Skip to content

Commit 0a902a9

Browse files
authored
[megatron] fix: patch support newer mcore version (verl-project#5372)
### What does this PR do? Patch support newer mcore version Tested on NVIDIA/Megatron-LM@bbbedbb ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: ... - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [X] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [X] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [X] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [X] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`. Signed-off-by: Hollow Man <hollowman@opensuse.org>
1 parent 5ab49ec commit 0a902a9

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

verl/models/mcore/patch.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,19 @@ def patch_forward(
258258
# Get the query, key and value tensors based on the type of attention -
259259
# self or cross attn.
260260
# query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128]
261-
query, key, value = self.get_query_key_value_tensors(
261+
qkv = self.get_query_key_value_tensors(
262262
hidden_states,
263263
key_value_states,
264264
position_ids,
265265
packed_seq_params,
266266
inference_context=inference_context,
267267
)
268+
query, key, value = qkv[:3]
269+
q_compressed = None
270+
# kv_compressed = None
271+
if len(qkv) > 4:
272+
q_compressed = qkv[3]
273+
# kv_compressed = qkv[4]
268274

269275
# ===================================================
270276
# Adjust key, value for inference
@@ -288,25 +294,36 @@ def patch_forward(
288294
# core attention computation
289295
# ==================================
290296
# Need corresponding TE change
291-
thd_qkv_format = packed_seq_params and packed_seq_params.qkv_format == "thd"
297+
non_dsa_thd_qkv_format = (
298+
packed_seq_params
299+
and packed_seq_params.qkv_format == "thd"
300+
and getattr(self.config, "experimental_attention_variant", None) is None
301+
)
292302
v_dim = value.shape[-1]
293-
if thd_qkv_format and query.shape[-1] != v_dim:
303+
if non_dsa_thd_qkv_format and query.shape[-1] != v_dim:
294304
value = F.pad(value, [0, query.shape[-1] - v_dim])
295305
self.core_attention.hidden_size_per_attention_head_v = value.shape[-1]
296306
if self.checkpoint_core_attention and self.training:
297307
core_attn_out = self._checkpointed_attention_forward(
298308
query, key, value, attention_mask, packed_seq_params=packed_seq_params
299309
)
300310
else:
311+
extra_kwargs = {}
312+
if getattr(self.config, "experimental_attention_variant", None) == "dsa":
313+
# For dsa we need to pass in the original hidden states and the compressed
314+
# query representation.
315+
extra_kwargs["x"] = hidden_states
316+
extra_kwargs["qr"] = q_compressed
301317
core_attn_out = self.core_attention(
302318
query,
303319
key,
304320
value,
305321
attention_mask,
306322
packed_seq_params=packed_seq_params,
307323
attn_mask_type=attn_mask_type,
324+
**extra_kwargs,
308325
)
309-
if thd_qkv_format:
326+
if non_dsa_thd_qkv_format:
310327
if core_attn_out.ndim == 2:
311328
core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-1], -1, value.shape[-1])
312329
if query.shape[-1] != v_dim:
@@ -329,7 +346,11 @@ def patch_forward(
329346

330347
return output, bias
331348

332-
MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors
349+
# This patch targets mcore 0.12 MLA behavior only.
350+
# For newer mcore, upstream MLA already has packed-seq + CP handling and
351+
# overriding it with the legacy implementation can break RoPE shapes.
352+
if not mcore_ge_013:
353+
MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors
333354

334355
MultiLatentAttention.forward = patch_forward
335356

0 commit comments

Comments
 (0)