Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `CombinedOptimizer` to `physicsnemo.optim` to combine multiple optimizers into a single interface.
- Added mixture_of_experts for weather example in physicsnemo.examples.weather.
**⚠️Warning:** - It uses experimental DiT model subject to future API changes.
Added some modifications to DiT architecture in physicsnemo.experimental.models.dit.
Expand Down
5 changes: 5 additions & 0 deletions physicsnemo/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Optimizer utilities for PhysicsNeMo."""

from physicsnemo.optim.combined_optimizer import CombinedOptimizer

__all__ = ["CombinedOptimizer"]
284 changes: 284 additions & 0 deletions physicsnemo/optim/combined_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Sequence

import torch
from torch.optim import Optimizer


class CombinedOptimizer(Optimizer):
r"""Combine multiple PyTorch optimizers into a single Optimizer-like interface.

This wrapper allows you to use different optimizers for different parts of a model
while presenting a unified interface compatible with PyTorch's training loops and
learning rate schedulers. The ``param_groups`` from all contained optimizers are
concatenated, enabling schedulers to operate transparently across all parameters.

Parameters
----------
optimizers : Sequence[torch.optim.Optimizer]
Sequence of PyTorch Optimizer instances to combine. Each optimizer
should already be configured with its own parameters and hyperparameters.
Must contain at least one optimizer.
torch_compile_kwargs : dict[str, Any], optional
Optional dictionary of keyword arguments to pass to ``torch.compile()``
when compiling each optimizer's step function. If None, step functions
are not compiled. Compiling can improve performance but may affect
serialization. Default is None.

Raises
------
ValueError
If ``optimizers`` is empty.

Notes
-----
* **Parameter Groups**: The ``param_groups`` attribute aggregates parameter
groups from all underlying optimizers, making this wrapper compatible with
learning rate schedulers.
* **Closure Behavior**: When ``step()`` is called with a closure, the closure
is passed to each underlying optimizer sequentially. This results in the
closure being evaluated multiple times (at least once per optimizer), which
triggers multiple forward and backward passes. This behavior matches calling
``step(closure)`` on each optimizer individually.
* **Dynamic Parameter Addition**: The ``add_param_group()`` method is not
supported. To add parameters dynamically, add them to the individual
optimizers before creating the CombinedOptimizer, or create a new instance.
* **State Access**: The ``state`` attribute inherited from the base class may
not accurately reflect the optimizer state. Access state through the
individual optimizers in the ``optimizers`` attribute instead.
* **Serialization**: The optimizer can be pickled and unpickled. When
``torch_compile_kwargs`` is provided, the compiled step functions are
reconstructed during unpickling.

Examples
--------
Combine Adam for model backbone and SGD for the head:

>>> import torch
>>> import torch.nn as nn
>>> from torch.optim import Adam, SGD
>>> from physicsnemo.optim import CombinedOptimizer
>>>
>>> model = nn.Sequential(
... nn.Linear(10, 20), # backbone
... nn.ReLU(),
... nn.Linear(20, 2), # head
... )
>>> backbone_params = list(model[0].parameters())
>>> head_params = list(model[2].parameters())
>>>
>>> opt1 = Adam(backbone_params, lr=1e-4)
>>> opt2 = SGD(head_params, lr=1e-2, momentum=0.9)
>>> combined_opt = CombinedOptimizer([opt1, opt2])
>>>
>>> # Use with a learning rate scheduler
>>> scheduler = torch.optim.lr_scheduler.StepLR(combined_opt, step_size=10)
>>>
>>> # Standard training loop
>>> for epoch in range(100):
... combined_opt.zero_grad()
... loss = model(torch.randn(32, 10)).sum()
... loss.backward()
... combined_opt.step()
... scheduler.step()
"""

def __init__(
self,
optimizers: Sequence[Optimizer],
torch_compile_kwargs: dict[str, Any] | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This syntax is not upstream, right? I assume we can't put step into a compile wrapper since that will break the closure behavior?

):
if not optimizers:
raise ValueError("`optimizers` must contain at least one optimizer.")

self.optimizers = optimizers
self._torch_compile_kwargs = torch_compile_kwargs

### Aggregate parameter groups from all optimizers
# We pass an empty defaults dict because hyperparameters are managed by
# the individual optimizers, not this wrapper.
param_groups = [g for opt in optimizers for g in opt.param_groups]

# Flag to allow add_param_group during initialization
self._initializing = True
try:
super().__init__(param_groups, defaults={})
finally:
self._initializing = False

### Setup step functions (optionally compiled)
if torch_compile_kwargs is None:
self.step_fns: list[Callable] = [opt.step for opt in optimizers]
else:
self.step_fns: list[Callable] = [
torch.compile(opt.step, **torch_compile_kwargs) for opt in optimizers
]

def zero_grad(self, *args, **kwargs) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upstream interface accepts only set_to_none=True here, and not other parameters. I know if we passed it, that would succeed, but unless there are optimizers accepting other vales for zero_grad I think we should stick to set_to_none.

r"""Clear the gradients of all optimized parameters.

