@@ -2353,22 +2353,25 @@ def _maybe_expand_transformer_param_shape_or_error_(
23532353 parent_module_name , _ , current_module_name = name .rpartition ("." )
23542354 parent_module = transformer .get_submodule (parent_module_name )
23552355
2356- # TODO: consider initializing this under meta device for optims.
2357- expanded_module = torch .nn .Linear (
2358- in_features , out_features , bias = bias , device = module_weight . device , dtype = module_weight .dtype
2359- )
2356+ with torch . device ( "meta" ):
2357+ expanded_module = torch .nn .Linear (
2358+ in_features , out_features , bias = bias , dtype = module_weight .dtype
2359+ )
23602360 # Only weights are expanded and biases are not.
23612361 new_weight = torch .zeros_like (
23622362 expanded_module .weight .data , device = module_weight .device , dtype = module_weight .dtype
23632363 )
23642364 slices = tuple (slice (0 , dim ) for dim in module_weight .shape )
23652365 new_weight [slices ] = module_weight
2366- expanded_module . weight . data . copy_ ( new_weight )
2366+ tmp_state_dict = { " weight" : new_weight }
23672367 if module_bias is not None :
2368- expanded_module .bias .data .copy_ (module_bias )
2368+ tmp_state_dict ["bias" ] = module_bias
2369+ expanded_module .load_state_dict (tmp_state_dict , strict = True , assign = True )
23692370
23702371 setattr (parent_module , current_module_name , expanded_module )
23712372
2373+ del tmp_state_dict
2374+
23722375 if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX :
23732376 attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX [current_module_name ]
23742377 new_value = int (expanded_module .weight .data .shape [1 ])
0 commit comments