Skip to content

Commit f348628

Browse files
committed
Merge branch 'master' into perplexity
2 parents 8a827f8 + d84635b commit f348628

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+4208
-978
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1379,7 +1379,7 @@ jobs:
13791379
id: pack_artifacts
13801380
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
13811381
run: |
1382-
zip -r llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework
1382+
zip --symlinks -r llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework
13831383
13841384
- name: Upload artifacts
13851385
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}

CMakeLists.txt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ else()
2929
set(LLAMA_STANDALONE OFF)
3030
endif()
3131

32+
option(LLAMA_USE_SYSTEM_GGML "Use system libggml" OFF)
33+
3234
if (EMSCRIPTEN)
3335
set(BUILD_SHARED_LIBS_DEFAULT OFF)
3436

@@ -145,7 +147,13 @@ endif()
145147
# 3rd-party
146148
#
147149

148-
if (NOT TARGET ggml)
150+
if (LLAMA_USE_SYSTEM_GGML)
151+
message(STATUS "Using system-provided libggml, skipping ggml build")
152+
find_package(ggml REQUIRED)
153+
add_library(ggml ALIAS ggml::ggml)
154+
endif()
155+
156+
if (NOT TARGET ggml AND NOT LLAMA_USE_SYSTEM_GGML)
149157
add_subdirectory(ggml)
150158
# ... otherwise assume ggml is added by a parent CMakeLists.txt
151159
endif()

cmake/common.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
include("ggml/cmake/common.cmake")
2+
13
function(llama_add_compile_flags)
24
if (LLAMA_FATAL_WARNINGS)
35
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")

convert_hf_to_gguf.py

Lines changed: 216 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,40 @@ def _set_vocab_llama_hf(self):
908908
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
909909
special_vocab.add_to_gguf(self.gguf_writer)
910910

911+
def _set_vocab_rwkv_world(self):
912+
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
913+
vocab_size = self.hparams.get("vocab_size", 65536)
914+
915+
tokens: list[bytes] = ['<s>'.encode("utf-8")]
916+
toktypes: list[int] = [gguf.TokenType.CONTROL]
917+
918+
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
919+
lines = f.readlines()
920+
for line in lines:
921+
parts = line.split(' ')
922+
assert len(parts) >= 3
923+
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
924+
token = token.encode("utf-8") if isinstance(token, str) else token
925+
assert isinstance(token, bytes)
926+
assert len(token) == token_len
927+
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
928+
tokens.append(token_text.encode("utf-8"))
929+
toktypes.append(gguf.TokenType.NORMAL)
930+
remainder = vocab_size - len(tokens)
931+
assert remainder >= 0
932+
for i in range(len(tokens), vocab_size):
933+
tokens.append(f"[PAD{i}]".encode("utf-8"))
934+
toktypes.append(gguf.TokenType.UNUSED)
935+
936+
self.gguf_writer.add_tokenizer_model("rwkv")
937+
self.gguf_writer.add_token_list(tokens)
938+
self.gguf_writer.add_token_types(toktypes)
939+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
940+
special_vocab.chat_template = "rwkv-world"
941+
# hack: Add '\n\n' as the EOT token to make it chat normally
942+
special_vocab._set_special_token("eot", 261)
943+
special_vocab.add_to_gguf(self.gguf_writer)
944+
911945
def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
912946
tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
913947
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
@@ -1713,6 +1747,25 @@ def prepare_tensors(self):
17131747
raise ValueError(f"Unprocessed experts: {experts}")
17141748

17151749

1750+
@Model.register("Mistral3ForConditionalGeneration")
1751+
class Mistral3Model(LlamaModel):
1752+
model_arch = gguf.MODEL_ARCH.LLAMA
1753+
1754+
# we need to merge the text_config into the root level of hparams
1755+
def __init__(self, *args, **kwargs):
1756+
hparams = Model.load_hparams(kwargs["dir_model"])
1757+
if "text_config" in hparams:
1758+
hparams = {**hparams, **hparams["text_config"]}
1759+
kwargs["hparams"] = hparams
1760+
super().__init__(*args, **kwargs)
1761+
1762+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
1763+
name = name.replace("language_model.", "")
1764+
if "multi_modal_projector" in name or "vision_tower" in name:
1765+
return []
1766+
return super().modify_tensors(data_torch, name, bid)
1767+
1768+
17161769
@Model.register("DeciLMForCausalLM")
17171770
class DeciModel(Model):
17181771
model_arch = gguf.MODEL_ARCH.DECI
@@ -3412,38 +3465,7 @@ class Rwkv6Model(Model):
34123465
model_arch = gguf.MODEL_ARCH.RWKV6
34133466

