@@ -124,7 +124,7 @@ def register_foreach_default(self, name: str) -> None:
124124
125125 def list_optimizers (
126126 self ,
127- filter : str = '' ,
127+ filter : Union [ str , List [ str ]] = '' ,
128128 exclude_filters : Optional [List [str ]] = None ,
129129 with_description : bool = False
130130 ) -> List [Union [str , Tuple [str , str ]]]:
@@ -141,14 +141,22 @@ def list_optimizers(
141141 names = sorted (self ._optimizers .keys ())
142142
143143 if filter :
144- names = [n for n in names if fnmatch (n , filter )]
144+ if isinstance (filter , str ):
145+ filters = [filter ]
146+ else :
147+ filters = filter
148+ filtered_names = set ()
149+ for f in filters :
150+ filtered_names .update (n for n in names if fnmatch (n , f ))
151+ names = sorted (filtered_names )
145152
146153 if exclude_filters :
147154 for exclude_filter in exclude_filters :
148155 names = [n for n in names if not fnmatch (n , exclude_filter )]
149156
150157 if with_description :
151158 return [(name , self ._optimizers [name ].description ) for name in names ]
159+
152160 return names
153161
154162 def get_optimizer_info (self , name : str ) -> OptimInfo :
@@ -718,11 +726,46 @@ def _register_default_optimizers() -> None:
718726# Public API
719727
720728def list_optimizers (
721- filter : str = '' ,
729+ filter : Union [ str , List [ str ]] = '' ,
722730 exclude_filters : Optional [List [str ]] = None ,
723731 with_description : bool = False ,
724732) -> List [Union [str , Tuple [str , str ]]]:
725733 """List available optimizer names, optionally filtered.
734+
735+ List all registered optimizers, with optional filtering using wildcard patterns.
736+ Optimizers can be filtered using include and exclude patterns, and can optionally
737+ return descriptions with each optimizer name.
738+
739+ Args:
740+ filter: Wildcard style filter string or list of filter strings
741+ (e.g., 'adam*' for all Adam variants, or ['adam*', '*8bit'] for
742+ Adam variants and 8-bit optimizers). Empty string means no filtering.
743+ exclude_filters: Optional list of wildcard patterns to exclude. For example,
744+ ['*8bit', 'fused*'] would exclude 8-bit and fused implementations.
745+ with_description: If True, returns tuples of (name, description) instead of
746+ just names. Descriptions provide brief explanations of optimizer characteristics.
747+
748+ Returns:
749+ If with_description is False:
750+ List of optimizer names as strings (e.g., ['adam', 'adamw', ...])
751+ If with_description is True:
752+ List of tuples of (name, description) (e.g., [('adam', 'Adaptive Moment...'), ...])
753+
754+ Examples:
755+ >>> list_optimizers()
756+ ['adam', 'adamw', 'sgd', ...]
757+
758+ >>> list_optimizers(['la*', 'nla*']) # List lamb & lars
759+ ['lamb', 'lambc', 'larc', 'lars', 'nlarc', 'nlars']
760+
761+ >>> list_optimizers('*adam*', exclude_filters=['bnb*', 'fused*']) # Exclude bnb & apex adam optimizers
762+ ['adam', 'adamax', 'adamp', 'adamw', 'nadam', 'nadamw', 'radam']
763+
764+ >>> list_optimizers(with_description=True) # Get descriptions
765+ [('adabelief', 'Adapts learning rate based on gradient prediction error'),
766+ ('adadelta', 'torch.optim Adadelta, Adapts learning rates based on running windows of gradients'),
767+ ('adafactor', 'Memory-efficient implementation of Adam with factored gradients'),
768+ ...]
726769 """
727770 return default_registry .list_optimizers (filter , exclude_filters , with_description )
728771
@@ -731,7 +774,38 @@ def get_optimizer_class(
731774 name : str ,
732775 bind_defaults : bool = False ,
733776) -> Union [Type [optim .Optimizer ], OptimizerCallable ]:
734- """Get optimizer class by name with any defaults applied.
777+ """Get optimizer class by name with option to bind default arguments.
778+
779+ Retrieves the optimizer class or a partial function with default arguments bound.
780+ This allows direct instantiation of optimizers with their default configurations
781+ without going through the full factory.
782+
783+ Args:
784+ name: Name of the optimizer to retrieve (e.g., 'adam', 'sgd')
785+ bind_defaults: If True, returns a partial function with default arguments from OptimInfo bound.
786+ If False, returns the raw optimizer class.
787+
788+ Returns:
789+ If bind_defaults is False:
790+ The optimizer class (e.g., torch.optim.Adam)
791+ If bind_defaults is True:
792+ A partial function with default arguments bound
793+
794+ Raises:
795+ ValueError: If optimizer name is not found in registry
796+
797+ Examples:
798+ >>> # Get raw optimizer class
799+ >>> Adam = get_optimizer_class('adam')
800+ >>> opt = Adam(model.parameters(), lr=1e-3)
801+
802+ >>> # Get optimizer with defaults bound
803+ >>> AdamWithDefaults = get_optimizer_class('adam', bind_defaults=True)
804+ >>> opt = AdamWithDefaults(model.parameters(), lr=1e-3)
805+
806+ >>> # Get SGD with nesterov momentum default
807+ >>> SGD = get_optimizer_class('sgd', bind_defaults=True) # nesterov=True bound
808+ >>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9)
735809 """
736810 return default_registry .get_optimizer_class (name , bind_defaults = bind_defaults )
737811
@@ -748,7 +822,69 @@ def create_optimizer_v2(
748822 param_group_fn : Optional [Callable [[nn .Module ], Params ]] = None ,
749823 ** kwargs : Any ,
750824) -> optim .Optimizer :
751- """Create an optimizer instance using the default registry."""
825+ """Create an optimizer instance via timm registry.
826+
827+ Creates and configures an optimizer with appropriate parameter groups and settings.
828+ Supports automatic parameter group creation for weight decay and layer-wise learning
829+ rates, as well as custom parameter grouping.
830+
831+ Args:
832+ model_or_params: A PyTorch model or an iterable of parameters/parameter groups.
833+ If a model is provided, parameters will be automatically extracted and grouped
834+ based on the other arguments.
835+ opt: Name of the optimizer to create (e.g., 'adam', 'adamw', 'sgd').
836+ Use list_optimizers() to see available options.
837+ lr: Learning rate. If None, will use the optimizer's default.
838+ weight_decay: Weight decay factor. Will be used to create param groups if model_or_params is a model.
839+ momentum: Momentum factor for optimizers that support it. Only used if the
840+ chosen optimizer accepts a momentum parameter.
841+ foreach: Enable/disable foreach (multi-tensor) implementation if available.
842+ If None, will use optimizer-specific defaults.
843+ filter_bias_and_bn: If True, bias, norm layer parameters (all 1d params) will not have
844+ weight decay applied. Only used when model_or_params is a model and
845+ weight_decay > 0.
846+ layer_decay: Optional layer-wise learning rate decay factor. If provided,
847+ learning rates will be scaled by layer_decay^(max_depth - layer_depth).
848+ Only used when model_or_params is a model.
849+ param_group_fn: Optional function to create custom parameter groups.
850+ If provided, other parameter grouping options will be ignored.
851+ **kwargs: Additional optimizer-specific arguments (e.g., betas for Adam).
852+
853+ Returns:
854+ Configured optimizer instance.
855+
856+ Examples:
857+ >>> # Basic usage with a model
858+ >>> optimizer = create_optimizer_v2(model, 'adamw', lr=1e-3)
859+
860+ >>> # SGD with momentum and weight decay
861+ >>> optimizer = create_optimizer_v2(
862+ ... model, 'sgd', lr=0.1, momentum=0.9, weight_decay=1e-4
863+ ... )
864+
865+ >>> # Adam with layer-wise learning rate decay
866+ >>> optimizer = create_optimizer_v2(
867+ ... model, 'adam', lr=1e-3, layer_decay=0.7
868+ ... )
869+
870+ >>> # Custom parameter groups
871+ >>> def group_fn(model):
872+ ... return [
873+ ... {'params': model.backbone.parameters(), 'lr': 1e-4},
874+ ... {'params': model.head.parameters(), 'lr': 1e-3}
875+ ... ]
876+ >>> optimizer = create_optimizer_v2(
877+ ... model, 'sgd', param_group_fn=group_fn
878+ ... )
879+
880+ Note:
881+ Parameter group handling precedence:
882+ 1. If param_group_fn is provided, it will be used exclusively
883+ 2. If layer_decay is provided, layer-wise groups will be created
884+ 3. If weight_decay > 0 and filter_bias_and_bn is True, weight decay groups will be created
885+ 4. Otherwise, all parameters will be in a single group
886+ """
887+
752888 return default_registry .create_optimizer (
753889 model_or_params ,
754890 opt = opt ,
0 commit comments