Skip to content

Commit 75d676e

Browse files
committed
Try to fix documentation build, add better docstrings to public optimizer api
1 parent c8b4511 commit 75d676e

File tree

2 files changed

+149
-7
lines changed

2 files changed

+149
-7
lines changed

hfdocs/source/reference/optimizers.mdx

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,28 @@ This page contains the API reference documentation for learning rate optimizers
66

77
### Factory functions
88

9-
[[autodoc]] timm.optim.optim_factory.create_optimizer
10-
[[autodoc]] timm.optim.optim_factory.create_optimizer_v2
9+
[[autodoc]] timm.optim.create_optimizer_v2
10+
[[autodoc]] timm.optim.list_optimizers
11+
[[autodoc]] timm.optim.get_optimizer_class
1112

1213
### Optimizer Classes
1314

1415
[[autodoc]] timm.optim.adabelief.AdaBelief
1516
[[autodoc]] timm.optim.adafactor.Adafactor
17+
[[autodoc]] timm.optim.adafactor_bv.AdafactorBigVision
1618
[[autodoc]] timm.optim.adahessian.Adahessian
1719
[[autodoc]] timm.optim.adamp.AdamP
1820
[[autodoc]] timm.optim.adamw.AdamW
21+
[[autodoc]] timm.optim.adopt.Adopt
1922
[[autodoc]] timm.optim.lamb.Lamb
2023
[[autodoc]] timm.optim.lars.Lars
24+
[[autodoc]] timm.optim.lion,Lion
2125
[[autodoc]] timm.optim.lookahead.Lookahead
2226
[[autodoc]] timm.optim.madgrad.MADGRAD
2327
[[autodoc]] timm.optim.nadam.Nadam
28+
[[autodoc]] timm.optim.nadamw.NadamW
2429
[[autodoc]] timm.optim.nvnovograd.NvNovoGrad
2530
[[autodoc]] timm.optim.radam.RAdam
2631
[[autodoc]] timm.optim.rmsprop_tf.RMSpropTF
2732
[[autodoc]] timm.optim.sgdp.SGDP
33+
[[autodoc]] timm.optim.sgdw.SGDW

timm/optim/_optim_factory.py

Lines changed: 141 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def register_foreach_default(self, name: str) -> None:
124124

