@@ -177,99 +177,67 @@ def variant_compatible_siblings(filenames, variant=None, use_safetensors=True) -
177177    # `text_encoder/pytorch_model.bin.index.json` 
178178    non_variant_index_re  =  re .compile (rf"({ '|' .join (weight_prefixes )} { '|' .join (weight_suffixs )}  )
179179
180-     if  variant  is  not None :
181-         variant_weights  =  {f  for  f  in  filenames  if  variant_file_re .match (f .split ("/" )[- 1 ]) is  not None }
182-         variant_indexes  =  {f  for  f  in  filenames  if  variant_index_re .match (f .split ("/" )[- 1 ]) is  not None }
183-         variant_filenames  =  variant_weights  |  variant_indexes 
184-     else :
185-         variant_filenames  =  set ()
180+     def  filter_for_compatible_extensions (filenames , variant = None , use_safetensors = True ):
181+         def  is_safetensors (filename ):
182+             return  ".safetensors"  in  filename 
186183
187-     non_variant_weights  =  {f  for  f  in  filenames  if  non_variant_file_re .match (f .split ("/" )[- 1 ]) is  not None }
188-     non_variant_indexes  =  {f  for  f  in  filenames  if  non_variant_index_re .match (f .split ("/" )[- 1 ]) is  not None }
189-     non_variant_filenames  =  non_variant_weights  |  non_variant_indexes 
184+         def  is_not_safetensors (filename ):
185+             return  ".safetensors"  not  in filename 
190186
191-     def  find_component (filename ):
192-         if  not  len (filename .split ("/" )) ==  2 :
193-             return 
194-         component  =  filename .split ("/" )[0 ]
195-         return  component 
196- 
197-     def  convert_to_variant (filename ):
198-         if  "index"  in  filename :
199-             variant_filename  =  filename .replace ("index" , f"index.{ variant }  )
200-         elif  re .compile (f"^(.*?){ transformers_index_format }  ).match (filename ) is  not None :
201-             variant_filename  =  f"{ filename .split ('-' )[0 ]} { variant } { '-' .join (filename .split ('-' )[1 :])}  
187+         if  use_safetensors  and  is_safetensors_compatible (filenames ):
188+             extension_filter  =  is_safetensors 
202189        else :
203-             variant_filename  =  f"{ filename .split ('.' )[0 ]} { variant } { filename .split ('.' )[1 ]}  
204-         return  variant_filename 
205- 
206-     def  has_sharded_variant (filename , variant , variant_filenames ):
207-         component  =  find_component (filename )
208-         # If component exists check for sharded variant index filename 
209-         # If component doesn't exist check main dir for sharded variant index filename 
210-         component  =  component  +  "/"  if  component  else  "" 
211-         variant_index_re  =  re .compile (
212-             rf"{ component } { '|' .join (weight_prefixes )} { '|' .join (weight_suffixs )} { variant }  
213-         )
214-         return  any (f  for  f  in  variant_filenames  if  variant_index_re .match (f ) is  not None )
215- 
216-     def  has_non_sharded_variant (filename , variant , variant_filenames ):
217-         component  =  find_component (filename )
218-         component  =  component  +  "/"  if  component  else  "" 
219-         base_name  =  filename .split ("/" )[- 1 ]
220- 
221-         # Only apply to sharded files (those with the index format) 
222-         if  not  (non_variant_file_re .match (base_name ) or  non_variant_index_re .match (base_name )):
223-             return  False 
190+             extension_filter  =  is_not_safetensors 
224191
225-         # Check if there's a non-sharded variant in the same component 
226-         non_sharded_variants  =  [
227-             f 
228-             for  f  in  variant_filenames 
229-             if  f .startswith (component ) and  not  re .search (transformers_index_format , f .split ("/" )[- 1 ])
230-         ]
231-         return  any (non_sharded_variants )
232- 
233-     if  use_safetensors :
234-         # Keep only safetensors and index files 
235-         non_variant_filenames  =  {
236-             f 
237-             for  f  in  non_variant_filenames 
238-             if  f .endswith (".safetensors" ) or  non_variant_index_re .match (f .split ("/" )[- 1 ])
192+         tensor_files  =  {f  for  f  in  filenames  if  extension_filter (f )}
193+         non_variant_indexes  =  {
194+             f  for  f  in  filenames  if  non_variant_index_re .match (f .split ("/" )[- 1 ]) is  not None  and  extension_filter (f )
239195        }
240-         if  variant  is  not None :
241-             variant_filenames  =  {
242-                 f  for  f  in  variant_filenames  if  f .endswith (".safetensors" ) or  variant_index_re .match (f .split ("/" )[- 1 ])
243-             }
244-     else :
245-         # Exclude safetensors files but keep index files 
246-         non_variant_filenames  =  {
196+         variant_indexes  =  {
247197            f 
248-             for  f  in  non_variant_filenames 
249-             if  not  f . endswith ( ".safetensors" )  or   non_variant_index_re .match (f .split ("/" )[- 1 ])
198+             for  f  in  filenames 
199+             if  variant   is   not None   and   variant_index_re .match (f .split ("/" )[- 1 ])  is   not   None   and   extension_filter ( f )
250200        }
251-         if  variant  is  not None :
252-             variant_filenames  =  {
253-                 f 
254-                 for  f  in  variant_filenames 
255-                 if  not  f .endswith (".safetensors" ) or  variant_index_re .match (f .split ("/" )[- 1 ])
256-             }
257201
258-     # all variant filenames will be used by default 
259-     usable_filenames  =  set (variant_filenames )
202+         return  tensor_files  |  non_variant_indexes  |  variant_indexes 
260203
261-     for  filename  in  non_variant_filenames :
262-         if  convert_to_variant (filename ) in  variant_filenames :
263-             continue 
204+     def  filter_for_weights_and_indexes (filenames , file_re , index_re ):
205+         weights  =  {f  for  f  in  filenames  if  file_re .match (f .split ("/" )[- 1 ]) is  not None }
206+         indexes  =  {f  for  f  in  filenames  if  index_re .match (f .split ("/" )[- 1 ]) is  not None }
207+         filtered_filenames  =  weights  |  indexes 
264208
265-         # If a sharded variant exists skip adding to allowed patterns 
266-         if  has_sharded_variant (filename , variant , variant_filenames ):
267-             continue 
209+         return  filtered_filenames 
268210
269-         if  has_non_sharded_variant (filename , variant , variant_filenames ):
211+     # Group files by component 
212+     components  =  {}
213+     for  filename  in  filenames :
214+         if  not  len (filename .split ("/" )) ==  2 :
215+             components .setdefault ("" , []).append (filename )
270216            continue 
271217
272-         usable_filenames .add (filename )
218+         component , _  =  filename .split ("/" )
219+         components .setdefault (component , []).append (filename )
220+ 
221+     usable_filenames  =  set ()
222+     variant_filenames  =  set ()
223+     for  component , component_filenames  in  components .items ():
224+         component_filenames  =  filter_for_compatible_extensions (
225+             component_filenames , variant = variant , use_safetensors = use_safetensors 
226+         )
227+ 
228+         component_variants  =  set ()
229+         if  variant  is  not None :
230+             component_variants  =  filter_for_weights_and_indexes (component_filenames , variant_file_re , variant_index_re )
231+ 
232+         if  component_variants :
233+             variant_filenames .update (component_variants )
234+             usable_filenames .update (component_variants )
235+ 
236+         else :
237+             component_non_variants  =  filter_for_weights_and_indexes (
238+                 component_filenames , non_variant_file_re , non_variant_index_re 
239+             )
240+             usable_filenames .update (component_non_variants )
273241
274242    return  usable_filenames , variant_filenames 
275243
0 commit comments