Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions physicsnemo/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Optimizer utilities for PhysicsNeMo."""

from physicsnemo.optim.combined_optimizer import CombinedOptimizer

__all__ = ["CombinedOptimizer"]
282 changes: 282 additions & 0 deletions physicsnemo/optim/combined_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Sequence

import torch
from torch.optim import Optimizer


class CombinedOptimizer(Optimizer):
r"""Combine multiple PyTorch optimizers into a single Optimizer-like interface.

This wrapper allows you to use different optimizers for different parts of a model
while presenting a unified interface compatible with PyTorch's training loops and
learning rate schedulers. The ``param_groups`` from all contained optimizers are
concatenated, enabling schedulers to operate transparently across all parameters.

Parameters
----------
optimizers : Sequence[torch.optim.Optimizer]
Sequence of PyTorch Optimizer instances to combine. Each optimizer
should already be configured with its own parameters and hyperparameters.
Must contain at least one optimizer.
torch_compile_kwargs : dict[str, Any], optional
Optional dictionary of keyword arguments to pass to ``torch.compile()``
when compiling each optimizer's step function. If None, step functions
are not compiled. Compiling can improve performance but may affect
serialization. Default is None.

Raises
------
ValueError
If ``optimizers`` is empty.

Notes
-----
* **Parameter Groups**: The ``param_groups`` attribute aggregates parameter
groups from all underlying optimizers, making this wrapper compatible with
learning rate schedulers.
* **Closure Behavior**: When ``step()`` is called with a closure, the closure
is evaluated once (not once per optimizer) and the resulting loss is passed
to all optimizers. This differs from calling each optimizer's step separately.
* **Dynamic Parameter Addition**: The ``add_param_group()`` method is not
supported. To add parameters dynamically, add them to the individual
optimizers before creating the CombinedOptimizer, or create a new instance.
* **State Access**: The ``state`` attribute inherited from the base class may
not accurately reflect the optimizer state. Access state through the
individual optimizers in the ``optimizers`` attribute instead.
* **Serialization**: The optimizer can be pickled and unpickled. When
``torch_compile_kwargs`` is provided, the compiled step functions are
reconstructed during unpickling.

Examples
--------
Combine Adam for model backbone and SGD for the head:

>>> import torch
>>> import torch.nn as nn
>>> from torch.optim import Adam, SGD
>>> from physicsnemo.optim import CombinedOptimizer
>>>
>>> model = nn.Sequential(
... nn.Linear(10, 20), # backbone
... nn.ReLU(),
... nn.Linear(20, 2), # head
... )
>>> backbone_params = list(model[0].parameters())
>>> head_params = list(model[2].parameters())
>>>
>>> opt1 = Adam(backbone_params, lr=1e-4)
>>> opt2 = SGD(head_params, lr=1e-2, momentum=0.9)
>>> combined_opt = CombinedOptimizer([opt1, opt2])
>>>
>>> # Use with a learning rate scheduler
>>> scheduler = torch.optim.lr_scheduler.StepLR(combined_opt, step_size=10)
>>>
>>> # Standard training loop
>>> for epoch in range(100):
... combined_opt.zero_grad()
... loss = model(torch.randn(32, 10)).sum()
... loss.backward()
... combined_opt.step()
... scheduler.step()
"""

def __init__(
self,
optimizers: Sequence[Optimizer],
torch_compile_kwargs: dict[str, Any] | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This syntax is not upstream, right? I assume we can't put step into a compile wrapper since that will break the closure behavior?

):
if not optimizers:
raise ValueError("`optimizers` must contain at least one optimizer.")

self.optimizers = optimizers
self._torch_compile_kwargs = torch_compile_kwargs

### Aggregate parameter groups from all optimizers
# We pass an empty defaults dict because hyperparameters are managed by
# the individual optimizers, not this wrapper.
param_groups = [g for opt in optimizers for g in opt.param_groups]

# Flag to allow add_param_group during initialization
self._initializing = True
try:
super().__init__(param_groups, defaults={})
finally:
self._initializing = False

### Setup step functions (optionally compiled)
if torch_compile_kwargs is None:
self.step_fns: list[Callable] = [opt.step for opt in optimizers]
else:
self.step_fns: list[Callable] = [
torch.compile(opt.step, **torch_compile_kwargs) for opt in optimizers
]

def zero_grad(self, *args, **kwargs) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upstream interface accepts only set_to_none=True here, and not other parameters. I know if we passed it, that would succeed, but unless there are optimizers accepting other vales for zero_grad I think we should stick to set_to_none.

r"""Clear the gradients of all optimized parameters.

This method delegates to the ``zero_grad()`` method of each underlying
optimizer, passing through all arguments and keyword arguments.

Parameters
----------
*args
Positional arguments to pass to each optimizer's ``zero_grad()``.
**kwargs
Keyword arguments to pass to each optimizer's ``zero_grad()``.
Common kwargs include ``set_to_none`` (bool).
"""
for opt in self.optimizers:
opt.zero_grad(*args, **kwargs)

def step(
self, closure: Callable[[], torch.Tensor] | None = None
) -> torch.Tensor | None:
Comment on lines +149 to +151
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upstream returns float | None, not tensor, FYI, in both the closure and step:

https://docs.pytorch.org/docs/stable/generated/torch.optim.Optimizer.step.html#torch.optim.Optimizer.step

r"""Perform a single optimization step.

This method calls the ``step()`` method of each underlying optimizer. If a
closure is provided, it is passed to each optimizer.

Parameters
----------
closure : Callable[[], torch.Tensor], optional
Optional callable that reevaluates the model and returns the loss.
If provided, it will be passed to each optimizer's step function.
Default is None.

Returns
-------
torch.Tensor or None
The loss value returned by the last optimizer, or None if no closure was provided.
"""
loss = None
for step_fn in self.step_fns:
if closure is None:
step_fn()
else:
res = step_fn(closure)
if res is not None:
loss = res
Comment on lines +174 to +176
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow this logic - if there are multiple optimizers and a closure, are they expected to return the same value? Right now we overwrite with the value of the last non-None res.


return loss

def add_param_group(self, param_group: dict[str, Any]) -> None:
r"""Add a param group to the Optimizer's param_groups.

This method is not supported for CombinedOptimizer as it would require
logic to determine which underlying optimizer should handle the new group.

Parameters
----------
param_group : dict[str, Any]
The parameter group to add.

Raises
------
NotImplementedError
Always raises NotImplementedError unless called during initialization.
"""
if getattr(self, "_initializing", False):
super().add_param_group(param_group)
return

raise NotImplementedError(
"CombinedOptimizer does not support add_param_group() after initialization, "
"since it is ambiguous which optimizer should handle the new group.\n"
"Add parameters to the underlying optimizers before creating the CombinedOptimizer."
)

def state_dict(self) -> dict[str, Any]:
r"""Return the state of all optimizers as a dictionary.

The returned dictionary contains the state dictionaries of all underlying
optimizers, allowing the combined optimizer to be checkpointed and restored.

Returns
-------
dict[str, Any]
A dictionary with a single key ``"optimizers"`` mapping to a list of
state dictionaries, one for each underlying optimizer in order.

Examples
--------
>>> combined_opt = CombinedOptimizer([opt1, opt2])
>>> state = combined_opt.state_dict()
>>> # state = {"optimizers": [opt1.state_dict(), opt2.state_dict()]}
"""
return {"optimizers": [opt.state_dict() for opt in self.optimizers]}

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
r"""Load the state of all optimizers from a dictionary.

This method restores the state of each underlying optimizer from the provided
state dictionary. The state dictionary must have been created by
``state_dict()`` from a CombinedOptimizer with the same number of optimizers.

Parameters
----------
state_dict : dict[str, Any]
A dictionary containing optimizer states, as returned by
``state_dict()``. Must contain an ``"optimizers"`` key mapping to
a list of state dictionaries.

Raises
------
ValueError
If the number of optimizers in ``state_dict`` does not match
the number of optimizers in this instance.
KeyError
If ``state_dict`` does not contain the expected structure.

Notes
-----
After loading state, the ``param_groups`` attribute is refreshed to
reflect any changes in the underlying optimizers.
"""
### Validate state dict structure
if "optimizers" not in state_dict:
raise KeyError(
"Expected state_dict to contain 'optimizers' key, "
f"but got keys: {list(state_dict.keys())}"
)

optimizer_states = state_dict["optimizers"]
if len(optimizer_states) != len(self.optimizers):
raise ValueError(
f"State dict contains {len(optimizer_states)} optimizer(s), "
f"but this CombinedOptimizer has {len(self.optimizers)} optimizer(s). "
"Cannot load state from a different optimizer configuration."
)

### Load state into each underlying optimizer
for opt, sd in zip(self.optimizers, optimizer_states):
opt.load_state_dict(sd)

### Refresh param_groups to reflect any changes
self.param_groups = [g for opt in self.optimizers for g in opt.param_groups]

def __repr__(self) -> str:
r"""Return a string representation of the CombinedOptimizer.

Returns
-------
str
A string showing the optimizer types being combined.
"""
optimizer_types = [opt.__class__.__name__ for opt in self.optimizers]
return f"CombinedOptimizer({', '.join(optimizer_types)})"
Loading