Skip to content

Conversation

@casteryh
Copy link
Contributor

@casteryh casteryh commented Sep 22, 2025

Previously, the test was bogus because

  • it simply repeatead the the logic of the implementation
  • it would pass without any problems if you commented out update_weights and push_weights, since all the weights had been the same from the start to the end.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 22, 2025
@casteryh casteryh marked this pull request as ready for review September 22, 2025 23:00
@casteryh casteryh changed the title [WIP] integration test for weight sync that actually tests behavior integration test for weight sync that actually tests behavior Sep 22, 2025
@casteryh
Copy link
Contributor Author

change to draft status. either weight loading is buggy or the test is buggy.

@casteryh casteryh changed the title integration test for weight sync that actually tests behavior integration test for weight sync that actually tests behavior Sep 23, 2025
@casteryh casteryh marked this pull request as draft September 23, 2025 02:00
@casteryh
Copy link
Contributor Author

casteryh commented Sep 23, 2025

For tp_size=2, the following is failing.
See P1959792717 for full logs:

ERROR:forge.actors.policy:Validation failed for the following params: ['model.embed_tokens.weight', 'model.layers.0.self_attn.qkv_proj.weight', 'model.layers.0.mlp.gate_up_proj.weight', 'model.layers.1.self_attn.qkv_proj.weight', 'model.layers.1.mlp.gate_up_proj.weight', 'model.layers.2.self_attn.qkv_proj.weight', 'model.layers.2.mlp.gate_up_proj.weight', 'model.layers.3.self_attn.qkv_proj.weight', 'model.layers.3.mlp.gate_up_proj.weight', 'model.layers.4.self_attn.qkv_proj.weight', 'model.layers.4.mlp.gate_up_proj.weight', 'model.layers.5.self_attn.qkv_proj.weight', 'model.layers.5.mlp.gate_up_proj.weight', 'model.layers.6.self_attn.qkv_proj.weight', 'model.layers.6.mlp.gate_up_proj.weight', 'model.layers.7.self_attn.qkv_proj.weight', 'model.layers.7.mlp.gate_up_proj.weight', 'model.layers.8.self_attn.qkv_proj.weight', 'model.layers.8.mlp.gate_up_proj.weight', 'model.layers.9.self_attn.qkv_proj.weight', 'model.layers.9.mlp.gate_up_proj.weight', 'model.layers.10.self_attn.qkv_proj.weight', 'model.layers.10.mlp.gate_up_proj.weight', 'model.layers.11.self_attn.qkv_proj.weight', 'model.layers.11.mlp.gate_up_proj.weight', 'model.layers.12.self_attn.qkv_proj.weight', 'model.layers.12.mlp.gate_up_proj.weight', 'model.layers.13.self_attn.qkv_proj.weight', 'model.layers.13.mlp.gate_up_proj.weight', 'model.layers.14.self_attn.qkv_proj.weight', 'model.layers.14.mlp.gate_up_proj.weight', 'model.layers.15.self_attn.qkv_proj.weight', 'model.layers.15.mlp.gate_up_proj.weight', 'model.layers.16.self_attn.qkv_proj.weight', 'model.layers.16.mlp.gate_up_proj.weight', 'model.layers.17.self_attn.qkv_proj.weight', 'model.layers.17.mlp.gate_up_proj.weight', 'model.layers.18.self_attn.qkv_proj.weight', 'model.layers.18.mlp.gate_up_proj.weight', 'model.layers.19.self_attn.qkv_proj.weight', 'model.layers.19.mlp.gate_up_proj.weight', 'model.layers.20.self_attn.qkv_proj.weight', 'model.layers.20.mlp.gate_up_proj.weight', 'model.layers.21.self_attn.qkv_proj.weight', 'model.layers.21.mlp.gate_up_proj.weight', 'model.layers.22.self_attn.qkv_proj.weight', 'model.layers.22.mlp.gate_up_proj.weight', 'model.layers.23.self_attn.qkv_proj.weight', 'model.layers.23.mlp.gate_up_proj.weight', 'model.layers.24.self_attn.qkv_proj.weight', 'model.layers.24.mlp.gate_up_proj.weight', 'model.layers.25.self_attn.qkv_proj.weight', 'model.layers.25.mlp.gate_up_proj.weight', 'model.layers.26.self_attn.qkv_proj.weight', 'model.layers.26.mlp.gate_up_proj.weight', 'model.layers.27.self_attn.qkv_proj.weight', 'model.layers.27.mlp.gate_up_proj.weight']

