Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions emerging_optimizers/orthogonalized_optimizers/muon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,35 @@ def newton_schulz_tp(
tp_group: The process group for communication if input is distributed.
mode: The mode to use for the Newton-Schulz iteration.
"""

# Placeholder for 3D conv1d case
original_3d_shape = None

if partition_dim is None:
# Fallback path for non TP params.
# Handle 3D conv1d case
if x.dim() == 3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the right place to add this logic.
the function still handles 2d input, reshape logic should be outside of this function. probably should be in the Megatron inherited muon

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is the same as unfusing QKV: in the optimizer state, there param shape will always be 3D. We would need to add a handle in the OrthogonalizedOptimizer class

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, all of our functions handles 2d.
Logic of making a 3d (or nd for that matter) 2d should be out of it.

original_3d_shape = x.shape
x = x.reshape(-1, x.size(-1))
output = newton_schulz(x, steps, coefficient_type)
return output.reshape(original_3d_shape)
return newton_schulz(x, steps, coefficient_type)

kwargs: Any = {
"steps": steps,
"coefficient_type": coefficient_type,
}

if x.dim() == 3:
is_3d_conv1d = True
else:
is_3d_conv1d = False

if is_3d_conv1d:
# merge all input channels into the last dimension
original_3d_shape = x.shape
x = x.reshape(-1, x.size(-1))

if mode == "duplicated":
x_shards = [torch.empty_like(x) for _ in range(tp_group.size())]
torch.distributed.all_gather(x_shards, x, tp_group)
Expand All @@ -223,6 +243,10 @@ def newton_schulz_tp(
else:
raise ValueError(f"Invalid mode: {mode}")

if is_3d_conv1d:
# reshape back to the original 3D shape, separate orthogonalized channels
output = output.reshape(original_3d_shape)

return output


Expand Down