Skip to content

Conversation

@gabeweisz
Copy link

Description

When using THD format packed data with TransformerEngine, the user must specify the maximum number of segments that can be packed into a sequence at Jax JIT time. If grain packs more segments than allowed, then this can cause crashes or data corruption.

We have previously updated grain to allow limiting the number of segments to pack into a sequence, and this PR takes the appropriate value from the MaxText configuration and passes it to Grain

Tests

We have had this fix in place in our AMD fork of MaxText for some time, but needed to get the Grain fix upstreamed first before creating this PR.
We have tested this fix extensively internally and have customers using it in production.

MaxText does not currently have any tests that use packed batches, but I can create some if needed.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Copy link
Contributor

@yeandy yeandy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need to add max_sequences_per_bin=config.max_segments_per_seq in make_hf_eval_iterator too

@gabeweisz
Copy link
Author

We may need to add max_sequences_per_bin=config.max_segments_per_seq in make_hf_eval_iterator too

Done, thanks for the tip

@gabeweisz gabeweisz closed this Dec 4, 2025
@gabeweisz gabeweisz reopened this Dec 4, 2025
use_sft=config.use_sft,
sft_train_on_completion_only=config.sft_train_on_completion_only,
chat_template_path=config.chat_template_path,
max_sequences_per_bin=config.max_segments_per_seq,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to this PR, max_segment_per_seq is only relevant for GPU packed attention. But this change will apply it to TPU workloads as well.
To be cleaner, it's better to align the behavior across hardware, and across different pipelines (grain pipeline's FirstFitPackIterDataset also has this parameter). We can set the default value to -1, which means no limit (passing None to PackAndBatchOperation)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants