Skip to content

Commit 9689673

Browse files
fmzfmz
andauthored
Add JAIS model(s) (#8118)
* Add `JAIS` model(s) * cleanup * address review comments * remove hack * un-hardcode max-alibi-bias * minor tweaks --------- Co-authored-by: fmz <[email protected]>
1 parent 023b880 commit 9689673

File tree

6 files changed

+288
-9
lines changed

6 files changed

+288
-9
lines changed

convert-hf-to-gguf-update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class TOKENIZER_TYPE(IntEnum):
8686
{"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
8787
{"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },
8888
{"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B
89+
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
8990
]
9091

9192

convert-hf-to-gguf.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
490490
if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
491491
# ref: https://huggingface.co/LumiOpen/Viking-7B
492492
res = "viking"
493+
if chkhsh == "b53802fb28e26d645c3a310b34bfe07da813026ec7c7716883404d5e0f8b1901":
494+
# ref: https://huggingface.co/core42/jais-13b
495+
res = "jais"
493496

494497
if res is None:
495498
logger.warning("\n")
@@ -2965,6 +2968,96 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
29652968
return [(self.map_tensor_name(name), data_torch)]
29662969

29672970

2971+
@Model.register("JAISLMHeadModel")
2972+
class JaisModel(Model):
2973+
model_arch = gguf.MODEL_ARCH.JAIS
2974+
2975+
def __init__(self, *args, **kwargs):
2976+
super().__init__(*args, **kwargs)
2977+
2978+
# SwigLU activation
2979+
assert self.hparams["activation_function"] == "swiglu"
2980+
# ALiBi position embedding
2981+
assert self.hparams["position_embedding_type"] == "alibi"
2982+
2983+
# Embeddings scale
2984+
self.embeddings_scale = 1.0
2985+
# note: For some JAIS flavors, output is tied to (same as) wte in original model
2986+
self.output_is_wte = False
2987+
if 'mup_embeddings_scale' in self.hparams:
2988+
self.output_is_wte = True # Hack (?)
2989+
self.embeddings_scale = self.hparams['mup_embeddings_scale']
2990+
elif 'embeddings_scale' in self.hparams:
2991+
self.embeddings_scale = self.hparams['embeddings_scale']
2992+
else:
2993+
assert False
2994+
2995+
self.width_scale = 1.0
2996+
if 'mup_output_alpha' in self.hparams:
2997+
assert 'mup_width_scale' in self.hparams
2998+
self.width_scale = self.hparams['mup_output_alpha'] * self.hparams['mup_width_scale']
2999+
elif 'width_scale' in self.hparams:
3000+
self.width_scale = self.hparams['width_scale']
3001+
else:
3002+
assert False
3003+
3004+
self.max_alibi_bias = 8.0
3005+
3006+
def set_vocab(self):
3007+
self._set_vocab_gpt2()
3008+
3009+
def set_gguf_parameters(self):
3010+
self.gguf_writer.add_name(self.dir_model.name)
3011+
self.gguf_writer.add_block_count(self.hparams["n_layer"])
3012+
self.gguf_writer.add_context_length(self.hparams["n_positions"])
3013+
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
3014+
self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"])
3015+
self.gguf_writer.add_head_count(self.hparams["n_head"])
3016+
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
3017+
self.gguf_writer.add_file_type(self.ftype)
3018+
3019+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3020+
del bid # unused
3021+
3022+
tensors: list[tuple[str, Tensor]] = []
3023+
3024+
# we don't need these
3025+
if name.endswith((".attn.bias")):
3026+
return tensors
3027+
3028+
if name.endswith(("relative_pe.slopes")):
3029+
# Calculate max ALiBi bias (this is the inverse of the ALiBi calculation)
3030+
# Some other models has max_alibi_bias spelled out explicitly in the hyperparams,
3031+
# but Jais's PyTorch model simply precalculates the slope values and places them
3032+
# in relative_pes.slopes
3033+
n_head_closest_log2 = 2 ** math.floor(math.log2(self.hparams["n_head"]))
3034+
first_val = float(data_torch._data[0])
3035+
self.max_alibi_bias = -round(math.log2(first_val) * n_head_closest_log2)
3036+
3037+
return tensors
3038+
3039+
if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_fc2.weight")):
3040+
data_torch = data_torch.transpose(1, 0)
3041+
3042+
new_name = self.map_tensor_name(name)
3043+
3044+
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
3045+
tensors.append((new_name, data_torch * self.embeddings_scale))
3046+
if self.output_is_wte:
3047+
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch * self.width_scale))
3048+
elif new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
3049+
assert not self.output_is_wte
3050+
tensors.append((new_name, data_torch * self.width_scale))
3051+
else:
3052+
tensors.append((new_name, data_torch))
3053+
3054+
return tensors
3055+
3056+
def write_tensors(self):
3057+
super().write_tensors()
3058+
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
3059+
3060+
29683061
###### CONVERSION LOGIC ######
29693062

29703063

gguf-py/gguf/constants.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ class MODEL_ARCH(IntEnum):
164164
DEEPSEEK2 = auto()
165165
BITNET = auto()
166166
T5 = auto()
167+
JAIS = auto()
167168

168169

169170
class MODEL_TENSOR(IntEnum):
@@ -288,6 +289,7 @@ class MODEL_TENSOR(IntEnum):
288289
MODEL_ARCH.DEEPSEEK2: "deepseek2",
289290
MODEL_ARCH.BITNET: "bitnet",
290291
MODEL_ARCH.T5: "t5",
292+
MODEL_ARCH.JAIS: "jais",
291293
}
292294

