@@ -143,7 +143,7 @@ def _from_kohya(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dic
143143 layer_id , layer_type = name .split ("_" , 1 )
144144 layer_type = layer_type .replace ("self_attn_" , "self_attn." ).replace ("mlp_" , "mlp." )
145145 rename = "." .join (["encoders" , layer_id , clip_attn_rename_dict [layer_type ]])
146-
146+
147147 lora_args = {}
148148 lora_args ["alpha" ] = param
149149 lora_args ["up" ] = lora_state_dict [origin_key .replace (".alpha" , ".lora_up.weight" )]
@@ -517,7 +517,7 @@ def _from_state_dict(cls, state_dicts: FluxStateDicts, config: FluxPipelineConfi
517517 if config .use_fbcache :
518518 dit = FluxDiTFBCache .from_state_dict (
519519 state_dicts .model ,
520- device = init_device ,
520+ device = ( "cpu" if config . use_fsdp else init_device ) ,
521521 dtype = config .model_dtype ,
522522 in_channel = config .control_type .get_in_channel (),
523523 attn_kwargs = attn_kwargs ,
@@ -526,7 +526,7 @@ def _from_state_dict(cls, state_dicts: FluxStateDicts, config: FluxPipelineConfi
526526 else :
527527 dit = FluxDiT .from_state_dict (
528528 state_dicts .model ,
529- device = init_device ,
529+ device = ( "cpu" if config . use_fsdp else init_device ) ,
530530 dtype = config .model_dtype ,
531531 in_channel = config .control_type .get_in_channel (),
532532 attn_kwargs = attn_kwargs ,
0 commit comments