Verified params = {'model.layers.18.mlp.down_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.23.mlp.down_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.14.self_attn.k_norm.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.24.self_attn.q_norm.weight', 'model.layers.6.self_attn.q_norm.weight', 'model.layers.21.post_attention_layernorm.weight', 'model.layers.22.mlp.down_proj.weight', 'model.layers.23.self_attn.k_norm.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.3.self_attn.k_norm.weight', 'model.layers.27.self_attn.k_norm.weight', 'model.layers.1.self_attn.q_norm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.11.self_attn.k_norm.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.9.self_attn.q_norm.weight', 'model.layers.4.self_attn.k_norm.weight', 'model.layers.16.mlp.down_proj.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.22.post_attention_layernorm.weight', 'model.layers.24.self_attn.k_norm.weight', 'model.layers.1.self_attn.k_norm.weight', 'model.layers.18.post_attention_layernorm.weight', 'model.layers.25.self_attn.q_norm.weight', 'model.layers.5.self_attn.k_norm.weight', 'model.layers.20.input_layernorm.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.25.input_layernorm.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.8.self_attn.q_norm.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.21.input_layernorm.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.25.self_attn.k_norm.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.26.self_attn.q_norm.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.26.self_attn.k_norm.weight', 'model.layers.20.mlp.down_proj.weight', 'model.layers.24.post_attention_layernorm.weight', 'model.layers.21.self_attn.k_norm.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.24.input_layernorm.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.12.self_attn.k_norm.weight', 'model.layers.10.self_attn.q_norm.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.19.input_layernorm.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.11.self_attn.q_norm.weight', 'model.layers.0.self_attn.q_norm.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.norm.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.26.mlp.down_proj.weight', 'model.layers.27.input_layernorm.weight', 'model.layers.2.self_attn.q_norm.weight', 'model.layers.27.mlp.down_proj.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.15.self_attn.q_norm.weight', 'model.layers.18.self_attn.k_norm.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.13.self_attn.k_norm.weight', 'model.layers.25.post_attention_layernorm.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.12.self_attn.q_norm.weight', 'model.layers.20.post_attention_layernorm.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.24.mlp.down_proj.weight', 'model.layers.15.self_attn.k_norm.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.18.input_layernorm.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_norm.weight', 'model.layers.18.self_attn.q_norm.weight', 'model.layers.2.self_attn.k_norm.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.self_attn.k_norm.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.21.self_attn.q_norm.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.26.post_attention_layernorm.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.26.input_layernorm.weight', 'model.layers.19.post_attention_layernorm.weight', 'model.layers.27.self_attn.o_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.19.self_attn.q_norm.weight', 'model.layers.4.self_attn.q_norm.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.19.mlp.down_proj.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.9.self_attn.k_norm.weight', 'model.layers.10.self_attn.k_norm.weight', 'model.layers.27.self_attn.q_norm.weight', 'model.layers.0.self_attn.k_norm.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.22.self_attn.q_norm.weight', 'model.layers.23.self_attn.q_norm.weight', 'model.layers.17.self_attn.q_norm.weight', 'model.layers.21.mlp.down_proj.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.13.self_attn.q_norm.weight', 'model.layers.5.self_attn.q_norm.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.20.self_attn.k_norm.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.6.self_attn.k_norm.weight', 'model.layers.20.self_attn.q_norm.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.7.self_attn.q_norm.weight', 'model.layers.3.self_attn.q_norm.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.16.self_attn.k_norm.weight', 'model.layers.19.self_attn.k_norm.weight', 'model.layers.22.input_layernorm.weight', 'model.layers.22.self_attn.k_norm.weight', 'model.layers.23.input_layernorm.weight', 'model.layers.7.self_attn.k_norm.weight', 'model.layers.23.post_attention_layernorm.weight', 'model.layers.17.self_attn.k_norm.weight', 'model.layers.25.mlp.down_proj.weight', 'model.layers.16.self_attn.q_norm.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.27.post_attention_layernorm.weight', 'model.layers.2.self_attn.o_proj.weight'}

@casteryh
Copy link
Contributor Author

For no parallelism at all, only model.embed_tokens.weight is failing.