34143467
def set_vocab(self):
3415-
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
3416-
vocab_size = self.hparams.get("vocab_size", 65536)
3417-
3418-
tokens: list[bytes] = ['<s>'.encode("utf-8")]
3419-
toktypes: list[int] = [gguf.TokenType.CONTROL]
3420-
3421-
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
3422-
lines = f.readlines()
3423-
for line in lines:
3424-
parts = line.split(' ')
3425-
assert len(parts) >= 3
3426-
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
3427-
token = token.encode("utf-8") if isinstance(token, str) else token
3428-
assert isinstance(token, bytes)
3429-
assert len(token) == token_len
3430-
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
3431-
tokens.append(token_text.encode("utf-8"))
3432-
toktypes.append(gguf.TokenType.NORMAL)
3433-
remainder = vocab_size - len(tokens)
3434-
assert remainder >= 0
3435-
for i in range(len(tokens), vocab_size):
3436-
tokens.append(f"[PAD{i}]".encode("utf-8"))
3437-
toktypes.append(gguf.TokenType.UNUSED)
3438-
3439-
self.gguf_writer.add_tokenizer_model("rwkv")
3440-
self.gguf_writer.add_token_list(tokens)
3441-
self.gguf_writer.add_token_types(toktypes)
3442-
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
3443-
special_vocab.chat_template = "rwkv-world"
3444-
# hack: Add '\n\n' as the EOT token to make it chat normally
3445-
special_vocab._set_special_token("eot", 261)
3446-
special_vocab.add_to_gguf(self.gguf_writer)
3468+
self._set_vocab_rwkv_world()
34473469

34483470
def set_gguf_parameters(self):
34493471
block_count = self.hparams["num_hidden_layers"]
@@ -3565,6 +3587,168 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
35653587
yield (new_name, data)
35663588

35673589

