Skip to content

Conversation

@michaelbenayoun
Copy link
Member

@michaelbenayoun michaelbenayoun commented Sep 29, 2025

What does this PR do?

This PR introduces a training metrics collection system.

  • Plugin-based architecture: Modular system with ThroughputPlugin, MFUPlugin, EfficiencyPlugin, and ComponentTimingPlugin
  • Moving window statistics: Configurable window size for real-time metrics calculation
  • Hardware detection: Automatic TRN1/TRN2 platform detection with correct peak FLOPS values

Core Metrics Implementation

  • Throughput metrics: Tokens/second calculation with proper data parallel scaling
  • Model FLOPS Utilization (MFU): System-wide MFU calculation using PaLM paper formula: 6N + 12LHQ*T FLOPS per token
  • Training efficiency: Breakdown of time spent on forward/backward/optimizer vs overhead
  • Component timing: Individual timing for forward pass, backward pass, optimizer step, and total step

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@michaelbenayoun michaelbenayoun marked this pull request as ready for review October 15, 2025 11:57
Comment on lines 16 to 18
MODEL_NAME="Qwen/Qwen3-8B" # Change this to the desired model name
# MODEL_NAME="Qwen/Qwen3-8B" # Change this to the desired model name
MODEL_NAME="Qwen/Qwen3-0.6B" # Change this to the desired model name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should revert it to Qwen3-8B

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the changes for this file not on purpose, it was a mistake, I am reverting it.

# https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/trainium2.html#compute
HARDWARE_TFLOPS = {
"trn1": {
"fp32": 48 / 2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why / 2?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The performance metrics are given per chip, we need it per-core.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, consider adding a comment


N = collector.model_params
L, H, Q, T = collector.num_layers, collector.num_heads, collector.head_dim, collector.seq_length
flops_per_token = 6 * N + 12 * L * H * Q * T
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain better where this calculation comes from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is from the PalM paper, I will add the reference in the code.


N = collector.model_params
L, H, Q, T = collector.num_layers, collector.num_heads, collector.head_dim, collector.seq_length
flops_per_token = 6 * N + 12 * L * H * Q * T
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here (maybe you can even centralize it in a method)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +33 to +36
(8, 1, 1),
(32, 8, 1),
(32, 1, 4),
(32, 8, 4),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couldn't we just test the last set of params? Otherwise this can make the CI even longer (it's 1h24 already!)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not that long, I'd rather be sure that the metrics are properly computed. This tests does not run any forward pass / backward pass / optimizer state. It is a lighweight one.

"bf16": 191 / 2,
},
"trn2": {
"fp32": 181 / 2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trn2 chips have 8 cores per device, that are by default grouped by pairs into 4 virtual devices.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I have updated that!

Copy link
Collaborator

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's huge! Before merging it, do you know why the distributed training CI is failing? Can we fix it @michaelbenayoun

Copy link
Collaborator

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks Michael!

@michaelbenayoun michaelbenayoun merged commit 1dfe4da into main Oct 31, 2025
5 checks passed
@michaelbenayoun michaelbenayoun deleted the metrics branch October 31, 2025 14:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants