Skip to content

Commit db6b08d

Browse files
authored
[3/4] feat: add gemma 3 4b (#2001)
1 parent 05d83a7 commit db6b08d

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

litgpt/config.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,38 @@ def norm_class(self) -> Type:
10901090
# 5 local layers for every global layer
10911091
rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(26)],
10921092
),
1093+
# https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json
1094+
dict(
1095+
name="Gemma-3-4b-it",
1096+
hf_config=dict(org="google", name="gemma-3-4b-it"),
1097+
scale_embeddings=True,
1098+
attention_scores_scalar=256,
1099+
vocab_size=262144,
1100+
block_size=131072,
1101+
sliding_window_size=1024,
1102+
# 5 local layers for every global layer
1103+
sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(34)],
1104+
intermediate_size=10240,
1105+
n_embd=2560,
1106+
n_layer=34,
1107+
n_head=8,
1108+
n_query_groups=4,
1109+
head_size=256,
1110+
rotary_percentage=1.0,
1111+
rope_adjustments=dict(factor=8.0),
1112+
parallel_residual=False,
1113+
bias=False,
1114+
norm_class_name="RMSNorm",
1115+
mlp_class_name="GemmaMLP",
1116+
gelu_approximate="tanh",
1117+
post_attention_norm=True,
1118+
post_mlp_norm=True,
1119+
norm_qk=True,
1120+
rope_base=1000000,
1121+
rope_local_base_freq=10000,
1122+
# 5 local layers for every global layer
1123+
rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(34)],
1124+
),
10931125
# https://huggingface.co/google/gemma-3-12b-it/blob/main/config.json
10941126
dict(
10951127
name="Gemma-3-12b-it",

tests/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ def test_against_original_gemma_2(model_name, device, dtype):
802802

803803

804804
@torch.inference_mode()
805-
@pytest.mark.parametrize("model_name", ["gemma-3-1b-it", "gemma-3-12b-it", "gemma-3-27b-it"])
805+
@pytest.mark.parametrize("model_name", ["gemma-3-1b-it", "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"])
806806
@pytest.mark.parametrize(
807807
("device", "dtype"),
808808
[

0 commit comments

Comments
 (0)