3590+
@Model.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM")
3591+
class Rwkv7Model(Model):
3592+
model_arch = gguf.MODEL_ARCH.RWKV7
3593+
3594+
def set_vocab(self):
3595+
self._set_vocab_rwkv_world()
3596+
3597+
def calc_lora_rank(self, hidden_size, exponent, multiplier):
3598+
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
3599+
3600+
def set_gguf_parameters(self):
3601+
block_count = self.hparams["num_hidden_layers"]
3602+
try:
3603+
head_size = self.hparams["head_size"]
3604+
layer_norm_eps = self.hparams["layer_norm_epsilon"]
3605+
except KeyError:
3606+
head_size = self.hparams["head_dim"]
3607+
layer_norm_eps = self.hparams["norm_eps"]
3608+
hidden_size = self.hparams["hidden_size"]
3609+
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else (hidden_size * 4)
3610+
3611+
# ICLR: In-Context-Learning-Rate
3612+
try:
3613+
lora_rank_decay = self.hparams["lora_rank_decay"] if self.hparams["lora_rank_decay"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
3614+
lora_rank_iclr = self.hparams["lora_rank_iclr"] if self.hparams["lora_rank_iclr"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
3615+
lora_rank_value_residual_mix = self.hparams["lora_rank_value_residual_mix"] if self.hparams["lora_rank_value_residual_mix"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
3616+
lora_rank_gate = self.hparams["lora_rank_gate"] if self.hparams["lora_rank_gate"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
3617+
except KeyError:
3618+
lora_rank_decay = self.hparams["decay_low_rank_dim"] if self.hparams["decay_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
3619+
lora_rank_iclr = self.hparams["a_low_rank_dim"] if self.hparams["a_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
3620+
lora_rank_value_residual_mix = self.hparams["v_low_rank_dim"] if self.hparams["v_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
3621+
lora_rank_gate = self.hparams["gate_low_rank_dim"] if self.hparams["gate_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
3622+
3623+
# RWKV isn't context limited
3624+
self.gguf_writer.add_context_length(1048576)
3625+
self.gguf_writer.add_embedding_length(hidden_size)
3626+
self.gguf_writer.add_block_count(block_count)
3627+
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
3628+
self.gguf_writer.add_wkv_head_size(head_size)
3629+
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
3630+
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
3631+
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
3632+
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
3633+
self.gguf_writer.add_feed_forward_length(intermediate_size)
3634+
self.gguf_writer.add_file_type(self.ftype)
3635+
3636+
# required by llama.cpp, unused
3637+
self.gguf_writer.add_head_count(0)
3638+
3639+
lerp_weights: dict[int, dict[str, Tensor]] = {}
3640+
lora_needs_transpose: bool = True
3641+
3642+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3643+
# unify tensor names here to make life easier
3644+
name = name.replace("blocks", "layers").replace("ffn", "feed_forward")
3645+
name = name.replace("self_attn", "attention").replace("attn", "attention")
3646+
name = name.replace("time_mixer.", "")
3647+
# lora layer names in fla-hub's impl
3648+
if "_lora.lora" in name:
3649+
self.lora_needs_transpose = False
3650+
name = name.replace("_lora.lora.0.weight", "1.weight")
3651+
name = name.replace("_lora.lora.2.weight", "2.weight")
3652+
name = name.replace("_lora.lora.2.bias", "0.weight")
3653+
3654+
name = name.replace("feed_forward_norm", "ln2")
3655+
name = name.replace("g_norm", "ln_x")
3656+
3657+
if "attention.v" in name and "value" not in self.map_tensor_name(name) and bid == 0:
3658+
# some models have dummy v0/v1/v2 on first layer while others don't
3659+
# ignore them all since they are not used
3660+
return
3661+
3662+
wkv_has_gate = self.hparams.get("wkv_has_gate", True)
3663+
lerp_list = ["r", "w", "k", "v", "a", "g"] if wkv_has_gate else ["r", "w", "k", "v", "a"]
3664+
3665+
if bid is not None and "attention.x_" in name:
3666+
if "attention.x_x" in name:
3667+
# already concatenated
3668+
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
3669+
data = data_torch.reshape(len(lerp_list), 1, 1, -1)
3670+
yield (new_name, data)
3671+
else:
3672+
try:
3673+
self.lerp_weights[bid][name] = data_torch
3674+
except KeyError:
3675+
self.lerp_weights[bid] = {name: data_torch}
3676+
if all(f"model.layers.{bid}.attention.x_{i}" in self.lerp_weights[bid].keys() for i in lerp_list):
3677+
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
3678+
data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.attention.x_{i}"] for i in lerp_list], dim=0)
3679+
yield (new_name, data)
3680+
return
3681+
else:
3682+
data_torch = data_torch.squeeze()
3683+
new_name = self.map_tensor_name(name)
3684+
3685+
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
3686+
new_name += ".weight"
3687+
3688+
if self.lora_needs_transpose and any(
3689+
new_name.endswith(t) for t in [
3690+
"time_mix_w1.weight", "time_mix_w2.weight",
3691+
"time_mix_a1.weight", "time_mix_a2.weight",
3692+
"time_mix_v1.weight", "time_mix_v2.weight",
3693+
"time_mix_g1.weight", "time_mix_g2.weight",
3694+
]
3695+
):
3696+
data_torch = data_torch.transpose(0, 1)
3697+
3698+
if 'r_k' in new_name:
3699+
data_torch = data_torch.flatten()
3700+
3701+
if bid == 0 and "time_mix_a" in new_name:
3702+
# dummy v0/v1/v2 on first layer
3703+
# easist way to make llama happy
3704+
yield (new_name.replace("time_mix_a", "time_mix_v"), data_torch)
3705+
3706+
yield (new_name, data_torch)
3707+
3708+
3709+
@Model.register("RwkvHybridForCausalLM")
3710+
class ARwkv7Model(Rwkv7Model):
3711+
model_arch = gguf.MODEL_ARCH.ARWKV7
3712+
3713+
def set_vocab(self):
3714+
try:
3715+
self._set_vocab_sentencepiece()
3716+
except FileNotFoundError:
3717+
self._set_vocab_gpt2()
3718+
3719+
def set_gguf_parameters(self):
3720+
block_count = self.hparams["num_hidden_layers"]
3721+
hidden_size = self.hparams["hidden_size"]
3722+
head_size = self.hparams["head_size"]
3723+
rms_norm_eps = self.hparams["rms_norm_eps"]
3724+
intermediate_size = self.hparams["intermediate_size"]
3725+
wkv_has_gate = self.hparams["wkv_has_gate"]
3726+
assert self.hparams["wkv_version"] == 7
3727+
3728+
# ICLR: In-Context-Learning-Rate
3729+
lora_rank_decay = 64
3730+
lora_rank_iclr = 64
3731+
lora_rank_value_residual_mix = 32
3732+
lora_rank_gate = 128 if wkv_has_gate else 0
3733+
3734+
# RWKV isn't context limited
3735+
self.gguf_writer.add_context_length(1048576)
3736+
self.gguf_writer.add_embedding_length(hidden_size)
3737+
self.gguf_writer.add_block_count(block_count)
3738+
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
3739+
self.gguf_writer.add_wkv_head_size(head_size)
3740+
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
3741+
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
3742+
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
3743+
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
3744+
self.gguf_writer.add_feed_forward_length(intermediate_size)
3745+
self.gguf_writer.add_file_type(self.ftype)
3746+
self.gguf_writer.add_token_shift_count(1)
3747+
3748+
# required by llama.cpp, unused
3749+
self.gguf_writer.add_head_count(0)
3750+
3751+
35683752
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
35693753
class MambaModel(Model):
35703754
model_arch = gguf.MODEL_ARCH.MAMBA

docs/backend/SYCL.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,8 +660,9 @@ use 1 SYCL GPUs: [0] with Max compute units:512
660660
|--------------------|---------------------------------------|---------------------------------------------|
661661
| GGML_SYCL | ON (mandatory) | Enable build with SYCL code path.<br>FP32 path - recommended for better perforemance than FP16 on quantized model|
662662
| GGML_SYCL_TARGET | INTEL *(default)* \| NVIDIA \| AMD | Set the SYCL target device type. |
663-
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
663+
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
664664
| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. |
665+
| GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
665666
| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. |
666667
| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. |
667668

@@ -671,6 +672,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
671672
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
672673
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
673674
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
675+
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
674676
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
675677

676678

0 commit comments

Comments
 (0)