Validation failed: [('model.embed_tokens.weight', AssertionError('current param model.embed_tokens.weight does not match expected value; previous param (torch.Size([151936, 2048]))= tensor([[-0.0127, 0.0195, 0.0117, ..., 0.0157, -0.0469, -0.0013],\n [ 0.0249, -0.0055, -0.0674, ..., 0.0037, 0.0403, -0.0165],\n [-0.0173, -0.0284, -0.0530, ..., -0.0305, -0.0074, 0.0649],\n ...,\n [ 0.0060, 0.0131, 0.0190, ..., 0.0020, -0.0014, -0.0055],\n [ 0.0060, 0.0131, 0.0190, ..., 0.0020, -0.0014, -0.0055],\n [ 0.0060, 0.0131, 0.0190, ..., 0.0020, -0.0014, -0.0055]],\n dtype=torch.bfloat16); expected = tensor([[-0.0190, 0.0293, 0.0176, ..., 0.0237, -0.0703, -0.0020],\n [ 0.0374, -0.0083, -0.1011, ..., 0.0056, 0.0605, -0.0247],\n [-0.0260, -0.0427, -0.0796, ..., -0.0459, -0.0111, 0.0977],\n ...,\n [ 0.0090, 0.0197, 0.0286, ..., 0.0031, -0.0021, -0.0082],\n [ 0.0090, 0.0197, 0.0286, ..., 0.0031, -0.0021, -0.0082],\n [ 0.0090, 0.0197, 0.0286, ..., 0.0031, -0.0021, -0.0082]],\n dtype=torch.bfloat16) vs got = torch.Size([151936, 2048]) tensor([[-0.0286, 0.0439, 0.0264, ..., 0.0354, -0.1055, -0.0030],\n [ 0.0562, -0.0125, -0.1514, ..., 0.0083, 0.0908, -0.0371],\n [-0.0391, -0.0640, -0.1191, ..., -0.0688, -0.0166, 0.1465],\n ...,\n [ 0.0136, 0.0295, 0.0430, ..., 0.0046, -0.0032, -0.0123],\n [ 0.0136, 0.0295, 0.0430, ..., 0.0046, -0.0032, -0.0123],\n [ 0.0136, 0.0295, 0.0430, ..., 0.0046, -0.0032, -0.0123]],\n dtype=torch.bfloat16)'))]

Comment on lines -248 to -250
2. Initializes RLTrainer, make the weights available in torchstore.
3. Initializes Policy, and calls update_weights() to load weights from torchstore.
4. Validate the policy weights against source of truth.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like because this is the new policy version has the same weights, we won't know for sure if the "update" logic is correct or not.

Comment on lines 404 to 408
async def _test_validate_model_params(self, validate_fn):
"""Validate updated model params using validate_fn."""
logger.info("[Policy] start validating model parameters post update")
return await self.policy_worker._test_validate_model_params.call(validate_fn)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really smart! Having the validate_fn as an argument greatly simplified the logic of validating the weights on each worker

"post_attention_layernorm",
"o_proj",
"norm.weight",
"embed_tokens.weight",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is weird -- in the single policy worker test case, model.embed_tokens.weight should fall into this case.

AssertionError: Validation failed with exception: Validation failed: [('model.embed_tokens.weight', AssertionError('current param model.embed_tokens.weight does not match expected value; previous param (torch.Size([151936, 2048]))= tensor([[-0.0127,  0.0195,  0.0117,  ...,  0.0157, -0.0469, -0.0013],\n        [ 0.0249, -0.0055, -0.0674,  ...,  0.0037,  0.0403, -0.0165],\n        [-0.0173, -0.0284, -0.0530,  ..., -0.0305, -0.0074,  0.0649],\n        ...,\n        [ 0.0060,  0.0131,  0.0190,  ...,  0.0020, -0.0014, -0.0055],\n        [ 0.0060,  0.0131,  0.0190,  ...,  0.0020, -0.0014, -0.0055],\n        [ 0.0060,  0.0131,  0.0190,  ...,  0.0020, -0.0014, -0.0055]],\n       dtype=torch.bfloat16); expected = tensor([[-0.0190,  0.0293,  0.0176,  ...,  0.0237, -0.0703, -0.0020],\n        [ 0.0374, -0.0083, -0.1011,  ...,  0.0056,  0.0605, -0.0247],\n        [-0.0260, -0.0427, -0.0796,  ..., -0.0459, -0.0111,  0.0977],\n        ...,\n        [ 0.0090,  0.0197,  0.0286,  ...,  0.0031, -0.0021, -0.0082],\n        [ 0.0090,  0.0197,  0.0286,  ...,  0.0031, -0.0021, -0.0082],\n        [ 0.0090,  0.0197,  0.0286,  ...,  0.0031, -0.0021, -0.0082]],\n       dtype=torch.bfloat16) vs got = torch.Size([151936, 2048]) tensor([[-0.0286,  0.0439,  0.0264,  ...,  0.0354, -0.1055, -0.0030],\n        [ 0.0562, -0.0125, -0.1514,  ...,  0.0083,  0.0908, -0.0371],\n        [-0.0391, -0.0640, -0.1191,  ..., -0.0688, -0.0166,  0.1465],\n        ...,\n        [ 0.0136,  0.0295,  0.0430,  ...,  0.0046, -0.0032, -0.0123],\n        [ 0.0136,  0.0295,  0.0430,  ...,  0.0046, -0.0032, -0.0123],\n        [ 0.0136,  0.0295,  0.0430,  ...,  0.0046, -0.0032, -0.0123]],\n       dtype=torch.bfloat16)'))]

