- 
                Notifications
    You must be signed in to change notification settings 
- Fork 86
Metrics for training #982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metrics for training #982
Conversation
| The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. | 
| MODEL_NAME="Qwen/Qwen3-8B" # Change this to the desired model name | ||
| # MODEL_NAME="Qwen/Qwen3-8B" # Change this to the desired model name | ||
| MODEL_NAME="Qwen/Qwen3-0.6B" # Change this to the desired model name | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should revert it to Qwen3-8B
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the changes for this file not on purpose, it was a mistake, I am reverting it.
| # https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/trainium2.html#compute | ||
| HARDWARE_TFLOPS = { | ||
| "trn1": { | ||
| "fp32": 48 / 2, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why / 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The performance metrics are given per chip, we need it per-core.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, consider adding a comment
|  | ||
| N = collector.model_params | ||
| L, H, Q, T = collector.num_layers, collector.num_heads, collector.head_dim, collector.seq_length | ||
| flops_per_token = 6 * N + 12 * L * H * Q * T | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain better where this calculation comes from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is from the PalM paper, I will add the reference in the code.
|  | ||
| N = collector.model_params | ||
| L, H, Q, T = collector.num_layers, collector.num_heads, collector.head_dim, collector.seq_length | ||
| flops_per_token = 6 * N + 12 * L * H * Q * T | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here (maybe you can even centralize it in a method)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| (8, 1, 1), | ||
| (32, 8, 1), | ||
| (32, 1, 4), | ||
| (32, 8, 4), | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couldn't we just test the last set of params? Otherwise this can make the CI even longer (it's 1h24 already!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not that long, I'd rather be sure that the metrics are properly computed. This tests does not run any forward pass / backward pass / optimizer state. It is a lighweight one.
| "bf16": 191 / 2, | ||
| }, | ||
| "trn2": { | ||
| "fp32": 181 / 2, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trn2 chips have 8 cores per device, that are by default grouped by pairs into 4 virtual devices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I have updated that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's huge! Before merging it, do you know why the distributed training CI is failing? Can we fix it @michaelbenayoun
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thanks Michael!
What does this PR do?
This PR introduces a training metrics collection system.
ThroughputPlugin,MFUPlugin,EfficiencyPlugin, andComponentTimingPluginCore Metrics Implementation