Use MaxText max_segments_per_seq config variable to control Grain batch packing #2774
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
When using THD format packed data with TransformerEngine, the user must specify the maximum number of segments that can be packed into a sequence at Jax JIT time. If grain packs more segments than allowed, then this can cause crashes or data corruption.
We have previously updated grain to allow limiting the number of segments to pack into a sequence, and this PR takes the appropriate value from the MaxText configuration and passes it to Grain
Tests
We have had this fix in place in our AMD fork of MaxText for some time, but needed to get the Grain fix upstreamed first before creating this PR.
We have tested this fix extensively internally and have customers using it in production.
MaxText does not currently have any tests that use packed batches, but I can create some if needed.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.