Skip to content

Commit 775158c

Browse files
authored
Fix/loading gemma 3 1b (#2004)
1 parent 7baccd4 commit 775158c

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

litgpt/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1069,7 +1069,7 @@ def norm_class(self) -> Type:
10691069
sliding_window_size=512,
10701070
# 5 local layers for every global layer
10711071
sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(26)],
1072-
intermediate_size=21504,
1072+
intermediate_size=6912,
10731073
n_embd=1152,
10741074
n_layer=26,
10751075
n_head=4,

litgpt/scripts/convert_hf_checkpoint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,9 @@ def convert_hf_checkpoint(
602602
elif model_name.lower().startswith("gemma-2"):
603603
qkv_weights = {}
604604
copy_fn = partial(copy_weights_gemma_2, qkv_weights)
605+
elif model_name.lower().startswith("gemma-3"):
606+
qkv_weights = {}
607+
copy_fn = partial(copy_weights_gemma_3, qkv_weights)
605608
elif model_name.lower().startswith("phi"):
606609
# holder to reconstitute the split q, k, v
607610
qkv_weights = {}

0 commit comments

Comments
 (0)