@@ -256,13 +256,99 @@ def remove(self, name):
256256 if self ._auto_offload_enabled :
257257 self .enable_auto_cpu_offload (self ._auto_offload_device )
258258
259+ # YiYi TODO: looking into improving the search pattern
259260 def get (self , names : Union [str , List [str ]]):
261+ """
262+ Get components by name with simple pattern matching.
263+
264+ Args:
265+ names: Component name(s) or pattern(s)
266+ Patterns:
267+ - "unet" : exact match
268+ - "!unet" : everything except exact match "unet"
269+ - "base_*" : everything starting with "base_"
270+ - "!base_*" : everything NOT starting with "base_"
271+ - "*unet*" : anything containing "unet"
272+ - "!*unet*" : anything NOT containing "unet"
273+ - "refiner|vae|unet" : anything containing any of these terms
274+ - "!refiner|vae|unet" : anything NOT containing any of these terms
275+
276+ Returns:
277+ Single component if names is str and matches one component,
278+ dict of components if names matches multiple components or is a list
279+ """
260280 if isinstance (names , str ):
261- if names not in self .components :
281+ # Check if this is a "not" pattern
282+ is_not_pattern = names .startswith ('!' )
283+ if is_not_pattern :
284+ names = names [1 :] # Remove the ! prefix
285+
286+ # Handle OR patterns (containing |)
287+ if '|' in names :
288+ terms = names .split ('|' )
289+ matches = {
290+ name : comp for name , comp in self .components .items ()
291+ if any ((term in name ) != is_not_pattern for term in terms ) # Flip condition if not pattern
292+ }
293+ if is_not_pattern :
294+ logger .info (f"Getting components NOT containing any of { terms } : { list (matches .keys ())} " )
295+ else :
296+ logger .info (f"Getting components containing any of { terms } : { list (matches .keys ())} " )
297+
298+ # Exact match
299+ elif names in self .components :
300+ if is_not_pattern :
301+ matches = {
302+ name : comp for name , comp in self .components .items ()
303+ if name != names
304+ }
305+ logger .info (f"Getting all components except '{ names } ': { list (matches .keys ())} " )
306+ else :
307+ logger .info (f"Getting component: { names } " )
308+ return self .components [names ]
309+
310+ # Prefix match (ends with *)
311+ elif names .endswith ('*' ):
312+ prefix = names [:- 1 ]
313+ matches = {
314+ name : comp for name , comp in self .components .items ()
315+ if name .startswith (prefix ) != is_not_pattern # Flip condition if not pattern
316+ }
317+ if is_not_pattern :
318+ logger .info (f"Getting components NOT starting with '{ prefix } ': { list (matches .keys ())} " )
319+ else :
320+ logger .info (f"Getting components starting with '{ prefix } ': { list (matches .keys ())} " )
321+
322+ # Contains match (starts with *)
323+ elif names .startswith ('*' ):
324+ search = names [1 :- 1 ] if names .endswith ('*' ) else names [1 :]
325+ matches = {
326+ name : comp for name , comp in self .components .items ()
327+ if (search in name ) != is_not_pattern # Flip condition if not pattern
328+ }
329+ if is_not_pattern :
330+ logger .info (f"Getting components NOT containing '{ search } ': { list (matches .keys ())} " )
331+ else :
332+ logger .info (f"Getting components containing '{ search } ': { list (matches .keys ())} " )
333+
334+ else :
262335 raise ValueError (f"Component '{ names } ' not found in ComponentsManager" )
263- return self .components [names ]
336+
337+ if not matches :
338+ raise ValueError (f"No components found matching pattern '{ names } '" )
339+ return matches if len (matches ) > 1 else next (iter (matches .values ()))
340+
264341 elif isinstance (names , list ):
265- return {n : self .components [n ] for n in names }
342+ results = {}
343+ for name in names :
344+ result = self .get (name )
345+ if isinstance (result , dict ):
346+ results .update (result )
347+ else :
348+ results [name ] = result
349+ logger .info (f"Getting multiple components: { list (results .keys ())} " )
350+ return results
351+
266352 else :
267353 raise ValueError (f"Invalid type for names: { type (names )} " )
268354
@@ -431,18 +517,34 @@ def __repr__(self):
431517
432518 return output
433519
434- def add_from_pretrained (self , pretrained_model_name_or_path , ** kwargs ):
520+ def add_from_pretrained (self , pretrained_model_name_or_path , prefix : Optional [str ] = None , ** kwargs ):
521+ """
522+ Load components from a pretrained model and add them to the manager.
523+
524+ Args:
525+ pretrained_model_name_or_path (str): The path or identifier of the pretrained model
526+ prefix (str, optional): Prefix to add to all component names loaded from this model.
527+ If provided, components will be named as "{prefix}_{component_name}"
528+ **kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained()
529+ """
435530 from ..pipelines .pipeline_utils import DiffusionPipeline
436531
437532 pipe = DiffusionPipeline .from_pretrained (pretrained_model_name_or_path , ** kwargs )
438533 for name , component in pipe .components .items ():
439- if name not in self .components and component is not None :
440- self .add (name , component )
441- elif name in self .components :
534+
535+ if component is None :
536+ continue
537+
538+ # Add prefix if specified
539+ component_name = f"{ prefix } _{ name } " if prefix else name
540+
541+ if component_name not in self .components :
542+ self .add (component_name , component )
543+ else :
442544 logger .warning (
443- f"Component '{ name } ' already exists in ComponentsManager and will not be added. To add it, either:\n "
444- f"1. remove the existing component with remove('{ name } ')\n "
445- f"2. Use a different name: add(' { name } _2', component )"
545+ f"Component '{ component_name } ' already exists in ComponentsManager and will not be added. To add it, either:\n "
546+ f"1. remove the existing component with remove('{ component_name } ')\n "
547+ f"2. Use a different prefix: add_from_pretrained(..., prefix=' { prefix } _2')"
446548 )
447549
448550def summarize_dict_by_value_and_parts (d : Dict [str , Any ]) -> Dict [str , Any ]:
0 commit comments