293295
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -954,6 +956,18 @@ class MODEL_TENSOR(IntEnum):
954956
MODEL_TENSOR.ENC_FFN_UP,
955957
MODEL_TENSOR.ENC_OUTPUT_NORM,
956958
],
959+
MODEL_ARCH.JAIS: [
960+
MODEL_TENSOR.TOKEN_EMBD,
961+
MODEL_TENSOR.OUTPUT_NORM,
962+
MODEL_TENSOR.OUTPUT,
963+
MODEL_TENSOR.ATTN_NORM,
964+
MODEL_TENSOR.ATTN_QKV,
965+
MODEL_TENSOR.ATTN_OUT,
966+
MODEL_TENSOR.FFN_NORM,
967+
MODEL_TENSOR.FFN_DOWN,
968+
MODEL_TENSOR.FFN_GATE,
969+
MODEL_TENSOR.FFN_UP,
970+
],
957971
# TODO
958972
}
959973

gguf-py/gguf/tensor_mapping.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class TensorNameMap:
1010
# Token embeddings
1111
MODEL_TENSOR.TOKEN_EMBD: (
1212
"gpt_neox.embed_in", # gptneox
13-
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx
13+
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
1616
"model.embed_tokens", # llama-hf
@@ -49,7 +49,7 @@ class TensorNameMap:
4949
# Output
5050
MODEL_TENSOR.OUTPUT: (
5151
"embed_out", # gptneox
52-
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx
52+
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais
5353
"output", # llama-pth bloom internlm2
5454
"word_embeddings_for_head", # persimmon
5555
"lm_head.linear", # phi2
@@ -58,7 +58,7 @@ class TensorNameMap:
5858
# Output norm
5959
MODEL_TENSOR.OUTPUT_NORM: (
6060
"gpt_neox.final_layer_norm", # gptneox
61-
"transformer.ln_f", # gpt2 gpt-j falcon
61+
"transformer.ln_f", # gpt2 gpt-j falcon jais
6262
"model.norm", # llama-hf baichuan internlm2
6363
"norm", # llama-pth
6464
"transformer.norm_f", # mpt dbrx
@@ -81,7 +81,7 @@ class TensorNameMap:
8181
# Attention norm
8282
MODEL_TENSOR.ATTN_NORM: (
8383
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
84-
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen
84+
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais
8585
"transformer.blocks.{bid}.norm_1", # mpt
8686
"transformer.h.{bid}.input_layernorm", # falcon7b
8787
"h.{bid}.input_layernorm", # bloom
@@ -109,7 +109,7 @@ class TensorNameMap:
109109
# Attention query-key-value
110110
MODEL_TENSOR.ATTN_QKV: (
111111
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
112-
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen
112+
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen jais
113113
"transformer.blocks.{bid}.attn.Wqkv", # mpt
114114
"transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx
115115
"transformer.h.{bid}.self_attention.query_key_value", # falcon
@@ -160,7 +160,7 @@ class TensorNameMap:
160160
# Attention output
161161
MODEL_TENSOR.ATTN_OUT: (
162162
"gpt_neox.layers.{bid}.attention.dense", # gptneox
163-
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
163+
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen jais
164164
"transformer.blocks.{bid}.attn.out_proj", # mpt
165165
"transformer.h.{bid}.self_attention.dense", # falcon
166166
"h.{bid}.self_attention.dense", # bloom
@@ -202,7 +202,7 @@ class TensorNameMap:
202202
# Feed-forward norm
203203
MODEL_TENSOR.FFN_NORM: (
204204
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
205-
"transformer.h.{bid}.ln_2", # gpt2 refact qwen
205+
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais
206206
"h.{bid}.post_attention_layernorm", # bloom
207207
"transformer.blocks.{bid}.norm_2", # mpt
208208
"model.layers.{bid}.post_attention_layernorm", # llama-hf
@@ -239,7 +239,7 @@ class TensorNameMap:
239239
# Feed-forward up
240240
MODEL_TENSOR.FFN_UP: (
241241
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
242-
"transformer.h.{bid}.mlp.c_fc", # gpt2
242+
"transformer.h.{bid}.mlp.c_fc", # gpt2 jais
243243
"transformer.blocks.{bid}.ffn.up_proj", # mpt
244244
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
245245
"h.{bid}.mlp.dense_h_to_4h", # bloom
@@ -285,6 +285,7 @@ class TensorNameMap:
285285
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
286286
"layers.{bid}.feed_forward.w1", # llama-pth
287287
"transformer.h.{bid}.mlp.w2", # qwen
288+
"transformer.h.{bid}.mlp.c_fc2", # jais
288289
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
289290
"model.layers.{bid}.feed_forward.w1", # internlm2
290291
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
@@ -308,7 +309,7 @@ class TensorNameMap:
308309
# Feed-forward down
309310
MODEL_TENSOR.FFN_DOWN: (
310311
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
311-
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen
312+
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen jais
312313
"transformer.blocks.{bid}.ffn.down_proj", # mpt
313314
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
314315
"h.{bid}.mlp.dense_4h_to_h", # bloom

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ extern "C" {
8989
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
9090
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
9191
LLAMA_VOCAB_PRE_TYPE_VIKING = 16,
92+
LLAMA_VOCAB_PRE_TYPE_JAIS = 17,
9293
};
9394

9495
// note: these values should be synchronized with ggml_rope

0 commit comments

Comments
 (0)