Skip to content

Conversation

joecummings
Copy link
Member

@joecummings joecummings commented Sep 26, 2025

Context

This PR serves to fix an issue we were seeing wherein the first loss step was astronomically high (>500). The issue was tracked to a very high KL divergence value, which just measures the difference in logprobs between the reference model and the training model. For the very first step, this definitely shouldn't be the case b/c they are both the same model at that time!

Therefore, perhaps the weights on the reference model and training model were actually not the same. Further comparison against a forward pass of the Hugging Face model confirmed this hypothesis.

Fix

Load in the reference model weights correctly through the TorchTitan APIs.

Before

Screenshot 2025-09-26 at 3 45 25 PM

After

Screenshot 2025-09-26 at 3 48 05 PM

To-dos

  • Confirm that loss is still reasonable under the distributed setting. Likely needs Weight loading working correctly with tp: use vllm builtin load_weights() #184 to land first
  • When running the test script attached to this PR, it shows that while the outputs from the Hugging Face model impl and the TorchTitan model impl are similar, they are not exactly the same. Some difference is expected b/c of the RoPE implementation chosen; however, it's worth some investigating to determine whether this difference is too large.
  • Formulate a better plan for how to expose Titan "APIs" - this debugging experience was a nightmare b/c there is no way to understand what's going on in Titan without going to that code base and then in addition going to the experiments/forge folder. This is untenable.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 26, 2025
@joecummings joecummings force-pushed the compare-against-hf-trainer branch from 1148c66 to 1846b85 Compare September 26, 2025 19:34
@joecummings joecummings force-pushed the compare-against-hf-trainer branch from 1846b85 to c0c0a35 Compare September 26, 2025 19:35
@joecummings joecummings marked this pull request as ready for review September 26, 2025 19:50
async def setup(self):
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was actually staring at this in the trainer side .. It's unclear at a glance how the checkpointer.load is associated with loading the HF model weights

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's not very clear without digging into the TorchTitan checkpointing code here: https://github.com/pytorch/torchtitan/blob/5b5d46856b400c8550989415bee91473aab4f921/torchtitan/components/checkpoint.py#L523

All the information is taken from the config and instantiated into the CheckpointManager. Then the load call only takes a "step", which in our case isn't needed b/c it should be a static model every time.

async def setup(self):
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙃

@joecummings joecummings merged commit afdee53 into meta-pytorch:main Sep 26, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants