Address number of steps issue and be more explicit about the type of iteration#91
Address number of steps issue and be more explicit about the type of iteration#91mkhona-nvidia wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
Conversation
Signed-off-by: mikail <mkhona@nvidia.com>
be4c77c to
0001f6e
Compare
Greptile OverviewGreptile SummaryThis PR updates the Newton-Schulz iteration logic to handle coefficient sets differently based on the
The implementation converts coefficient sets to lists to enable extension via concatenation, then adjusts the list based on the coefficient type before the iteration loop. This removes the modulo operator from the iteration loop (line 168) in favor of direct indexing. Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant newton_schulz
participant _COEFFICIENT_SETS
participant newton_schulz_step
Caller->>newton_schulz: x, steps, coefficient_type
alt coefficient_type in _COEFFICIENT_SETS
newton_schulz->>_COEFFICIENT_SETS: Get coefficient_sets
_COEFFICIENT_SETS-->>newton_schulz: list of (a,b,c) tuples
else coefficient_type == "custom"
Caller->>newton_schulz: custom_coefficient_sets
end
newton_schulz->>newton_schulz: Calculate num_coeffs = len(coefficient_sets)
alt coefficient_type == "simple"
newton_schulz->>newton_schulz: coefficient_sets = coefficient_sets * steps
else coefficient_type == "polar_express"
alt steps < num_coeffs
newton_schulz-->>Caller: ValueError (steps must be >= num_coeffs)
else steps > num_coeffs
newton_schulz->>newton_schulz: Append last coefficient (steps - num_coeffs) times
end
else other types (quintic, aol, custom)
alt steps % num_coeffs != 0
newton_schulz-->>Caller: ValueError (steps must be multiple of num_coeffs)
else
newton_schulz->>newton_schulz: coefficient_sets = coefficient_sets * (steps // num_coeffs)
end
end
newton_schulz->>newton_schulz: Normalize X (spectral norm <= 1)
loop i in range(steps)
newton_schulz->>newton_schulz: Get (a,b,c) = coefficient_sets[i]
newton_schulz->>newton_schulz_step: X, a, b, c
newton_schulz_step-->>newton_schulz: Updated X
end
newton_schulz-->>Caller: Orthogonalized tensor X
|
| # For simple, repeat the same coefficients as many times as needed | ||
| coefficient_sets = coefficient_sets * steps | ||
| elif coefficient_type == "polar_express": | ||
| # For polar_express, steps must be >= 8, repeat last step if steps > 8 |
There was a problem hiding this comment.
Hardcoded value 8 in comment could become incorrect if polar_express coefficient count changes
| # For polar_express, steps must be >= 8, repeat last step if steps > 8 | |
| # For polar_express, steps must be >= len(coefficient_sets), repeat last step if steps > len(coefficient_sets) |
Signed-off-by: mikail <mkhona@nvidia.com>
|
This shouldn't be done this way. The simple nested dict of list of coefficient data structure is already complicated than what it should be. Adding more controlling logic will break readability, very hard to know what would happen by looking at the API, needs to look the code very carefully. If supporting what the paper does exactly is deemed necessary, it would be the time to abstract coefficient being its own iterator. |
|
There are also code logic that we determined to remove before, like copy coefficient to satisfy the loop length. coefficient_sets = coefficient_sets * (steps // num_coeffs)It will be a complete overhaul, closing. |
Issue raises that the Polar Express paper recommends that for
newton-schulzsteps that are greater than 8, we repeat the last set of(a,b,c)coefficients. We previously required that number of steps provided be a multiple of coefficient sets so we could cyclically run the iteration. While not incorrect, it deviated from polar express' recommendation so this PR fixes that.