diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py index 79bd4f2..b2460d6 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_utils.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_utils.py @@ -196,8 +196,18 @@ def newton_schulz_tp( tp_group: The process group for communication if input is distributed. mode: The mode to use for the Newton-Schulz iteration. """ + + # Placeholder for 3D conv1d case + original_3d_shape = None + if partition_dim is None: # Fallback path for non TP params. + # Handle 3D conv1d case + if x.dim() == 3: + original_3d_shape = x.shape + x = x.reshape(-1, x.size(-1)) + output = newton_schulz(x, steps, coefficient_type) + return output.reshape(original_3d_shape) return newton_schulz(x, steps, coefficient_type) kwargs: Any = { @@ -205,6 +215,16 @@ def newton_schulz_tp( "coefficient_type": coefficient_type, } + if x.dim() == 3: + is_3d_conv1d = True + else: + is_3d_conv1d = False + + if is_3d_conv1d: + # merge all input channels into the last dimension + original_3d_shape = x.shape + x = x.reshape(-1, x.size(-1)) + if mode == "duplicated": x_shards = [torch.empty_like(x) for _ in range(tp_group.size())] torch.distributed.all_gather(x_shards, x, tp_group) @@ -223,6 +243,10 @@ def newton_schulz_tp( else: raise ValueError(f"Invalid mode: {mode}") + if is_3d_conv1d: + # reshape back to the original 3D shape, separate orthogonalized channels + output = output.reshape(original_3d_shape) + return output