@@ -104,7 +104,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
104104 extension is replaced with ".safetensors"
105105 """
106106 passed_components = passed_components or []
107- if folder_names is not None :
107+ if folder_names :
108108 filenames = {f for f in filenames if os .path .split (f )[0 ] in folder_names }
109109
110110 # extract all components of the pipeline and their associated files
@@ -141,7 +141,25 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
141141 return True
142142
143143
144- def variant_compatible_siblings (filenames , variant = None ) -> Union [List [os .PathLike ], str ]:
144+ def filter_model_files (filenames ):
145+ """Filter model repo files for just files/folders that contain model weights"""
146+ weight_names = [
147+ WEIGHTS_NAME ,
148+ SAFETENSORS_WEIGHTS_NAME ,
149+ FLAX_WEIGHTS_NAME ,
150+ ONNX_WEIGHTS_NAME ,
151+ ONNX_EXTERNAL_WEIGHTS_NAME ,
152+ ]
153+
154+ if is_transformers_available ():
155+ weight_names += [TRANSFORMERS_WEIGHTS_NAME , TRANSFORMERS_SAFE_WEIGHTS_NAME , TRANSFORMERS_FLAX_WEIGHTS_NAME ]
156+
157+ allowed_extensions = [wn .split ("." )[- 1 ] for wn in weight_names ]
158+
159+ return [f for f in filenames if any (f .endswith (extension ) for extension in allowed_extensions )]
160+
161+
162+ def variant_compatible_siblings (filenames , variant = None , ignore_patterns = None ) -> Union [List [os .PathLike ], str ]:
145163 weight_names = [
146164 WEIGHTS_NAME ,
147165 SAFETENSORS_WEIGHTS_NAME ,
@@ -169,6 +187,10 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
169187 variant_index_re = re .compile (
170188 rf"({ '|' .join (weight_prefixes )} )\.({ '|' .join (weight_suffixs )} )\.index\.{ variant } \.json$"
171189 )
190+ legacy_variant_file_re = re .compile (rf".*-{ transformers_index_format } \.{ variant } \.[a-z]+$" )
191+ legacy_variant_index_re = re .compile (
192+ rf"({ '|' .join (weight_prefixes )} )\.({ '|' .join (weight_suffixs )} )\.{ variant } \.index\.json$"
193+ )
172194
173195 # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
174196 non_variant_file_re = re .compile (
@@ -177,54 +199,68 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
177199 # `text_encoder/pytorch_model.bin.index.json`
178200 non_variant_index_re = re .compile (rf"({ '|' .join (weight_prefixes )} )\.({ '|' .join (weight_suffixs )} )\.index\.json" )
179201
180- if variant is not None :
181- variant_weights = {f for f in filenames if variant_file_re .match (f .split ("/" )[- 1 ]) is not None }
182- variant_indexes = {f for f in filenames if variant_index_re .match (f .split ("/" )[- 1 ]) is not None }
183- variant_filenames = variant_weights | variant_indexes
184- else :
185- variant_filenames = set ()
202+ def filter_for_compatible_extensions (filenames , ignore_patterns = None ):
203+ if not ignore_patterns :
204+ return filenames
205+
206+ # ignore patterns uses glob style patterns e.g *.safetensors but we're only
207+ # interested in the extension name
208+ return {f for f in filenames if not any (f .endswith (pat .lstrip ("*." )) for pat in ignore_patterns )}
209+
210+ def filter_with_regex (filenames , pattern_re ):
211+ return {f for f in filenames if pattern_re .match (f .split ("/" )[- 1 ]) is not None }
212+
213+ # Group files by component
214+ components = {}
215+ for filename in filenames :
216+ if not len (filename .split ("/" )) == 2 :
217+ components .setdefault ("" , []).append (filename )
218+ continue
186219
187- non_variant_weights = {f for f in filenames if non_variant_file_re .match (f .split ("/" )[- 1 ]) is not None }
188- non_variant_indexes = {f for f in filenames if non_variant_index_re .match (f .split ("/" )[- 1 ]) is not None }
189- non_variant_filenames = non_variant_weights | non_variant_indexes
220+ component , _ = filename .split ("/" )
221+ components .setdefault (component , []).append (filename )
190222
191- # all variant filenames will be used by default
192- usable_filenames = set (variant_filenames )
223+ usable_filenames = set ()
224+ variant_filenames = set ()
225+ for component , component_filenames in components .items ():
226+ component_filenames = filter_for_compatible_extensions (component_filenames , ignore_patterns = ignore_patterns )
227+
228+ component_variants = set ()
229+ component_legacy_variants = set ()
230+ component_non_variants = set ()
231+ if variant is not None :
232+ component_variants = filter_with_regex (component_filenames , variant_file_re )
233+ component_variant_index_files = filter_with_regex (component_filenames , variant_index_re )
234+
235+ component_legacy_variants = filter_with_regex (component_filenames , legacy_variant_file_re )
236+ component_legacy_variant_index_files = filter_with_regex (component_filenames , legacy_variant_index_re )
237+
238+ if component_variants or component_legacy_variants :
239+ variant_filenames .update (
240+ component_variants | component_variant_index_files
241+ if component_variants
242+ else component_legacy_variants | component_legacy_variant_index_files
243+ )
193244
194- def convert_to_variant (filename ):
195- if "index" in filename :
196- variant_filename = filename .replace ("index" , f"index.{ variant } " )
197- elif re .compile (f"^(.*?){ transformers_index_format } " ).match (filename ) is not None :
198- variant_filename = f"{ filename .split ('-' )[0 ]} .{ variant } -{ '-' .join (filename .split ('-' )[1 :])} "
199245 else :
200- variant_filename = f" { filename . split ( '.' )[ 0 ] } . { variant } . { filename . split ( '.' )[ 1 ] } "
201- return variant_filename
246+ component_non_variants = filter_with_regex ( component_filenames , non_variant_file_re )
247+ component_variant_index_files = filter_with_regex ( component_filenames , non_variant_index_re )
202248
203- def find_component (filename ):
204- if not len (filename .split ("/" )) == 2 :
205- return
206- component = filename .split ("/" )[0 ]
207- return component
208-
209- def has_sharded_variant (component , variant , variant_filenames ):
210- # If component exists check for sharded variant index filename
211- # If component doesn't exist check main dir for sharded variant index filename
212- component = component + "/" if component else ""
213- variant_index_re = re .compile (
214- rf"{ component } ({ '|' .join (weight_prefixes )} )\.({ '|' .join (weight_suffixs )} )\.index\.{ variant } \.json$"
215- )
216- return any (f for f in variant_filenames if variant_index_re .match (f ) is not None )
249+ usable_filenames .update (component_non_variants | component_variant_index_files )
217250
218- for filename in non_variant_filenames :
219- if convert_to_variant (filename ) in variant_filenames :
220- continue
251+ usable_filenames .update (variant_filenames )
221252
222- component = find_component (filename )
223- # If a sharded variant exists skip adding to allowed patterns
224- if has_sharded_variant (component , variant , variant_filenames ):
225- continue
253+ if len (variant_filenames ) == 0 and variant is not None :
254+ error_message = f"You are trying to load model files of the `variant={ variant } `, but no such modeling files are available. "
255+ raise ValueError (error_message )
226256
227- usable_filenames .add (filename )
257+ if len (variant_filenames ) > 0 and usable_filenames != variant_filenames :
258+ logger .warning (
259+ f"\n A mixture of { variant } and non-{ variant } filenames will be loaded.\n Loaded { variant } filenames:\n "
260+ f"[{ ', ' .join (variant_filenames )} ]\n Loaded non-{ variant } filenames:\n "
261+ f"[{ ', ' .join (usable_filenames - variant_filenames )} \n If this behavior is not "
262+ f"expected, please check your folder structure."
263+ )
228264
229265 return usable_filenames , variant_filenames
230266
@@ -922,18 +958,13 @@ def _get_custom_components_and_folders(
922958 f"{ candidate_file } as defined in `model_index.json` does not exist in { pretrained_model_name } and is not a module in 'diffusers/pipelines'."
923959 )
924960
925- if len (variant_filenames ) == 0 and variant is not None :
926- error_message = f"You are trying to load the model files of the `variant={ variant } `, but no such modeling files are available."
927- raise ValueError (error_message )
928-
929961 return custom_components , folder_names
930962
931963
932964def _get_ignore_patterns (
933965 passed_components ,
934966 model_folder_names : List [str ],
935967 model_filenames : List [str ],
936- variant_filenames : List [str ],
937968 use_safetensors : bool ,
938969 from_flax : bool ,
939970 allow_pickle : bool ,
@@ -964,33 +995,13 @@ def _get_ignore_patterns(
964995 if not use_onnx :
965996 ignore_patterns += ["*.onnx" , "*.pb" ]
966997
967- safetensors_variant_filenames = {f for f in variant_filenames if f .endswith (".safetensors" )}
968- safetensors_model_filenames = {f for f in model_filenames if f .endswith (".safetensors" )}
969- if len (safetensors_variant_filenames ) > 0 and safetensors_model_filenames != safetensors_variant_filenames :
970- logger .warning (
971- f"\n A mixture of { variant } and non-{ variant } filenames will be loaded.\n Loaded { variant } filenames:\n "
972- f"[{ ', ' .join (safetensors_variant_filenames )} ]\n Loaded non-{ variant } filenames:\n "
973- f"[{ ', ' .join (safetensors_model_filenames - safetensors_variant_filenames )} \n If this behavior is not "
974- f"expected, please check your folder structure."
975- )
976-
977998 else :
978999 ignore_patterns = ["*.safetensors" , "*.msgpack" ]
9791000
9801001 use_onnx = use_onnx if use_onnx is not None else is_onnx
9811002 if not use_onnx :
9821003 ignore_patterns += ["*.onnx" , "*.pb" ]
9831004
984- bin_variant_filenames = {f for f in variant_filenames if f .endswith (".bin" )}
985- bin_model_filenames = {f for f in model_filenames if f .endswith (".bin" )}
986- if len (bin_variant_filenames ) > 0 and bin_model_filenames != bin_variant_filenames :
987- logger .warning (
988- f"\n A mixture of { variant } and non-{ variant } filenames will be loaded.\n Loaded { variant } filenames:\n "
989- f"[{ ', ' .join (bin_variant_filenames )} ]\n Loaded non-{ variant } filenames:\n "
990- f"[{ ', ' .join (bin_model_filenames - bin_variant_filenames )} \n If this behavior is not expected, please check "
991- f"your folder structure."
992- )
993-
9941005 return ignore_patterns
9951006
9961007
0 commit comments