This method delegates to the ``zero_grad()`` method of each underlying
optimizer, passing through all arguments and keyword arguments.

Parameters
----------
*args
Positional arguments to pass to each optimizer's ``zero_grad()``.
**kwargs
Keyword arguments to pass to each optimizer's ``zero_grad()``.
Common kwargs include ``set_to_none`` (bool).
"""
for opt in self.optimizers:
opt.zero_grad(*args, **kwargs)

def step(
self, closure: Callable[[], torch.Tensor] | None = None
) -> torch.Tensor | None:
Comment on lines +149 to +151
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upstream returns float | None, not tensor, FYI, in both the closure and step:

https://docs.pytorch.org/docs/stable/generated/torch.optim.Optimizer.step.html#torch.optim.Optimizer.step

r"""Perform a single optimization step.

This method calls the ``step()`` method of each underlying optimizer. If a
closure is provided, it is passed to each optimizer.

Parameters
----------
closure : Callable[[], torch.Tensor], optional
Optional callable that reevaluates the model and returns the loss.
If provided, it will be passed to each optimizer's step function.
Default is None.

Returns
-------
torch.Tensor or None
The loss value returned by the last optimizer, or None if no closure was provided.
"""
loss = None
for step_fn in self.step_fns:
if closure is None:
step_fn()
else:
res = step_fn(closure)
if res is not None:
loss = res
Comment on lines +174 to +176
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow this logic - if there are multiple optimizers and a closure, are they expected to return the same value? Right now we overwrite with the value of the last non-None res.


return loss

def add_param_group(self, param_group: dict[str, Any]) -> None:
r"""Add a param group to the Optimizer's param_groups.

This method is not supported for CombinedOptimizer as it would require
logic to determine which underlying optimizer should handle the new group.

Parameters
----------
param_group : dict[str, Any]
The parameter group to add.

Raises
------
NotImplementedError
Always raises NotImplementedError unless called during initialization.
"""
if getattr(self, "_initializing", False):
super().add_param_group(param_group)
return

raise NotImplementedError(
"CombinedOptimizer does not support add_param_group() after initialization, "
"since it is ambiguous which optimizer should handle the new group.\n"
"Add parameters to the underlying optimizers before creating the CombinedOptimizer."
)

def state_dict(self) -> dict[str, Any]:
r"""Return the state of all optimizers as a dictionary.

The returned dictionary contains the state dictionaries of all underlying
optimizers, allowing the combined optimizer to be checkpointed and restored.

Returns
-------
dict[str, Any]
A dictionary with a single key ``"optimizers"`` mapping to a list of
state dictionaries, one for each underlying optimizer in order.

Examples
--------
>>> combined_opt = CombinedOptimizer([opt1, opt2])
>>> state = combined_opt.state_dict()
>>> # state = {"optimizers": [opt1.state_dict(), opt2.state_dict()]}
"""
return {"optimizers": [opt.state_dict() for opt in self.optimizers]}

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
r"""Load the state of all optimizers from a dictionary.

This method restores the state of each underlying optimizer from the provided
state dictionary. The state dictionary must have been created by
``state_dict()`` from a CombinedOptimizer with the same number of optimizers.

Parameters
----------
state_dict : dict[str, Any]
A dictionary containing optimizer states, as returned by
``state_dict()``. Must contain an ``"optimizers"`` key mapping to
a list of state dictionaries.

Raises
------
ValueError
If the number of optimizers in ``state_dict`` does not match
the number of optimizers in this instance.
KeyError
If ``state_dict`` does not contain the expected structure.

Notes
-----
After loading state, the ``param_groups`` attribute is refreshed to
reflect any changes in the underlying optimizers.
"""
### Validate state dict structure
if "optimizers" not in state_dict:
raise KeyError(
"Expected state_dict to contain 'optimizers' key, "
f"but got keys: {list(state_dict.keys())}"
)

optimizer_states = state_dict["optimizers"]
if len(optimizer_states) != len(self.optimizers):
raise ValueError(
f"State dict contains {len(optimizer_states)} optimizer(s), "
f"but this CombinedOptimizer has {len(self.optimizers)} optimizer(s). "
"Cannot load state from a different optimizer configuration."
)

### Load state into each underlying optimizer
for opt, sd in zip(self.optimizers, optimizer_states):
opt.load_state_dict(sd)

### Refresh param_groups to reflect any changes
self.param_groups = [g for opt in self.optimizers for g in opt.param_groups]

def __repr__(self) -> str:
r"""Return a string representation of the CombinedOptimizer.

Returns
-------
str
A string showing the optimizer types being combined.
"""
optimizer_types = [opt.__class__.__name__ for opt in self.optimizers]
return f"CombinedOptimizer({', '.join(optimizer_types)})"
Loading