Skip to content

Commit 5fc73d6

Browse files
committed
Merge branch 'master' into dev-printf-opt
2 parents d0add33 + cc98f8d commit 5fc73d6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+4460
-4296
lines changed

.github/labeler.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ ggml:
7676
- changed-files:
7777
- any-glob-to-any-file:
7878
- ggml/**
79+
model:
80+
- changed-files:
81+
- any-glob-to-any-file:
82+
- src/models/**
7983
nix:
8084
- changed-files:
8185
- any-glob-to-any-file:

examples/model-conversion/scripts/causal/run-org-model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ def fn(_m, input, output):
138138
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
139139
)
140140

141+
142+
print("Loading model and tokenizer using AutoTokenizer:", model_path)
143+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
141144
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
142145

143146
print("Model type: ", config.model_type)
@@ -147,10 +150,6 @@ def fn(_m, input, output):
147150
print("BOS token id: ", config.bos_token_id)
148151
print("EOS token id: ", config.eos_token_id)
149152

150-
print("Loading model and tokenizer using AutoTokenizer:", model_path)
151-
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
152-
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
153-
154153
if unreleased_model_name:
155154
model_name_lower = unreleased_model_name.lower()
156155
unreleased_module_path = (
@@ -171,7 +170,7 @@ def fn(_m, input, output):
171170
exit(1)
172171
else:
173172
model = AutoModelForCausalLM.from_pretrained(
174-
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True
173+
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
175174
)
176175

177176
for name, module in model.named_modules():

ggml/include/ggml.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2108,6 +2108,7 @@ extern "C" {
21082108
enum ggml_scale_mode {
21092109
GGML_SCALE_MODE_NEAREST = 0,
21102110
GGML_SCALE_MODE_BILINEAR = 1,
2111+
GGML_SCALE_MODE_BICUBIC = 2,
21112112

21122113
GGML_SCALE_MODE_COUNT
21132114
};

ggml/src/ggml-cpu/ops.cpp

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7507,10 +7507,17 @@ static void ggml_compute_forward_upscale_f32(
75077507
float sf1 = (float)ne1/src0->ne[1];
75087508
float sf2 = (float)ne2/src0->ne[2];
75097509
float sf3 = (float)ne3/src0->ne[3];
7510+
float pixel_offset = 0.5f;
75107511

75117512
const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
75127513
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
75137514

7515+
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7516+
pixel_offset = 0.0f;
7517+
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
7518+
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
7519+
}
7520+
75147521
if (mode == GGML_SCALE_MODE_NEAREST) {
75157522
for (int64_t i3 = 0; i3 < ne3; i3++) {
75167523
const int64_t i03 = i3 / sf3;
@@ -7530,13 +7537,6 @@ static void ggml_compute_forward_upscale_f32(
75307537
}
75317538
}
75327539
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
7533-
float pixel_offset = 0.5f;
7534-
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7535-
pixel_offset = 0.0f;
7536-
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
7537-
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
7538-
}
7539-
75407540
for (int64_t i3 = 0; i3 < ne3; i3++) {
75417541
const int64_t i03 = i3 / sf3;
75427542
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
@@ -7571,6 +7571,51 @@ static void ggml_compute_forward_upscale_f32(
75717571

75727572
const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
75737573

7574+
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7575+
*y_dst = val;
7576+
}
7577+
}
7578+
}
7579+
}
7580+
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
7581+
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
7582+
const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
7583+
auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
7584+
auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
7585+
auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
7586+
const float w0 = weight2(x + 1);
7587+
const float w1 = weight1(x + 0);
7588+
const float w2 = weight1(1 - x);
7589+
const float w3 = weight2(2 - x);
7590+
return p0*w0 + p1*w1 + p2*w2 + p3*w3;
7591+
};
7592+
7593+
for (int64_t i3 = 0; i3 < ne3; i3++) {
7594+
const int64_t i03 = i3 / sf3;
7595+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7596+
const int64_t i02 = i2 / sf2;
7597+
for (int64_t i1 = 0; i1 < ne1; i1++) {
7598+
const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7599+
const int64_t y0 = (int64_t)floorf(y);
7600+
const float dy = y - (float)y0;
7601+
7602+
for (int64_t i0 = 0; i0 < ne0; i0++) {
7603+
const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7604+
const int64_t x0 = (int64_t)floorf(x);
7605+
const float dx = x - (float)x0;
7606+
7607+
auto p = [=](int64_t x_off, int64_t y_off) -> float {
7608+
int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
7609+
int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
7610+
return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7611+
};
7612+
7613+
const float val = bicubic(
7614+
bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
7615+
bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
7616+
bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
7617+
bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
7618+
75747619
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
75757620
*y_dst = val;
75767621
}

ggml/src/ggml-cpu/repack.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,10 +1678,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
16781678
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
16791679
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
16801680

1681+
// Ensure minimum chunk size to avoid alignment issues with high thread counts
1682+
// Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
1683+
const int64_t min_chunk_size = NB_COLS;
1684+
if (nchunk > 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) {
1685+
nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
1686+
}
1687+
16811688
if (nth == 1 || nchunk < nth || disable_chunking) {
16821689
nchunk = nth;
16831690
}
16841691

1692+
// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
1693+
// This prevents creating too many tiny chunks that could overlap after alignment
1694+
const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
1695+
if (nchunk > max_nchunk) {
1696+
nchunk = max_nchunk;
1697+
}
1698+
16851699
if (ith == 0) {
16861700
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
16871701
ggml_threadpool_chunk_set(params->threadpool, nth);
@@ -1695,8 +1709,15 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
16951709
while (current_chunk < nchunk) {
16961710
int64_t src0_start = (current_chunk * ne01) / nchunk;
16971711
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
1712+
1713+
// Align boundaries to NB_COLS - round up to ensure all data is included
1714+
// The chunk size limiting above ensures chunks are large enough to prevent overlaps
16981715
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
16991716
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1717+
if (src0_end > ne01) {
1718+
src0_end = ne01;
1719+
}
1720+
17001721
if (src0_start >= src0_end) {
17011722
break;
17021723
}
@@ -1808,8 +1829,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
18081829
int64_t src0_cur_start = (ith * ne01) / nth;
18091830
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
18101831

1832+
// Align boundaries to NB_COLS - round up to ensure all data is included
18111833
src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
18121834
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
1835+
if (src0_cur_end > ne01) {
1836+
src0_cur_end = ne01;
1837+
}
18131838

18141839
if (src0_cur_start >= src0_cur_end) {
18151840
return;

ggml/src/ggml-cuda/fattn-tile.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
1414
GGML_ASSERT(V->ne[0] == K->ne[0]);
1515
ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
1616
} break;
17+
case 72: {
18+
GGML_ASSERT(V->ne[0] == K->ne[0]);
19+
ggml_cuda_flash_attn_ext_tile_case< 72, 72>(ctx, dst);
20+
} break;
1721
case 80: {
1822
GGML_ASSERT(V->ne[0] == K->ne[0]);
1923
ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst);

ggml/src/ggml-cuda/fattn-tile.cuh

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
// nbatch_K == number of K columns to load in parallel for KQ calculation
77

88
// TODO optimize kernel parameters for FP16 NVIDIA (P100)
9-
// TODO optimize kernel parameters for head sizes 40, 80, 96, 112
9+
// TODO optimize kernel parameters for head sizes 40, 72, 80, 96, 112
1010

1111
// The ROCm compiler cannot handle templating in __launch_bounds__.
1212
// As a workaround, define a macro to package the kernel parameters as uint32_t:
@@ -32,6 +32,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
3232
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64)
3333
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
3434

35+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 64, 72)
36+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 64, 72)
37+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 64, 72)
38+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 64, 72)
39+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 64, 72)
40+
3541
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40)
3642
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40)
3743
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40)
@@ -80,6 +86,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
8086
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64)
8187
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
8288

89+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
90+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
91+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
92+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
93+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
94+
8395
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
8496
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
8597
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
@@ -130,6 +142,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
130142
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
131143
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64)
132144

145+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
146+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
147+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
148+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
149+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
150+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
151+
133152
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
134153
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
135154
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
@@ -185,6 +204,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
185204
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64)
186205
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64)
187206

207+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
208+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
209+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
210+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
211+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
212+
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
213+
188214
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
189215
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
190216
GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
@@ -723,7 +749,7 @@ static __global__ void flash_attn_tile(
723749

724750
if (
725751
#ifdef GGML_USE_WMMA_FATTN
726-
(ncols2 != 1 && DV != 40 && DV != 512) ||
752+
(ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
727753
#endif // GGML_USE_WMMA_FATTN
728754
(use_logit_softcap && !(DV == 128 || DV == 256))
729755
) {
@@ -1198,6 +1224,7 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
11981224

11991225
extern DECL_FATTN_TILE_CASE( 40, 40);
12001226
extern DECL_FATTN_TILE_CASE( 64, 64);
1227+
extern DECL_FATTN_TILE_CASE( 72, 72);
12011228
extern DECL_FATTN_TILE_CASE( 80, 80);
12021229
extern DECL_FATTN_TILE_CASE( 96, 96);
12031230
extern DECL_FATTN_TILE_CASE(112, 112);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
223223
switch (K->ne[0]) {
224224
case 40:
225225
case 64:
226+
case 72:
226227
case 80:
227228
case 96:
228229
case 128:
@@ -275,7 +276,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
275276
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
276277

277278
// If Turing tensor cores available, use them:
278-
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
279+
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) {
279280
if (can_use_vector_kernel) {
280281
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
281282
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
@@ -301,7 +302,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
301302
}
302303

303304
// Use the WMMA kernel if possible:
304-
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) {
305+
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
305306
if (can_use_vector_kernel && Q->ne[1] <= 2) {
306307
return BEST_FATTN_KERNEL_VEC;
307308
}

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,6 +2115,14 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
21152115
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
21162116
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
21172117

2118+
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
2119+
ggml_backend_buft_is_cuda_split(src1->buffer->buft);
2120+
2121+
//TODO: add support for fusion for split buffers
2122+
if (split) {
2123+
return false;
2124+
}
2125+
21182126
//we only support fusion for ncols_dst = 1
21192127
if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
21202128
return false;
@@ -2154,6 +2162,15 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
21542162
return false;
21552163
}
21562164

2165+
2166+
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
2167+
ggml_backend_buft_is_cuda_split(src1->buffer->buft);
2168+
2169+
//TODO: add support for fusion for split buffers
2170+
if (split) {
2171+
return false;
2172+
}
2173+
21572174
return use_mul_mat_vec_q;
21582175
}
21592176

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-tile.cuh"
4+
5+
DECL_FATTN_TILE_CASE(72, 72);

0 commit comments

Comments
 (0)