@@ -170,7 +170,7 @@ def copy_weights_gemma_2(
170170 config : Config ,
171171 state_dict : Dict [str , torch .Tensor ],
172172 lit_weights : Dict [str , Union [torch .Tensor , NotYetLoadedTensor ]],
173- untie_weights : bool = False ,
173+ untie_weights : bool = True ,
174174 saver : Optional [incremental_save ] = None ,
175175) -> None :
176176 weight_map = {
@@ -219,7 +219,7 @@ def copy_weights_gemma_3(
219219 config : Config ,
220220 state_dict : Dict [str , torch .Tensor ],
221221 lit_weights : Dict [str , Union [torch .Tensor , NotYetLoadedTensor ]],
222- untie_weights : bool = False ,
222+ untie_weights : bool = True ,
223223 saver : Optional [incremental_save ] = None ,
224224) -> None :
225225 weight_map = {
@@ -557,6 +557,8 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None:
557557 copy_fn = partial (copy_weights_falcon , config )
558558 elif config .name .startswith ("Gemma-2" ):
559559 copy_fn = partial (copy_weights_gemma_2 , config )
560+ elif config .name .startswith ("Gemma-3" ):
561+ copy_fn = partial (copy_weights_gemma_3 , config )
560562 elif config .name .lower ().startswith ("phi" ):
561563 copy_fn = partial (copy_weights_phi , config )
562564 elif config .name .lower ().startswith (("qwen2.5" , "qwq" )):
0 commit comments