Comment on lines -70 to -73
q = saved_sd[prefix + "self_attn.q_proj.weight"]
k = saved_sd[prefix + "self_attn.k_proj.weight"]
v = saved_sd[prefix + "self_attn.v_proj.weight"]
load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These also appeared in the error log https://www.internalfb.com/intern/everpaste/?handle=GFQsCCF1lpC-jBUEAI4XK_T7_DM8bsIXAAAB&phabricator_paste_number=1959792717

'model.layers.0.self_attn.qkv_proj.weight', 'model.layers.0.mlp.gate_up_proj.weight', 'model.layers.1.self_attn.qkv_proj.weight', 'model.layers.1.mlp.gate_up_proj.weight',

@casteryh
Copy link
Contributor Author

casteryh commented Sep 23, 2025

Fixed the bug in DTensor multiplication. All tests passing now for single worker.
For tp_size = 2. All the weights that involve manual concatenating are not correct.
https://www.internalfb.com/phabricator/paste/view/P1960764150

ERROR:forge.actors.policy:Validation failed for the following params: ['model.layers.0.self_attn.qkv_proj.weight', 'model.layers.0.mlp.gate_up_proj.weight', 'model.layers.1.self_attn.qkv_proj.weight', 'model.layers.1.mlp.gate_up_proj.weight', 'model.layers.2.self_attn.qkv_proj.weight', 'model.layers.2.mlp.gate_up_proj.weight', 'model.layers.3.self_attn.qkv_proj.weight', 'model.layers.3.mlp.gate_up_proj.weight', 'model.layers.4.self_attn.qkv_proj.weight', 'model.layers.4.mlp.gate_up_proj.weight', 'model.layers.5.self_attn.qkv_proj.weight', 'model.layers.5.mlp.gate_up_proj.weight', 'model.layers.6.self_attn.qkv_proj.weight', 'model.layers.6.mlp.gate_up_proj.weight', 'model.layers.7.self_attn.qkv_proj.weight', 'model.layers.7.mlp.gate_up_proj.weight', 'model.layers.8.self_attn.qkv_proj.weight', 'model.layers.8.mlp.gate_up_proj.weight', 'model.layers.9.self_attn.qkv_proj.weight', 'model.layers.9.mlp.gate_up_proj.weight', 'model.layers.10.self_attn.qkv_proj.weight', 'model.layers.10.mlp.gate_up_proj.weight', 'model.layers.11.self_attn.qkv_proj.weight', 'model.layers.11.mlp.gate_up_proj.weight', 'model.layers.12.self_attn.qkv_proj.weight', 'model.layers.12.mlp.gate_up_proj.weight', 'model.layers.13.self_attn.qkv_proj.weight', 'model.layers.13.mlp.gate_up_proj.weight', 'model.layers.14.self_attn.qkv_proj.weight', 'model.layers.14.mlp.gate_up_proj.weight', 'model.layers.15.self_attn.qkv_proj.weight', 'model.layers.15.mlp.gate_up_proj.weight', 'model.layers.16.self_attn.qkv_proj.weight', 'model.layers.16.mlp.gate_up_proj.weight', 'model.layers.17.self_attn.qkv_proj.weight', 'model.layers.17.mlp.gate_up_proj.weight', 'model.layers.18.self_attn.qkv_proj.weight', 'model.layers.18.mlp.gate_up_proj.weight', 'model.layers.19.self_attn.qkv_proj.weight', 'model.layers.19.mlp.gate_up_proj.weight', 'model.layers.20.self_attn.qkv_proj.weight', 'model.layers.20.mlp.gate_up_proj.weight', 'model.layers.21.self_attn.qkv_proj.weight', 'model.layers.21.mlp.gate_up_proj.weight', 'model.layers.22.self_attn.qkv_proj.weight', 'model.layers.22.mlp.gate_up_proj.weight', 'model.layers.23.self_attn.qkv_proj.weight', 'model.layers.23.mlp.gate_up_proj.weight', 'model.layers.24.self_attn.qkv_proj.weight', 'model.layers.24.mlp.gate_up_proj.weight', 'model.layers.25.self_attn.qkv_proj.weight', 'model.layers.25.mlp.gate_up_proj.weight', 'model.layers.26.self_attn.qkv_proj.weight', 'model.layers.26.mlp.gate_up_proj.weight', 'model.layers.27.self_attn.qkv_proj.weight', 'model.layers.27.mlp.gate_up_proj.weight']

@casteryh casteryh marked this pull request as ready for review September 25, 2025 17:49
@casteryh casteryh merged commit e76f2a0 into meta-pytorch:main Sep 25, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants