Skip to content

Commit 1389753

Browse files
committed
Merge branch 'master' into imatrix
2 parents c3ede42 + a0535ff commit 1389753

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+5708
-2415
lines changed

.github/workflows/build.yml

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ jobs:
664664
./build-xcframework.sh
665665
666666
windows-msys2:
667-
runs-on: windows-latest
667+
runs-on: windows-2025
668668

669669
strategy:
670670
fail-fast: false
@@ -714,7 +714,7 @@ jobs:
714714
cmake --build build --config ${{ matrix.build }} -j $(nproc)
715715
716716
windows-latest-cmake:
717-
runs-on: windows-latest
717+
runs-on: windows-2025
718718

719719
env:
720720
OPENBLAS_VERSION: 0.3.23
@@ -725,16 +725,22 @@ jobs:
725725
matrix:
726726
include:
727727
- build: 'cpu-x64 (static)'
728+
arch: 'x64'
728729
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF'
729730
- build: 'openblas-x64'
731+
arch: 'x64'
730732
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"'
731733
- build: 'vulkan-x64'
734+
arch: 'x64'
732735
defines: '-DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_VULKAN=ON'
733736
- build: 'llvm-arm64'
737+
arch: 'arm64'
734738
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON'
735739
- build: 'llvm-arm64-opencl-adreno'
740+
arch: 'arm64'
736741
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON'
737742
# - build: 'kompute-x64'
743+
# arch: 'x64'
738744
# defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON'
739745

740746
steps:
@@ -805,6 +811,8 @@ jobs:
805811
- name: libCURL
806812
id: get_libcurl
807813
uses: ./.github/actions/windows-setup-curl
814+
with:
815+
architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }}
808816

809817
- name: Build
810818
id: cmake_build
@@ -825,7 +833,7 @@ jobs:
825833
826834
- name: Test
827835
id: cmake_test
828-
if: ${{ matrix.build != 'llvm-arm64' && matrix.build != 'llvm-arm64-opencl-adreno' }}
836+
if: ${{ matrix.arch == 'x64' }}
829837
run: |
830838
cd build
831839
ctest -L main -C Release --verbose --timeout 900
@@ -930,7 +938,7 @@ jobs:
930938
cmake --build build --config Release
931939
932940
windows-latest-cmake-sycl:
933-
runs-on: windows-latest
941+
runs-on: windows-2022
934942

935943
defaults:
936944
run:
@@ -964,7 +972,7 @@ jobs:
964972

965973
windows-latest-cmake-hip:
966974
if: ${{ github.event.inputs.create_release != 'true' }}
967-
runs-on: windows-latest
975+
runs-on: windows-2022
968976

969977
steps:
970978
- name: Clone

.github/workflows/release.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ jobs:
235235
name: llama-bin-ubuntu-vulkan-x64.zip
236236

237237
windows-cpu:
238-
runs-on: windows-latest
238+
runs-on: windows-2025
239239

240240
strategy:
241241
matrix:
@@ -271,7 +271,7 @@ jobs:
271271
env:
272272
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
273273
run: |
274-
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" ${{ matrix.arch }}
274+
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" ${{ matrix.arch == 'x64' && 'x64' || 'amd64_arm64' }}
275275
cmake -S . -B build -G "Ninja Multi-Config" ^
276276
-D CMAKE_TOOLCHAIN_FILE=cmake/${{ matrix.arch }}-windows-llvm.cmake ^
277277
-DGGML_NATIVE=OFF ^
@@ -288,7 +288,7 @@ jobs:
288288
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
289289
run: |
290290
Copy-Item $env:CURL_PATH\bin\libcurl-${{ matrix.arch }}.dll .\build\bin\Release\
291-
Copy-Item "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Redist\MSVC\14.42.34433\debug_nonredist\${{ matrix.arch }}\Microsoft.VC143.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\
291+
Copy-Item "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Redist\MSVC\14.44.35112\debug_nonredist\${{ matrix.arch }}\Microsoft.VC143.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\
292292
7z a llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\*
293293
294294
- name: Upload artifacts
@@ -298,7 +298,7 @@ jobs:
298298
name: llama-bin-win-cpu-${{ matrix.arch }}.zip
299299

300300
windows:
301-
runs-on: windows-latest
301+
runs-on: windows-2025
302302

303303
env:
304304
OPENBLAS_VERSION: 0.3.23
@@ -448,7 +448,7 @@ jobs:
448448
name: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
449449

450450
windows-sycl:
451-
runs-on: windows-latest
451+
runs-on: windows-2022
452452

453453
defaults:
454454
run:
@@ -520,7 +520,7 @@ jobs:
520520
name: llama-bin-win-sycl-x64.zip
521521

522522
windows-hip:
523-
runs-on: windows-latest
523+
runs-on: windows-2022
524524

525525
strategy:
526526
matrix:

convert_hf_to_gguf.py

Lines changed: 163 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ def prepare_tensors(self):
310310
gguf.MODEL_TENSOR.POSNET_NORM2,
311311
gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
312312
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
313+
gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF,
314+
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
313315
)
314316
)
315317
or not new_name.endswith(".weight")
@@ -320,7 +322,11 @@ def prepare_tensors(self):
320322
self.match_model_tensor_name(new_name, key, bid)
321323
for key in (
322324
gguf.MODEL_TENSOR.TOKEN_EMBD,
325+
gguf.MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
323326
gguf.MODEL_TENSOR.OUTPUT,
327+
gguf.MODEL_TENSOR.ALTUP_ROUTER,
328+
gguf.MODEL_TENSOR.LAUREL_L,
329+
gguf.MODEL_TENSOR.LAUREL_R,
324330
)
325331
):
326332
if self.ftype in (
@@ -921,13 +927,20 @@ def _create_vocab_sentencepiece(self):
921927
tokenizer = SentencePieceProcessor()
922928
tokenizer.LoadFromFile(str(tokenizer_path))
923929

924-
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
930+
vocab_size = self.find_hparam([
931+
"vocab_size_per_layer_input", # gemma3n
932+
"vocab_size",
933+
], optional=True) or tokenizer.vocab_size()
925934

926935
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
927936
scores: list[float] = [-10000.0] * vocab_size
928937
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
929938

930939
for token_id in range(tokenizer.vocab_size()):
940+
if token_id >= vocab_size:
941+
logger.warning(f'ignore tokens from {token_id}: id is out of range, max={vocab_size - 1}')
942+
break
943+
931944
piece = tokenizer.IdToPiece(token_id)
932945
text = piece.encode("utf-8")
933946
score = tokenizer.GetScore(token_id)
@@ -2730,6 +2743,52 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27302743
yield from super().modify_tensors(data_torch, name, bid)
27312744

27322745

2746+
@ModelBase.register("Ernie4_5_ForCausalLM")
2747+
class Ernie4_5Model(TextModel):
2748+
model_arch = gguf.MODEL_ARCH.ERNIE4_5
2749+
2750+
def set_vocab(self):
2751+
self._set_vocab_sentencepiece()
2752+
2753+
def set_gguf_parameters(self):
2754+
super().set_gguf_parameters()
2755+
2756+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2757+
num_heads = self.hparams["num_attention_heads"]
2758+
num_kv_heads = self.hparams["num_key_value_heads"]
2759+
head_dim = self.hparams["head_dim"]
2760+
2761+
if "ernie." in name:
2762+
name = name.replace("ernie.", "model.")
2763+
# split the qkv weights
2764+
# qkv_proj shape: [(num_heads + 2 * num_kv_heads) * head_dim, hidden_size]
2765+
if "qkv_proj" in name:
2766+
name_q = name.replace("qkv_proj.weight", "q_proj.weight")
2767+
name_k = name.replace("qkv_proj.weight", "k_proj.weight")
2768+
name_v = name.replace("qkv_proj.weight", "v_proj.weight")
2769+
total_q_dim = num_heads * head_dim
2770+
total_k_dim = num_kv_heads * head_dim
2771+
total_v_dim = num_kv_heads * head_dim
2772+
q_proj_weight, k_proj_weight, v_proj_weight = data_torch.split([total_q_dim, total_k_dim, total_v_dim], dim=0)
2773+
return [
2774+
(self.map_tensor_name(name_q), q_proj_weight),
2775+
(self.map_tensor_name(name_k), k_proj_weight),
2776+
(self.map_tensor_name(name_v), v_proj_weight)
2777+
]
2778+
# split the up_gate_proj into gate and up
2779+
# up_gate_proj shape: [2 * intermediate_size, hidden_size]
2780+
if "up_gate_proj" in name:
2781+
name_up = name.replace("up_gate_proj.weight", "up_proj.weight")
2782+
name_gate = name.replace("up_gate_proj.weight", "gate_proj.weight")
2783+
dim_half = data_torch.shape[0] // 2
2784+
gate_proj_weight, up_proj_weight = data_torch.split(dim_half, dim=0)
2785+
return [
2786+
(self.map_tensor_name(name_gate), gate_proj_weight),
2787+
(self.map_tensor_name(name_up), up_proj_weight)
2788+
]
2789+
return [(self.map_tensor_name(name), data_torch)]
2790+
2791+
27332792
@ModelBase.register(
27342793
"Qwen2VLModel",
27352794
"Qwen2VLForConditionalGeneration",
@@ -4217,6 +4276,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
42174276
@ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
42184277
class Gemma3Model(TextModel):
42194278
model_arch = gguf.MODEL_ARCH.GEMMA3
4279+
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
42204280

42214281
def set_vocab(self):
42224282
self._set_vocab_sentencepiece()
@@ -4238,9 +4298,8 @@ def set_gguf_parameters(self):
42384298
self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
42394299
self.gguf_writer.add_file_type(self.ftype)
42404300
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
4241-
# both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
4301+
# attn_logit_softcapping is removed in Gemma3
42424302
assert hparams.get("attn_logit_softcapping") is None
4243-
assert hparams.get("final_logit_softcapping") is None
42444303
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
42454304
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
42464305
if hparams.get("rope_scaling") is not None:
@@ -4252,7 +4311,7 @@ def set_gguf_parameters(self):
42524311
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
42534312
del bid # unused
42544313

4255-
if name.startswith("language_model."):
4314+
if "language_model." in name:
42564315
name = name.replace("language_model.", "")
42574316

42584317
elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
@@ -4267,8 +4326,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
42674326

42684327
# ref code in Gemma3RMSNorm
42694328
# output = output * (1.0 + self.weight.float())
4329+
# note: this is not the case on gemma3n
42704330
if name.endswith("norm.weight"):
4271-
data_torch = data_torch + 1
4331+
data_torch = data_torch + self.norm_shift
42724332

42734333
return [(self.map_tensor_name(name), data_torch)]
42744334

@@ -4325,6 +4385,104 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
43254385
return [] # skip other tensors
43264386

43274387

4388+
@ModelBase.register("Gemma3nForConditionalGeneration")
4389+
class Gemma3NModel(Gemma3Model):
4390+
model_arch = gguf.MODEL_ARCH.GEMMA3N
4391+
norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code
4392+
4393+
_altup_proj: list[Tensor] = []
4394+
_altup_unembd: list[Tensor] = []
4395+
4396+
def __init__(self, *args, **kwargs):
4397+
super().__init__(*args, **kwargs)
4398+
assert self.hparams["altup_num_inputs"] == 4, "Current conversion only supports 4 altup inputs"
4399+
self._altup_proj = [
4400+
torch.Tensor(), # to be replaced
4401+
torch.Tensor(), # to be replaced
4402+
torch.Tensor(), # to be replaced
4403+
]
4404+
self._altup_unembd = [
4405+
torch.Tensor(), # to be replaced
4406+
torch.Tensor(), # to be replaced
4407+
torch.Tensor(), # to be replaced
4408+
]
4409+
4410+
def set_vocab(self):
4411+
with open(self.dir_model / "chat_template.jinja") as f:
4412+
# quick hack to make sure chat template is added
4413+
self.gguf_writer.add_chat_template(f.read())
4414+
super().set_vocab()
4415+
4416+
def set_gguf_parameters(self):
4417+
super().set_gguf_parameters()
4418+
self.gguf_writer.add_altup_active_idx(self.hparams["altup_active_idx"])
4419+
self.gguf_writer.add_altup_num_inputs(self.hparams["altup_num_inputs"])
4420+
self.gguf_writer.add_embedding_length_per_layer_input(self.hparams["hidden_size_per_layer_input"])
4421+
self.gguf_writer.add_shared_kv_layers(self.hparams["num_kv_shared_layers"])
4422+
4423+
activation_sparsity_scale = []
4424+
for s in self.hparams["activation_sparsity_pattern"]:
4425+
normal_dist = torch.distributions.normal.Normal(0, 1)
4426+
std_multiplier = normal_dist.icdf(torch.tensor(s, dtype=torch.float32))
4427+
activation_sparsity_scale.append(std_multiplier.item())
4428+
self.gguf_writer.add_activation_sparsity_scale(activation_sparsity_scale)
4429+
4430+
sliding_window_pattern = []
4431+
for t in self.hparams["layer_types"]:
4432+
sliding_window_pattern.append(t == "sliding_attention")
4433+
self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
4434+
4435+
def _stack_matrices(self, matrices: list[Tensor]) -> Tensor | None:
4436+
has_all = all(m.numel() > 0 for m in matrices)
4437+
if not has_all:
4438+
return None
4439+
else:
4440+
return torch.stack(matrices, dim=0)
4441+
4442+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4443+
if name.endswith("_scale"):
4444+
name = name + ".weight"
4445+
4446+
# TODO: implement self.prediction_coefs.weight.clamp_(...)
4447+
4448+
if "language_model." not in name:
4449+
return [] # skip non-language model tensors
4450+
4451+
if "altup_unembed_projections" in name:
4452+
data_torch = data_torch.to(device="cpu")
4453+
if ".0." in name:
4454+
self._altup_unembd[0] = data_torch
4455+
elif ".1." in name:
4456+
self._altup_unembd[1] = data_torch
4457+
elif ".2." in name:
4458+
self._altup_unembd[2] = data_torch
4459+
else:
4460+
raise ValueError(f"Unknown name: {name}")
4461+
out = self._stack_matrices(self._altup_unembd)
4462+
if out is not None:
4463+
return [(self.map_tensor_name("model.altup_unembed_projections.weight"), out)]
4464+
else:
4465+
return []
4466+
4467+
if "altup_projections" in name:
4468+
data_torch = data_torch.to(device="cpu")
4469+
if ".0." in name:
4470+
self._altup_proj[0] = data_torch
4471+
elif ".1." in name:
4472+
self._altup_proj[1] = data_torch
4473+
elif ".2." in name:
4474+
self._altup_proj[2] = data_torch
4475+
else:
4476+
raise ValueError(f"Unknown name: {name}")
4477+
out = self._stack_matrices(self._altup_proj)
4478+
if out is not None:
4479+
return [(self.map_tensor_name("model.altup_projections.weight"), out)]
4480+
else:
4481+
return []
4482+
4483+
return super().modify_tensors(data_torch, name, bid)
4484+
4485+
43284486
@ModelBase.register("Starcoder2ForCausalLM")
43294487
class StarCoder2Model(TextModel):
43304488
model_arch = gguf.MODEL_ARCH.STARCODER2

docs/backend/SYCL.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
757757
| Name | Value | Function |
758758
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
759759
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
760-
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
760+
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features for Intel GPUs. (Recommended to 1 for intel devices older than Gen 10) |
761761
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
762762
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
763763
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |

0 commit comments

Comments
 (0)