@@ -557,6 +557,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
557557 variant = kwargs .pop ("variant" , None )
558558 use_safetensors = kwargs .pop ("use_safetensors" , None )
559559 quantization_config = kwargs .pop ("quantization_config" , None )
560+ dduf_reader = kwargs .pop ("dduf_reader" , None )
560561
561562 allow_pickle = False
562563 if use_safetensors is None :
@@ -649,6 +650,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
649650 revision = revision ,
650651 subfolder = subfolder ,
651652 user_agent = user_agent ,
653+ dduf_reader = dduf_reader ,
652654 ** kwargs ,
653655 )
654656 # no in-place modification of the original config.
@@ -724,6 +726,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
724726 "revision" : revision ,
725727 "user_agent" : user_agent ,
726728 "commit_hash" : commit_hash ,
729+ "dduf_reader" : dduf_reader ,
727730 }
728731 index_file = _fetch_index_file (** index_file_kwargs )
729732 # In case the index file was not found we still have to consider the legacy format.
@@ -759,7 +762,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
759762
760763 model = load_flax_checkpoint_in_pytorch_model (model , model_file )
761764 else :
762- if is_sharded :
765+ # in the case it is sharded, we have already the index
766+ if is_sharded and not dduf_reader :
763767 sharded_ckpt_cached_folder , sharded_metadata = _get_checkpoint_shard_files (
764768 pretrained_model_name_or_path ,
765769 index_file ,
@@ -790,6 +794,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
790794 subfolder = subfolder ,
791795 user_agent = user_agent ,
792796 commit_hash = commit_hash ,
797+ dduf_reader = dduf_reader ,
793798 )
794799
795800 except IOError as e :
@@ -813,6 +818,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
813818 subfolder = subfolder ,
814819 user_agent = user_agent ,
815820 commit_hash = commit_hash ,
821+ dduf_reader = dduf_reader ,
816822 )
817823
818824 if low_cpu_mem_usage :
@@ -837,7 +843,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
837843 # TODO (sayakpaul, SunMarc): remove this after model loading refactor
838844 elif is_quant_method_bnb :
839845 param_device = torch .cuda .current_device ()
840- state_dict = load_state_dict (model_file , variant = variant )
846+ state_dict = load_state_dict (model_file , variant = variant , dduf_reader = dduf_reader )
841847 model ._convert_deprecated_attention_blocks (state_dict )
842848
843849 # move the params from meta device to cpu
@@ -937,7 +943,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
937943 else :
938944 model = cls .from_config (config , ** unused_kwargs )
939945
940- state_dict = load_state_dict (model_file , variant = variant )
946+ state_dict = load_state_dict (model_file , variant = variant , dduf_reader = dduf_reader )
941947 model ._convert_deprecated_attention_blocks (state_dict )
942948
943949 model , missing_keys , unexpected_keys , mismatched_keys , error_msgs = cls ._load_pretrained_model (
0 commit comments