-
Notifications
You must be signed in to change notification settings - Fork 33
integration test for weight sync that actually tests behavior #217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
change to draft status. either weight loading is buggy or the test is buggy. |
|
For tp_size=2, the following is failing.
|
|
For no parallelism at all, only
|
| 2. Initializes RLTrainer, make the weights available in torchstore. | ||
| 3. Initializes Policy, and calls update_weights() to load weights from torchstore. | ||
| 4. Validate the policy weights against source of truth. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like because this is the new policy version has the same weights, we won't know for sure if the "update" logic is correct or not.
| async def _test_validate_model_params(self, validate_fn): | ||
| """Validate updated model params using validate_fn.""" | ||
| logger.info("[Policy] start validating model parameters post update") | ||
| return await self.policy_worker._test_validate_model_params.call(validate_fn) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really smart! Having the validate_fn as an argument greatly simplified the logic of validating the weights on each worker
| "post_attention_layernorm", | ||
| "o_proj", | ||
| "norm.weight", | ||
| "embed_tokens.weight", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is weird -- in the single policy worker test case, model.embed_tokens.weight should fall into this case.
AssertionError: Validation failed with exception: Validation failed: [('model.embed_tokens.weight', AssertionError('current param model.embed_tokens.weight does not match expected value; previous param (torch.Size([151936, 2048]))= tensor([[-0.0127, 0.0195, 0.0117, ..., 0.0157, -0.0469, -0.0013],\n [ 0.0249, -0.0055, -0.0674, ..., 0.0037, 0.0403, -0.0165],\n [-0.0173, -0.0284, -0.0530, ..., -0.0305, -0.0074, 0.0649],\n ...,\n [ 0.0060, 0.0131, 0.0190, ..., 0.0020, -0.0014, -0.0055],\n [ 0.0060, 0.0131, 0.0190, ..., 0.0020, -0.0014, -0.0055],\n [ 0.0060, 0.0131, 0.0190, ..., 0.0020, -0.0014, -0.0055]],\n dtype=torch.bfloat16); expected = tensor([[-0.0190, 0.0293, 0.0176, ..., 0.0237, -0.0703, -0.0020],\n [ 0.0374, -0.0083, -0.1011, ..., 0.0056, 0.0605, -0.0247],\n [-0.0260, -0.0427, -0.0796, ..., -0.0459, -0.0111, 0.0977],\n ...,\n [ 0.0090, 0.0197, 0.0286, ..., 0.0031, -0.0021, -0.0082],\n [ 0.0090, 0.0197, 0.0286, ..., 0.0031, -0.0021, -0.0082],\n [ 0.0090, 0.0197, 0.0286, ..., 0.0031, -0.0021, -0.0082]],\n dtype=torch.bfloat16) vs got = torch.Size([151936, 2048]) tensor([[-0.0286, 0.0439, 0.0264, ..., 0.0354, -0.1055, -0.0030],\n [ 0.0562, -0.0125, -0.1514, ..., 0.0083, 0.0908, -0.0371],\n [-0.0391, -0.0640, -0.1191, ..., -0.0688, -0.0166, 0.1465],\n ...,\n [ 0.0136, 0.0295, 0.0430, ..., 0.0046, -0.0032, -0.0123],\n [ 0.0136, 0.0295, 0.0430, ..., 0.0046, -0.0032, -0.0123],\n [ 0.0136, 0.0295, 0.0430, ..., 0.0046, -0.0032, -0.0123]],\n dtype=torch.bfloat16)'))]
| q = saved_sd[prefix + "self_attn.q_proj.weight"] | ||
| k = saved_sd[prefix + "self_attn.k_proj.weight"] | ||
| v = saved_sd[prefix + "self_attn.v_proj.weight"] | ||
| load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These also appeared in the error log https://www.internalfb.com/intern/everpaste/?handle=GFQsCCF1lpC-jBUEAI4XK_T7_DM8bsIXAAAB&phabricator_paste_number=1959792717
'model.layers.0.self_attn.qkv_proj.weight', 'model.layers.0.mlp.gate_up_proj.weight', 'model.layers.1.self_attn.qkv_proj.weight', 'model.layers.1.mlp.gate_up_proj.weight',
|
Fixed the bug in DTensor multiplication. All tests passing now for single worker.
|
Previously, the test was bogus because