Skip to content

Commit 99c30f7

Browse files
authored
Merge branch 'ggml-org:master' into glm45
2 parents da39c79 + 5c0eb5e commit 99c30f7

File tree

13 files changed

+144
-66
lines changed

13 files changed

+144
-66
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
name: Check Pre-Tokenizer Hashes
2+
3+
on:
4+
push:
5+
paths:
6+
- 'convert_hf_to_gguf.py'
7+
- 'convert_hf_to_gguf_update.py'
8+
pull_request:
9+
paths:
10+
- 'convert_hf_to_gguf.py'
11+
- 'convert_hf_to_gguf_update.py'
12+
13+
jobs:
14+
pre-tokenizer-hashes:
15+
runs-on: ubuntu-latest
16+
17+
steps:
18+
- name: Checkout repository
19+
uses: actions/checkout@v4
20+
21+
- name: Set up Python
22+
uses: actions/setup-python@v5
23+
with:
24+
python-version: '3.11'
25+
26+
- name: Install Python dependencies
27+
run: |
28+
python3 -m venv .venv
29+
.venv/bin/pip install -r requirements/requirements-convert_hf_to_gguf_update.txt
30+
31+
- name: Update pre-tokenizer hashes
32+
run: |
33+
cp convert_hf_to_gguf.py /tmp
34+
.venv/bin/python convert_hf_to_gguf_update.py --check-missing
35+
36+
- name: Check if committed pre-tokenizer hashes matches generated version
37+
run: |
38+
if ! diff -q convert_hf_to_gguf.py /tmp/convert_hf_to_gguf.py; then
39+
echo "Model pre-tokenizer hashes (in convert_hf_to_gguf.py) do not match generated hashes (from convert_hf_to_gguf_update.py)."
40+
echo "To fix: run ./convert_hf_to_gguf_update.py and commit the updated convert_hf_to_gguf.py along with your changes"
41+
echo "Differences found:"
42+
diff convert_hf_to_gguf.py /tmp/convert_hf_to_gguf.py || true
43+
exit 1
44+
fi
45+
echo "Model pre-tokenizer hashes are up to date."

convert_hf_to_gguf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
702702
if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890":
703703
# ref: https://huggingface.co/moonshotai/Kimi-K2-Base
704704
res = "kimi-k2"
705+
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
706+
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
707+
res = "qwen2"
705708
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
706709
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
707710
res = "llama-bpe"
@@ -849,9 +852,6 @@ def get_vocab_base_pre(self, tokenizer) -> str:
849852
if chkhsh == "2085e1638f6c377a0aa4ead21b27bb4cb941bf800df86ed391011769c1758dfb":
850853
# ref: https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B
851854
res = "exaone4"
852-
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
853-
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-8B
854-
res = "qwen2"
855855

856856
if res is None:
857857
logger.warning("\n")

convert_hf_to_gguf_update.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ class TOKENIZER_TYPE(IntEnum):
5959
"--full", action="store_true",
6060
help="download full list of models - make sure you have access to all of them",
6161
)
62+
parser.add_argument(
63+
"--check-missing", action="store_true",
64+
help="only check for missing pre-tokenizer hashes",
65+
)
6266
parser.add_argument(
6367
"hf_token",
6468
help="optional HF token",
@@ -70,6 +74,10 @@ class TOKENIZER_TYPE(IntEnum):
7074
if hf_token is None:
7175
logger.warning("HF token not found. You can provide it as an argument or set it in ~/.cache/huggingface/token")
7276

77+
if args.check_missing and args.full:
78+
logger.warning("Downloading full list of models requested, ignoring --check-missing!")
79+
args.check_missing = False
80+
7381
# TODO: this string has to exercise as much pre-tokenizer functionality as possible
7482
# will be updated with time - contributions welcome
7583
CHK_TXT = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````\"\"\"\"......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL'
@@ -147,6 +155,7 @@ class TOKENIZER_TYPE(IntEnum):
147155
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"},
148156
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
149157
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
158+
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
150159
]
151160

152161

@@ -221,12 +230,13 @@ def get_existing_models(convert_py):
221230
all_models = models.copy()
222231
models = [model for model in all_models if model["name"] not in existing_models]
223232

