Skip to content

Commit b404b69

Browse files
authored
[2/4] add gemma 3 1b (#2000)
1 parent 1f422b6 commit b404b69

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

litgpt/config.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,6 +1058,38 @@ def norm_class(self) -> Type:
10581058
# Google Gemma 3
10591059
##################
10601060
gemma3 = [
1061+
# https://huggingface.co/google/gemma-3-1b-it/blob/main/config.json
1062+
dict(
1063+
name="Gemma-3-1b-it",
1064+
hf_config=dict(org="google", name="gemma-3-1b-it"),
1065+
scale_embeddings=True,
1066+
attention_scores_scalar=256,
1067+
vocab_size=262144,
1068+
block_size=131072,
1069+
sliding_window_size=512,
1070+
# 5 local layers for every global layer
1071+
sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(26)],
1072+
intermediate_size=21504,
1073+
n_embd=1152,
1074+
n_layer=26,
1075+
n_head=4,
1076+
n_query_groups=1,
1077+
head_size=256,
1078+
rotary_percentage=1.0,
1079+
rope_adjustments=None,
1080+
parallel_residual=False,
1081+
bias=False,
1082+
norm_class_name="RMSNorm",
1083+
mlp_class_name="GemmaMLP",
1084+
gelu_approximate="tanh",
1085+
post_attention_norm=True,
1086+
post_mlp_norm=True,
1087+
norm_qk=True,
1088+
rope_base=1000000,
1089+
rope_local_base_freq=10000,
1090+
# 5 local layers for every global layer
1091+
rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(26)],
1092+
),
10611093
# https://huggingface.co/google/gemma-3-27b-it/blob/main/config.json
10621094
dict(
10631095
name="Gemma-3-27b-it",

tests/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ def test_against_original_gemma_2(model_name, device, dtype):
802802

803803

804804
@torch.inference_mode()
805-
@pytest.mark.parametrize("model_name", ["gemma-3-27b-it"])
805+
@pytest.mark.parametrize("model_name", ["gemma-3-1b-it", "gemma-3-27b-it"])
806806
@pytest.mark.parametrize(
807807
("device", "dtype"),
808808
[

0 commit comments

Comments
 (0)