Skip to content

Commit f9f3dcb

Browse files
fix: gemma-3 checkpoint conversion from litgpt to hf (#2195)
Co-authored-by: Bhimraj Yadav <bhimrajyadav977@gmail.com>
1 parent dbddd23 commit f9f3dcb

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

litgpt/scripts/convert_lit_checkpoint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def copy_weights_gemma_2(
170170
config: Config,
171171
state_dict: Dict[str, torch.Tensor],
172172
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
173-
untie_weights: bool = False,
173+
untie_weights: bool = True,
174174
saver: Optional[incremental_save] = None,
175175
) -> None:
176176
weight_map = {
@@ -219,7 +219,7 @@ def copy_weights_gemma_3(
219219
config: Config,
220220
state_dict: Dict[str, torch.Tensor],
221221
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
222-
untie_weights: bool = False,
222+
untie_weights: bool = True,
223223
saver: Optional[incremental_save] = None,
224224
) -> None:
225225
weight_map = {
@@ -557,6 +557,8 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None:
557557
copy_fn = partial(copy_weights_falcon, config)
558558
elif config.name.startswith("Gemma-2"):
559559
copy_fn = partial(copy_weights_gemma_2, config)
560+
elif config.name.startswith("Gemma-3"):
561+
copy_fn = partial(copy_weights_gemma_3, config)
560562
elif config.name.lower().startswith("phi"):
561563
copy_fn = partial(copy_weights_phi, config)
562564
elif config.name.lower().startswith(("qwen2.5", "qwq")):

tests/convert/test_lit_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def test_against_original_gemma_2(model_name, device, dtype):
501501
ours_model.lm_head.weight = ours_model.transformer.wte.weight
502502
ours_state_dict = ours_model.state_dict()
503503
theirs_state_dict = {}
504-
copy_weights_gemma_2(ours_config, theirs_state_dict, ours_state_dict, untie_weights=True)
504+
copy_weights_gemma_2(ours_config, theirs_state_dict, ours_state_dict)
505505
theirs_model = Gemma2ForCausalLM(theirs_config).to(device)
506506
keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
507507
assert not keys.unexpected_keys
@@ -574,7 +574,7 @@ def test_against_original_gemma_3(model_name, device, dtype):
574574
ours_model.lm_head.weight = ours_model.transformer.wte.weight
575575
ours_state_dict = ours_model.state_dict()
576576
theirs_state_dict = {}
577-
copy_weights_gemma_3(ours_config, theirs_state_dict, ours_state_dict, untie_weights=True)
577+
copy_weights_gemma_3(ours_config, theirs_state_dict, ours_state_dict)
578578
theirs_model = Gemma3ForCausalLM(theirs_config).to(device)
579579
keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
580580
assert not keys.unexpected_keys

0 commit comments

Comments
 (0)