224-
logging.info(f"Downloading {len(models)} models...")
225-
for model in models:
226-
try:
227-
download_model(model)
228-
except Exception as e:
229-
logger.error(f"Failed to download model {model['name']}. Error: {e}")
233+
if not args.check_missing:
234+
logging.info(f"Downloading {len(models)} models...")
235+
for model in models:
236+
try:
237+
download_model(model)
238+
except Exception as e:
239+
logger.error(f"Failed to download model {model['name']}. Error: {e}")
230240

231241

232242
# generate the source code for the convert_hf_to_gguf.py:get_vocab_base_pre() function:

ggml/src/ggml-cuda/fattn.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
315315

316316
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
317317
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
318-
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies &&
319-
(Q->ne[3] > 1 || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion;
318+
const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
319+
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
320+
(cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
320321
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
321322
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
322323
if (prec == GGML_PREC_DEFAULT) {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,6 +1852,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18521852
ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
18531853
ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
18541854

1855+
bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
1856+
bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
1857+
18551858
// Handle src0
18561859
src0_ptr = (const cuda_t *) src0->data;
18571860

@@ -1870,6 +1873,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18701873
s11 = ne10;
18711874
s12 = ne11*s11;
18721875
s13 = ne12*s12;
1876+
1877+
is_src1_cont_2 = true;
18731878
}
18741879

18751880
// Setup destination buffer
@@ -1918,15 +1923,19 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19181923
const int64_t r2 = ne12/ne02;
19191924
const int64_t r3 = ne13/ne03;
19201925

1921-
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
1926+
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
1927+
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
1928+
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
1929+
const int64_t smb = ne12 == 1 ? s13 : s12;
1930+
19221931
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
19231932
// use cublasGemmStridedBatchedEx
19241933
CUBLAS_CHECK(
19251934
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
19261935
ne01, ne11, ne10,
1927-
alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1928-
src1_ptr, cu_data_type_b, s11, s12, // strideB
1929-
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1936+
alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
1937+
src1_ptr, cu_data_type_b, s11, smb, // strideB
1938+
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
19301939
ne12*ne13,
19311940
cu_compute_type,
19321941
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

ggml/src/ggml-cuda/im2col.cu

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,75 @@
11
#include "im2col.cuh"
22

3+
#define MIN(a, b) (a) < (b) ? (a) : (b)
4+
5+
#define MAX_GRIDDIM_Z 65535
6+
37
template <typename T>
48
static __global__ void im2col_kernel(
5-
const float * x, T * dst, int64_t batch_offset,
6-
int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
9+
const float * x, T * dst,
10+
int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,
11+
int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW,
712
int s0, int s1, int p0, int p1, int d0, int d1) {
813
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
9-
if (i >= pelements) {
14+
if (i >= IC_KH_KW) {
1015
return;
1116
}
1217

13-
const int64_t ksize = OW * KH;
14-
const int64_t kx = i / ksize;
15-
const int64_t kd = kx * ksize;
16-
const int64_t ky = (i - kd) / OW;
17-
const int64_t ix = i % OW;
18+
const int64_t iic = i / (KH_KW);
19+
const int64_t rem = i - iic * KH_KW;
20+
const int64_t ikh = rem / KW;
21+
const int64_t ikw = rem - ikh * KW;
1822

19-
const int64_t oh = blockIdx.y;
20-
const int64_t batch = blockIdx.z / IC;
21-
const int64_t ic = blockIdx.z % IC;
23+
const int64_t iow = blockIdx.y;
24+
for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) {
25+
const int64_t in = iz / OH;
26+
const int64_t ioh = iz - in * OH;
2227

23-
const int64_t iiw = ix * s0 + kx * d0 - p0;
24-
const int64_t iih = oh * s1 + ky * d1 - p1;
28+
const int64_t iiw = iow * s0 + ikw * d0 - p0;
29+
const int64_t iih = ioh * s1 + ikh * d1 - p1;
2530

26-
const int64_t offset_dst =
27-
((batch * OH + oh) * OW + ix) * CHW +
28-
(ic * (KW * KH) + ky * KW + kx);
31+
const int64_t offset_dst =
32+
((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;
2933

30-
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
31-
dst[offset_dst] = 0.0f;
32-
} else {
33-
const int64_t offset_src = ic * offset_delta + batch * batch_offset;
34-
dst[offset_dst] = x[offset_src + iih * IW + iiw];
34+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
35+
dst[offset_dst] = 0.0f;
36+
} else {
37+
const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;
38+
dst[offset_dst] = x[offset_src + iih * IW + iiw];
39+
}
3540
}
3641
}
3742

43+
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
3844
template <typename T>
3945
static void im2col_cuda(const float * x, T* dst,
4046
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
41-
int64_t batch, int64_t batch_offset, int64_t offset_delta,
47+
int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
4248
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
43-
const int parallel_elements = OW * KW * KH;
44-
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
45-
dim3 block_nums(num_blocks, OH, batch * IC);
46-
im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
49+
const int64_t IC_KH_KW = IC * KH * KW;
50+
const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
51+
const int64_t N_OH = N * OH;
52+
const int64_t KH_KW = KW*KH;
53+
dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z));
54+
im2col_kernel<<<block_nums, MIN(IC_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(x, dst, IC, IW, IH, OH, OW, KW, KH,
55+
IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW,
56+
s0, s1, p0, p1, d0, d1);
4757
}
4858

4959
static void im2col_cuda_f16(const float * x, half * dst,
5060
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
51-
int64_t batch, int64_t batch_offset, int64_t offset_delta,
61+
int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
5262
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
5363

54-
im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
64+
im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
5565
}
5666

5767
static void im2col_cuda_f32(const float * x, float * dst,
5868
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
59-
int64_t batch, int64_t batch_offset, int64_t offset_delta,
69+
int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
6070
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
6171

62-
im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
72+
im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
6373
}
6474

6575
void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -91,13 +101,13 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
91101
const int64_t OH = is_2D ? dst->ne[2] : 1;
92102
const int64_t OW = dst->ne[1];
93103

94-
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
95-
const int64_t batch = src1->ne[is_2D ? 3 : 2];
96-
const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
104+
const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
105+
const int64_t N = src1->ne[is_2D ? 3 : 2];
106+
const int64_t IH_IW = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
97107

98108
if(dst->type == GGML_TYPE_F16) {
99-
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
109+
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
100110
} else {
101-
im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
111+
im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
102112
}
103113
}

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,8 +2046,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
20462046

20472047
backend_ctx->adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version);
20482048
backend_ctx->has_vector_subgroup_broadcast =
2049-
backend_ctx->adreno_cl_compiler_version.major >= 47 ||
2050-
backend_ctx->adreno_cl_compiler_version.major == 17;
2049+
(backend_ctx->adreno_cl_compiler_version.type == E031 && backend_ctx->adreno_cl_compiler_version.major >= 47) ||
2050+
(backend_ctx->adreno_cl_compiler_version.type == DX && backend_ctx->adreno_cl_compiler_version.major >= 17);
20512051
GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n",
20522052
backend_ctx->has_vector_subgroup_broadcast ? "true" : "false");
20532053

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2688,6 +2688,9 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
26882688
const size_t type_size_src0 = ggml_type_size(src0->type);
26892689
const size_t type_size_src1 = ggml_type_size(src1->type);
26902690

