Skip to content

Conversation

@peterdsharpe
Copy link
Collaborator

@peterdsharpe peterdsharpe commented Nov 20, 2025

PhysicsNeMo Pull Request

Adds a new CombinedOptimizer utility, which is useful for the increasingly-popular "architecture-aware optimizers", such as Muon.

This PR targets the v2.0 refactor branch, so this should only be merged after #1235 .

Description

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

This commit introduces the CombinedOptimizer class, which allows users to combine multiple PyTorch optimizers into a single interface. The new class supports dynamic parameter addition, state management, and optional compilation of step functions for performance improvements. It also includes detailed documentation and examples for ease of use.
This commit introduces an `__init__.py` file for the optimizer module, providing a clear entry point for the optimizer utilities in PhysicsNeMo. Additionally, it enhances the `CombinedOptimizer` class by implementing a flag for initialization, allowing the addition of parameter groups during initialization. The `step` method is updated to return the loss value from the last optimizer, improving its usability. Documentation and comments have been added for clarity.
This commit introduces a comprehensive test suite for the CombinedOptimizer class, covering initialization, parameter aggregation, step execution, and state management. The tests ensure that the CombinedOptimizer correctly integrates multiple optimizers, handles closures, and maintains expected behavior during training. Additionally, it verifies the proper functioning of learning rate schedulers with the combined optimizer. This enhances the reliability and robustness of the optimizer's implementation.
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 20, 2025

Greptile Overview

Greptile Summary

Adds CombinedOptimizer utility that wraps multiple PyTorch optimizers into a unified interface for architecture-aware optimization strategies. The implementation provides proper state management, serialization, and learning rate scheduler compatibility.

Critical Issue Found:

  • The documentation claims closures are evaluated once (line 53-55), but the implementation passes the closure directly to each optimizer, causing multiple evaluations. This leads to redundant forward/backward passes and conflicts with optimizers like LBFGS that may call the closure multiple times internally.
  • The test suite validates this incorrect behavior (test_combined_optimizer.py:158-174), which needs alignment with the intended design.

Recommendations:

  • Decide on the desired closure behavior (single evaluation vs. per-optimizer evaluation)
  • Update either the implementation or documentation to match
  • Adjust the corresponding test case to validate the correct behavior

Important Files Changed

File Analysis

Filename Score Overview
physicsnemo/optim/combined_optimizer.py 2/5 New CombinedOptimizer wrapper for multiple optimizers. Critical bug: docstring claims closure evaluated once but implementation evaluates it per optimizer
test/optim/test_combined_optimizer.py 3/5 Comprehensive test coverage for CombinedOptimizer. Tests validate incorrect behavior (multiple closure calls) that contradicts documentation

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

… that the closure is evaluated multiple times, matching individual optimizer behavior.
Copy link
Collaborator

@coreyjadams coreyjadams left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks good. I appreciate extensive testing.

This adds a new package, physicsnemo.optim, which I think is overdue. As part of the refactor let's make sure to get it in the docs too :).

I left some comments but do have another question. What happens when users instantiate and use this and the parameter groups are not disjoint? Will it cause an error? Silent bugs? Should we include a check that each parameter group is completely disjoint?

torch.compile(opt.step, **torch_compile_kwargs) for opt in optimizers
]

def zero_grad(self, *args, **kwargs) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upstream interface accepts only set_to_none=True here, and not other parameters. I know if we passed it, that would succeed, but unless there are optimizers accepting other vales for zero_grad I think we should stick to set_to_none.

def __init__(
self,
optimizers: Sequence[Optimizer],
torch_compile_kwargs: dict[str, Any] | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This syntax is not upstream, right? I assume we can't put step into a compile wrapper since that will break the closure behavior?

Comment on lines +149 to +151
def step(
self, closure: Callable[[], torch.Tensor] | None = None
) -> torch.Tensor | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upstream returns float | None, not tensor, FYI, in both the closure and step:

https://docs.pytorch.org/docs/stable/generated/torch.optim.Optimizer.step.html#torch.optim.Optimizer.step

Comment on lines +174 to +176
res = step_fn(closure)
if res is not None:
loss = res
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow this logic - if there are multiple optimizers and a closure, are they expected to return the same value? Right now we overwrite with the value of the last non-None res.

return CombinedOptimizer(optimizers)


class TestInitialization:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of our other tests are more or less "functionals" and not class-based tests. I don't personally have any objection to this style, but I don't know what unforeseen consequences we might see with this. Any thoughts?

Comment on lines +230 to +252
def test_state_dict_structure(self, combined_optimizer):
state = combined_optimizer.state_dict()
assert "optimizers" in state
assert len(state["optimizers"]) == 2
assert isinstance(state["optimizers"][0], dict)

def test_load_state_dict(self, combined_optimizer, model):
# Save state
state = combined_optimizer.state_dict()

# Create new optimizer
opt1 = SGD(model.layer1.parameters(), lr=0.01)
opt2 = Adam(model.layer2.parameters(), lr=0.001)
new_combined = CombinedOptimizer([opt1, opt2])

# Load state
new_combined.load_state_dict(state)

# Verify equality (basic check)
# Note: strict equality of state dicts might fail due to weak refs or unrelated keys,
# so we check structure matches.
new_state = new_combined.state_dict()
assert len(new_state["optimizers"]) == len(state["optimizers"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update this test to check save / restore is getting numerics correct too? Loading the wrong optimizer state or getting the numbers incorrect would be a pretty frustrating bug in training codes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants