-
Notifications
You must be signed in to change notification settings - Fork 498
Adds CombinedOptimizer
#1241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v2.0-refactor
Are you sure you want to change the base?
Adds CombinedOptimizer
#1241
Changes from 5 commits
1504e8a
d07fb2b
6302cd3
02a8f4d
a8817a2
0176b00
c657c5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| """Optimizer utilities for PhysicsNeMo.""" | ||
|
|
||
| from physicsnemo.optim.combined_optimizer import CombinedOptimizer | ||
|
|
||
| __all__ = ["CombinedOptimizer"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,282 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | ||
| # SPDX-FileCopyrightText: All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Any, Callable, Sequence | ||
|
|
||
| import torch | ||
| from torch.optim import Optimizer | ||
|
|
||
|
|
||
| class CombinedOptimizer(Optimizer): | ||
| r"""Combine multiple PyTorch optimizers into a single Optimizer-like interface. | ||
|
|
||
| This wrapper allows you to use different optimizers for different parts of a model | ||
| while presenting a unified interface compatible with PyTorch's training loops and | ||
| learning rate schedulers. The ``param_groups`` from all contained optimizers are | ||
| concatenated, enabling schedulers to operate transparently across all parameters. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| optimizers : Sequence[torch.optim.Optimizer] | ||
| Sequence of PyTorch Optimizer instances to combine. Each optimizer | ||
| should already be configured with its own parameters and hyperparameters. | ||
| Must contain at least one optimizer. | ||
| torch_compile_kwargs : dict[str, Any], optional | ||
| Optional dictionary of keyword arguments to pass to ``torch.compile()`` | ||
| when compiling each optimizer's step function. If None, step functions | ||
| are not compiled. Compiling can improve performance but may affect | ||
| serialization. Default is None. | ||
|
|
||
| Raises | ||
| ------ | ||
| ValueError | ||
| If ``optimizers`` is empty. | ||
|
|
||
| Notes | ||
| ----- | ||
| * **Parameter Groups**: The ``param_groups`` attribute aggregates parameter | ||
| groups from all underlying optimizers, making this wrapper compatible with | ||
| learning rate schedulers. | ||
| * **Closure Behavior**: When ``step()`` is called with a closure, the closure | ||
| is evaluated once (not once per optimizer) and the resulting loss is passed | ||
| to all optimizers. This differs from calling each optimizer's step separately. | ||
| * **Dynamic Parameter Addition**: The ``add_param_group()`` method is not | ||
| supported. To add parameters dynamically, add them to the individual | ||
| optimizers before creating the CombinedOptimizer, or create a new instance. | ||
| * **State Access**: The ``state`` attribute inherited from the base class may | ||
| not accurately reflect the optimizer state. Access state through the | ||
| individual optimizers in the ``optimizers`` attribute instead. | ||
| * **Serialization**: The optimizer can be pickled and unpickled. When | ||
| ``torch_compile_kwargs`` is provided, the compiled step functions are | ||
| reconstructed during unpickling. | ||
|
|
||
| Examples | ||
| -------- | ||
| Combine Adam for model backbone and SGD for the head: | ||
|
|
||
| >>> import torch | ||
| >>> import torch.nn as nn | ||
| >>> from torch.optim import Adam, SGD | ||
| >>> from physicsnemo.optim import CombinedOptimizer | ||
| >>> | ||
| >>> model = nn.Sequential( | ||
| ... nn.Linear(10, 20), # backbone | ||
| ... nn.ReLU(), | ||
| ... nn.Linear(20, 2), # head | ||
| ... ) | ||
| >>> backbone_params = list(model[0].parameters()) | ||
| >>> head_params = list(model[2].parameters()) | ||
| >>> | ||
| >>> opt1 = Adam(backbone_params, lr=1e-4) | ||
| >>> opt2 = SGD(head_params, lr=1e-2, momentum=0.9) | ||
| >>> combined_opt = CombinedOptimizer([opt1, opt2]) | ||
| >>> | ||
| >>> # Use with a learning rate scheduler | ||
| >>> scheduler = torch.optim.lr_scheduler.StepLR(combined_opt, step_size=10) | ||
| >>> | ||
| >>> # Standard training loop | ||
| >>> for epoch in range(100): | ||
| ... combined_opt.zero_grad() | ||
| ... loss = model(torch.randn(32, 10)).sum() | ||
| ... loss.backward() | ||
| ... combined_opt.step() | ||
| ... scheduler.step() | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| optimizers: Sequence[Optimizer], | ||
| torch_compile_kwargs: dict[str, Any] | None = None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This syntax is not upstream, right? I assume we can't put |
||
| ): | ||
| if not optimizers: | ||
| raise ValueError("`optimizers` must contain at least one optimizer.") | ||
|
|
||
| self.optimizers = optimizers | ||
| self._torch_compile_kwargs = torch_compile_kwargs | ||
|
|
||
| ### Aggregate parameter groups from all optimizers | ||
| # We pass an empty defaults dict because hyperparameters are managed by | ||
| # the individual optimizers, not this wrapper. | ||
| param_groups = [g for opt in optimizers for g in opt.param_groups] | ||
|
|
||
| # Flag to allow add_param_group during initialization | ||
| self._initializing = True | ||
| try: | ||
| super().__init__(param_groups, defaults={}) | ||
| finally: | ||
| self._initializing = False | ||
|
|
||
| ### Setup step functions (optionally compiled) | ||
| if torch_compile_kwargs is None: | ||
| self.step_fns: list[Callable] = [opt.step for opt in optimizers] | ||
| else: | ||
| self.step_fns: list[Callable] = [ | ||
| torch.compile(opt.step, **torch_compile_kwargs) for opt in optimizers | ||
| ] | ||
|
|
||
| def zero_grad(self, *args, **kwargs) -> None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Upstream interface accepts only |
||
| r"""Clear the gradients of all optimized parameters. | ||
|
|
||
| This method delegates to the ``zero_grad()`` method of each underlying | ||
| optimizer, passing through all arguments and keyword arguments. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| *args | ||
| Positional arguments to pass to each optimizer's ``zero_grad()``. | ||
| **kwargs | ||
| Keyword arguments to pass to each optimizer's ``zero_grad()``. | ||
| Common kwargs include ``set_to_none`` (bool). | ||
| """ | ||
| for opt in self.optimizers: | ||
| opt.zero_grad(*args, **kwargs) | ||
|
|
||
| def step( | ||
| self, closure: Callable[[], torch.Tensor] | None = None | ||
| ) -> torch.Tensor | None: | ||
|
Comment on lines
+149
to
+151
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Upstream returns float | None, not tensor, FYI, in both the closure and |
||
| r"""Perform a single optimization step. | ||
|
|
||
| This method calls the ``step()`` method of each underlying optimizer. If a | ||
| closure is provided, it is passed to each optimizer. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| closure : Callable[[], torch.Tensor], optional | ||
| Optional callable that reevaluates the model and returns the loss. | ||
| If provided, it will be passed to each optimizer's step function. | ||
| Default is None. | ||
|
|
||
| Returns | ||
| ------- | ||
| torch.Tensor or None | ||
| The loss value returned by the last optimizer, or None if no closure was provided. | ||
| """ | ||
| loss = None | ||
| for step_fn in self.step_fns: | ||
| if closure is None: | ||
| step_fn() | ||
| else: | ||
| res = step_fn(closure) | ||
| if res is not None: | ||
| loss = res | ||
|
Comment on lines
+174
to
+176
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't follow this logic - if there are multiple optimizers and a closure, are they expected to return the same value? Right now we overwrite with the value of the last non-None |
||
|
|
||
| return loss | ||
|
|
||
| def add_param_group(self, param_group: dict[str, Any]) -> None: | ||
| r"""Add a param group to the Optimizer's param_groups. | ||
|
|
||
| This method is not supported for CombinedOptimizer as it would require | ||
| logic to determine which underlying optimizer should handle the new group. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| param_group : dict[str, Any] | ||
| The parameter group to add. | ||
|
|
||
| Raises | ||
| ------ | ||
| NotImplementedError | ||
| Always raises NotImplementedError unless called during initialization. | ||
| """ | ||
| if getattr(self, "_initializing", False): | ||
| super().add_param_group(param_group) | ||
| return | ||
|
|
||
| raise NotImplementedError( | ||
| "CombinedOptimizer does not support add_param_group() after initialization, " | ||
| "since it is ambiguous which optimizer should handle the new group.\n" | ||
| "Add parameters to the underlying optimizers before creating the CombinedOptimizer." | ||
| ) | ||
|
|
||
| def state_dict(self) -> dict[str, Any]: | ||
| r"""Return the state of all optimizers as a dictionary. | ||
|
|
||
| The returned dictionary contains the state dictionaries of all underlying | ||
| optimizers, allowing the combined optimizer to be checkpointed and restored. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict[str, Any] | ||
| A dictionary with a single key ``"optimizers"`` mapping to a list of | ||
| state dictionaries, one for each underlying optimizer in order. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> combined_opt = CombinedOptimizer([opt1, opt2]) | ||
| >>> state = combined_opt.state_dict() | ||
| >>> # state = {"optimizers": [opt1.state_dict(), opt2.state_dict()]} | ||
| """ | ||
| return {"optimizers": [opt.state_dict() for opt in self.optimizers]} | ||
|
|
||
| def load_state_dict(self, state_dict: dict[str, Any]) -> None: | ||
| r"""Load the state of all optimizers from a dictionary. | ||
|
|
||
| This method restores the state of each underlying optimizer from the provided | ||
| state dictionary. The state dictionary must have been created by | ||
| ``state_dict()`` from a CombinedOptimizer with the same number of optimizers. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| state_dict : dict[str, Any] | ||
| A dictionary containing optimizer states, as returned by | ||
| ``state_dict()``. Must contain an ``"optimizers"`` key mapping to | ||
| a list of state dictionaries. | ||
|
|
||
| Raises | ||
| ------ | ||
| ValueError | ||
| If the number of optimizers in ``state_dict`` does not match | ||
| the number of optimizers in this instance. | ||
| KeyError | ||
| If ``state_dict`` does not contain the expected structure. | ||
|
|
||
| Notes | ||
| ----- | ||
| After loading state, the ``param_groups`` attribute is refreshed to | ||
| reflect any changes in the underlying optimizers. | ||
| """ | ||
| ### Validate state dict structure | ||
| if "optimizers" not in state_dict: | ||
| raise KeyError( | ||
| "Expected state_dict to contain 'optimizers' key, " | ||
| f"but got keys: {list(state_dict.keys())}" | ||
| ) | ||
|
|
||
| optimizer_states = state_dict["optimizers"] | ||
| if len(optimizer_states) != len(self.optimizers): | ||
| raise ValueError( | ||
| f"State dict contains {len(optimizer_states)} optimizer(s), " | ||
| f"but this CombinedOptimizer has {len(self.optimizers)} optimizer(s). " | ||
| "Cannot load state from a different optimizer configuration." | ||
| ) | ||
|
|
||
| ### Load state into each underlying optimizer | ||
| for opt, sd in zip(self.optimizers, optimizer_states): | ||
| opt.load_state_dict(sd) | ||
|
|
||
| ### Refresh param_groups to reflect any changes | ||
| self.param_groups = [g for opt in self.optimizers for g in opt.param_groups] | ||
|
|
||
| def __repr__(self) -> str: | ||
| r"""Return a string representation of the CombinedOptimizer. | ||
|
|
||
| Returns | ||
| ------- | ||
| str | ||
| A string showing the optimizer types being combined. | ||
| """ | ||
| optimizer_types = [opt.__class__.__name__ for opt in self.optimizers] | ||
| return f"CombinedOptimizer({', '.join(optimizer_types)})" | ||
Uh oh!
There was an error while loading. Please reload this page.