2691+
bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
2692+
bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
2693+
26912694
// SRC1 strides
26922695
int64_t s11 = nb11 / type_size_src1;
26932696
int64_t s12 = nb12 / type_size_src1;
@@ -2737,6 +2740,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
27372740
s11 = ne10;
27382741
s12 = ne11 * s11;
27392742
s13 = ne12 * s12;
2743+
2744+
is_src1_cont_2 = true;
27402745
}
27412746

27422747
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
@@ -2852,12 +2857,16 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
28522857
else
28532858
#endif
28542859
{
2855-
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2860+
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
2861+
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
2862+
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
2863+
const int64_t smb = ne12 == 1 ? s13 : s12;
2864+
28562865
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
28572866
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
28582867
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2859-
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
2860-
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
2868+
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
2869+
src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
28612870
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
28622871
} else {
28632872
const int ne23 = ne12 * ne13;
Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1 @@
11
-r ./requirements-convert_legacy_llama.txt
2-
--extra-index-url https://download.pytorch.org/whl/cpu
3-
torch~=2.2.1; platform_machine != "s390x"
4-
5-
# torch s390x packages can only be found from nightly builds
6-
--extra-index-url https://download.pytorch.org/whl/nightly
7-
torch>=0.0.0.dev0; platform_machine == "s390x"

src/llama-context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ llama_context::llama_context(
105105

106106
{
107107
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
108-
supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false;
108+
supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows;
109109

110110
if (!supports_set_rows && !cparams.kv_unified) {
111111
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);

0 commit comments

Comments
 (0)