125125
def list_optimizers(
126126
self,
127-
filter: str = '',
127+
filter: Union[str, List[str]] = '',
128128
exclude_filters: Optional[List[str]] = None,
129129
with_description: bool = False
130130
) -> List[Union[str, Tuple[str, str]]]:
@@ -141,14 +141,22 @@ def list_optimizers(
141141
names = sorted(self._optimizers.keys())
142142

143143
if filter:
144-
names = [n for n in names if fnmatch(n, filter)]
144+
if isinstance(filter, str):
145+
filters = [filter]
146+
else:
147+
filters = filter
148+
filtered_names = set()
149+
for f in filters:
150+
filtered_names.update(n for n in names if fnmatch(n, f))
151+
names = sorted(filtered_names)
145152

146153
if exclude_filters:
147154
for exclude_filter in exclude_filters:
148155
names = [n for n in names if not fnmatch(n, exclude_filter)]
149156

150157
if with_description:
151158
return [(name, self._optimizers[name].description) for name in names]
159+
152160
return names
153161

154162
def get_optimizer_info(self, name: str) -> OptimInfo:
@@ -718,11 +726,46 @@ def _register_default_optimizers() -> None:
718726
# Public API
719727

720728
def list_optimizers(
721-
filter: str = '',
729+
filter: Union[str, List[str]] = '',
722730
exclude_filters: Optional[List[str]] = None,
723731
with_description: bool = False,
724732
) -> List[Union[str, Tuple[str, str]]]:
725733
"""List available optimizer names, optionally filtered.
734+
735+
List all registered optimizers, with optional filtering using wildcard patterns.
736+
Optimizers can be filtered using include and exclude patterns, and can optionally
737+
return descriptions with each optimizer name.
738+
739+
Args:
740+
filter: Wildcard style filter string or list of filter strings
741+
(e.g., 'adam*' for all Adam variants, or ['adam*', '*8bit'] for
742+
Adam variants and 8-bit optimizers). Empty string means no filtering.
743+
exclude_filters: Optional list of wildcard patterns to exclude. For example,
744+
['*8bit', 'fused*'] would exclude 8-bit and fused implementations.
745+
with_description: If True, returns tuples of (name, description) instead of
746+
just names. Descriptions provide brief explanations of optimizer characteristics.
747+
748+
Returns:
749+
If with_description is False:
750+
List of optimizer names as strings (e.g., ['adam', 'adamw', ...])
751+
If with_description is True:
752+
List of tuples of (name, description) (e.g., [('adam', 'Adaptive Moment...'), ...])
753+
754+
Examples:
755+
>>> list_optimizers()
756+
['adam', 'adamw', 'sgd', ...]
757+
758+
>>> list_optimizers(['la*', 'nla*']) # List lamb & lars
759+
['lamb', 'lambc', 'larc', 'lars', 'nlarc', 'nlars']
760+
761+
>>> list_optimizers('*adam*', exclude_filters=['bnb*', 'fused*']) # Exclude bnb & apex adam optimizers
762+
['adam', 'adamax', 'adamp', 'adamw', 'nadam', 'nadamw', 'radam']
763+
764+
>>> list_optimizers(with_description=True) # Get descriptions
765+
[('adabelief', 'Adapts learning rate based on gradient prediction error'),
766+
('adadelta', 'torch.optim Adadelta, Adapts learning rates based on running windows of gradients'),
767+
('adafactor', 'Memory-efficient implementation of Adam with factored gradients'),
768+
...]
726769
"""
727770
return default_registry.list_optimizers(filter, exclude_filters, with_description)
728771

@@ -731,7 +774,38 @@ def get_optimizer_class(
731774
name: str,
732775
bind_defaults: bool = False,
733776
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
734-
"""Get optimizer class by name with any defaults applied.
777+
"""Get optimizer class by name with option to bind default arguments.
778+
779+
Retrieves the optimizer class or a partial function with default arguments bound.
780+
This allows direct instantiation of optimizers with their default configurations
781+
without going through the full factory.
782+
783+
Args:
784+
name: Name of the optimizer to retrieve (e.g., 'adam', 'sgd')
785+
bind_defaults: If True, returns a partial function with default arguments from OptimInfo bound.
786+
If False, returns the raw optimizer class.
787+
788+
Returns:
789+
If bind_defaults is False:
790+
The optimizer class (e.g., torch.optim.Adam)
791+
If bind_defaults is True:
792+
A partial function with default arguments bound
793+
794+
Raises:
795+
ValueError: If optimizer name is not found in registry
796+
797+
Examples:
798+
>>> # Get raw optimizer class
799+
>>> Adam = get_optimizer_class('adam')
800+
>>> opt = Adam(model.parameters(), lr=1e-3)
801+
802+
>>> # Get optimizer with defaults bound
803+
>>> AdamWithDefaults = get_optimizer_class('adam', bind_defaults=True)
804+
>>> opt = AdamWithDefaults(model.parameters(), lr=1e-3)
805+
806+
>>> # Get SGD with nesterov momentum default
807+
>>> SGD = get_optimizer_class('sgd', bind_defaults=True) # nesterov=True bound
808+
>>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9)
735809
"""
736810
return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)
737811

@@ -748,7 +822,69 @@ def create_optimizer_v2(
748822
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
749823
**kwargs: Any,
750824
) -> optim.Optimizer:
751-
"""Create an optimizer instance using the default registry."""
825+
"""Create an optimizer instance via timm registry.
826+
827+
Creates and configures an optimizer with appropriate parameter groups and settings.
828+
Supports automatic parameter group creation for weight decay and layer-wise learning
829+
rates, as well as custom parameter grouping.
830+
831+
Args:
832+
model_or_params: A PyTorch model or an iterable of parameters/parameter groups.
833+
If a model is provided, parameters will be automatically extracted and grouped
834+
based on the other arguments.
835+
opt: Name of the optimizer to create (e.g., 'adam', 'adamw', 'sgd').
836+
Use list_optimizers() to see available options.
837+
lr: Learning rate. If None, will use the optimizer's default.
838+
weight_decay: Weight decay factor. Will be used to create param groups if model_or_params is a model.
839+
momentum: Momentum factor for optimizers that support it. Only used if the
840+
chosen optimizer accepts a momentum parameter.
841+
foreach: Enable/disable foreach (multi-tensor) implementation if available.
842+
If None, will use optimizer-specific defaults.
843+
filter_bias_and_bn: If True, bias, norm layer parameters (all 1d params) will not have
844+
weight decay applied. Only used when model_or_params is a model and
845+
weight_decay > 0.
846+
layer_decay: Optional layer-wise learning rate decay factor. If provided,
847+
learning rates will be scaled by layer_decay^(max_depth - layer_depth).
848+
Only used when model_or_params is a model.
849+
param_group_fn: Optional function to create custom parameter groups.
850+
If provided, other parameter grouping options will be ignored.
851+
**kwargs: Additional optimizer-specific arguments (e.g., betas for Adam).
852+
853+
Returns:
854+
Configured optimizer instance.
855+
856+
Examples:
857+
>>> # Basic usage with a model
858+
>>> optimizer = create_optimizer_v2(model, 'adamw', lr=1e-3)
859+
860+
>>> # SGD with momentum and weight decay
861+
>>> optimizer = create_optimizer_v2(
862+
... model, 'sgd', lr=0.1, momentum=0.9, weight_decay=1e-4
863+
... )
864+
865+
>>> # Adam with layer-wise learning rate decay
866+
>>> optimizer = create_optimizer_v2(
867+
... model, 'adam', lr=1e-3, layer_decay=0.7
868+
... )
869+
870+
>>> # Custom parameter groups
871+
>>> def group_fn(model):
872+
... return [
873+
... {'params': model.backbone.parameters(), 'lr': 1e-4},
874+
... {'params': model.head.parameters(), 'lr': 1e-3}
875+
... ]
876+
>>> optimizer = create_optimizer_v2(
877+
... model, 'sgd', param_group_fn=group_fn
878+
... )
879+
880+
Note:
881+
Parameter group handling precedence:
882+
1. If param_group_fn is provided, it will be used exclusively
883+
2. If layer_decay is provided, layer-wise groups will be created
884+
3. If weight_decay > 0 and filter_bias_and_bn is True, weight decay groups will be created
885+
4. Otherwise, all parameters will be in a single group
886+
"""
887+
752888
return default_registry.create_optimizer(
753889
model_or_params,
754890
opt=opt,

0 commit comments

Comments
 (0)