@@ -554,6 +554,11 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
554554 loaded_sub_model = passed_class_obj [name ]
555555
556556 else :
557+ sub_model_dtype = (
558+ torch_dtype .get (name , torch_dtype .get ("_" , torch .float32 ))
559+ if isinstance (torch_dtype , dict )
560+ else torch_dtype
561+ )
557562 loaded_sub_model = _load_empty_model (
558563 library_name = library_name ,
559564 class_name = class_name ,
@@ -562,7 +567,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
562567 is_pipeline_module = is_pipeline_module ,
563568 pipeline_class = pipeline_class ,
564569 name = name ,
565- torch_dtype = torch_dtype ,
570+ torch_dtype = sub_model_dtype ,
566571 cached_folder = kwargs .get ("cached_folder" , None ),
567572 force_download = kwargs .get ("force_download" , None ),
568573 proxies = kwargs .get ("proxies" , None ),
@@ -578,7 +583,12 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
578583 # Obtain a sorted dictionary for mapping the model-level components
579584 # to their sizes.
580585 module_sizes = {
581- module_name : compute_module_sizes (module , dtype = torch_dtype )["" ]
586+ module_name : compute_module_sizes (
587+ module ,
588+ dtype = torch_dtype .get (module_name , torch_dtype .get ("_" , torch .float32 ))
589+ if isinstance (torch_dtype , dict )
590+ else torch_dtype ,
591+ )["" ]
582592 for module_name , module in init_empty_modules .items ()
583593 if isinstance (module , torch .nn .Module )
584594 }
0 commit comments