Skip to content

Address number of steps issue and be more explicit about the type of iteration#91

Closed
mkhona-nvidia wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
mkhona-nvidia:mkhona/ns_polar_express
Closed

Address number of steps issue and be more explicit about the type of iteration#91
mkhona-nvidia wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
mkhona-nvidia:mkhona/ns_polar_express

Conversation

@mkhona-nvidia
Copy link
Contributor

@mkhona-nvidia mkhona-nvidia commented Jan 27, 2026

Issue raises that the Polar Express paper recommends that for newton-schulz steps that are greater than 8, we repeat the last set of (a,b,c) coefficients. We previously required that number of steps provided be a multiple of coefficient sets so we could cyclically run the iteration. While not incorrect, it deviated from polar express' recommendation so this PR fixes that.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 27, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Signed-off-by: mikail <mkhona@nvidia.com>
@mkhona-nvidia mkhona-nvidia force-pushed the mkhona/ns_polar_express branch from be4c77c to 0001f6e Compare January 27, 2026 04:44
@mkhona-nvidia mkhona-nvidia requested a review from skyw January 27, 2026 04:46
@mkhona-nvidia mkhona-nvidia changed the title Address number of steps issue and be more explicit Address number of steps issue and be more explicit about the type of iteration Jan 27, 2026
@greptile-apps
Copy link

greptile-apps bot commented Jan 27, 2026

Greptile Overview

Greptile Summary

This PR updates the Newton-Schulz iteration logic to handle coefficient sets differently based on the coefficient_type. The key changes align with the Polar Express paper recommendation:

  • polar_express: Now requires steps >= 8 and repeats the last coefficient set (a,b,c) for any steps beyond 8, rather than requiring steps to be a multiple of 8
  • simple: Repeats the single coefficient set for all steps
  • quintic, aol, custom: Still require steps to be a multiple of the number of coefficient sets and cycle through them

The implementation converts coefficient sets to lists to enable extension via concatenation, then adjusts the list based on the coefficient type before the iteration loop. This removes the modulo operator from the iteration loop (line 168) in favor of direct indexing.

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk
  • The changes are well-structured and align with the stated goal of following the Polar Express paper's recommendation. The logic is clear with appropriate validation and error handling. Score reduced by 1 point due to lack of test coverage for the new polar_express behavior with steps > 8, though the logic appears correct.
  • No files require special attention

Important Files Changed

Filename Overview
emerging_optimizers/orthogonalized_optimizers/muon_utils.py Updated coefficient set handling logic to support different iteration strategies per coefficient type, specifically implementing Polar Express paper recommendation to repeat last coefficient set for steps > 8

Sequence Diagram

sequenceDiagram
    participant Caller
    participant newton_schulz
    participant _COEFFICIENT_SETS
    participant newton_schulz_step

    Caller->>newton_schulz: x, steps, coefficient_type
    
    alt coefficient_type in _COEFFICIENT_SETS
        newton_schulz->>_COEFFICIENT_SETS: Get coefficient_sets
        _COEFFICIENT_SETS-->>newton_schulz: list of (a,b,c) tuples
    else coefficient_type == "custom"
        Caller->>newton_schulz: custom_coefficient_sets
    end
    
    newton_schulz->>newton_schulz: Calculate num_coeffs = len(coefficient_sets)
    
    alt coefficient_type == "simple"
        newton_schulz->>newton_schulz: coefficient_sets = coefficient_sets * steps
    else coefficient_type == "polar_express"
        alt steps < num_coeffs
            newton_schulz-->>Caller: ValueError (steps must be >= num_coeffs)
        else steps > num_coeffs
            newton_schulz->>newton_schulz: Append last coefficient (steps - num_coeffs) times
        end
    else other types (quintic, aol, custom)
        alt steps % num_coeffs != 0
            newton_schulz-->>Caller: ValueError (steps must be multiple of num_coeffs)
        else
            newton_schulz->>newton_schulz: coefficient_sets = coefficient_sets * (steps // num_coeffs)
        end
    end
    
    newton_schulz->>newton_schulz: Normalize X (spectral norm <= 1)
    
    loop i in range(steps)
        newton_schulz->>newton_schulz: Get (a,b,c) = coefficient_sets[i]
        newton_schulz->>newton_schulz_step: X, a, b, c
        newton_schulz_step-->>newton_schulz: Updated X
    end
    
    newton_schulz-->>Caller: Orthogonalized tensor X
Loading

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

# For simple, repeat the same coefficients as many times as needed
coefficient_sets = coefficient_sets * steps
elif coefficient_type == "polar_express":
# For polar_express, steps must be >= 8, repeat last step if steps > 8
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded value 8 in comment could become incorrect if polar_express coefficient count changes

Suggested change
# For polar_express, steps must be >= 8, repeat last step if steps > 8
# For polar_express, steps must be >= len(coefficient_sets), repeat last step if steps > len(coefficient_sets)

Signed-off-by: mikail <mkhona@nvidia.com>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@skyw
Copy link
Contributor

skyw commented Jan 27, 2026

This shouldn't be done this way. The simple nested dict of list of coefficient data structure is already complicated than what it should be. Adding more controlling logic will break readability, very hard to know what would happen by looking at the API, needs to look the code very carefully.

If supporting what the paper does exactly is deemed necessary, it would be the time to abstract coefficient being its own iterator.

@skyw
Copy link
Contributor

skyw commented Jan 27, 2026

There are also code logic that we determined to remove before, like copy coefficient to satisfy the loop length.

coefficient_sets = coefficient_sets * (steps // num_coeffs)

It will be a complete overhaul, closing.

@skyw skyw closed this Jan 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants