Skip to content

[Feature] adapt step3 model with AFD#18

Merged
jiangkuaixue123 merged 13 commits intoafd-p2p-dbo-rebase2from
afd-step3-merge
Dec 22, 2025
Merged

[Feature] adapt step3 model with AFD#18
jiangkuaixue123 merged 13 commits intoafd-p2p-dbo-rebase2from
afd-step3-merge

Conversation

@InhabitancyCocoon
Copy link
Collaborator

@InhabitancyCocoon InhabitancyCocoon commented Dec 18, 2025

Purpose

adapt afd feature to step3 model.

Warning

This PR changes the parameters' order of compute_ffn_output method, which may be a breaking change.

Test Plan

  • Requires a mini step3 which fits in single H800 / H100. --load_format dummy can be helpful.

  • Make sure your CUDA_VISIBLE_DEVICES is set properly.

  • Commands are as follow. Notice the afd_size param.

attn dp 2, dbo enabled

vllm serve /path/to/your/step3 --dtype bfloat16 --data_parallel_size=2 --enable_expert_parallel --enforce_eager --enable-dbo --dbo-prefill-token-threshold 12 --dbo-decode-token-threshold 2 --afd-config '{"afd_connector":"p2pconnector", "afd_role": "attention", "afd_host":"127.0.0.1", "afd_port":"29500","num_afd_stages":"2","afd_extra_config":{"afd_size":"2A2F"}}'

attn tp / ep 2, dbo enabled

vllm serve /path/to/your/step3 --dtype bfloat16 --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --enable-dbo --dbo-prefill-token-threshold 12 --dbo-decode-token-threshold 2 --afd-config '{"afd_connector":"p2pconnector", "afd_role": "attention", "afd_host":"127.0.0.1", "afd_port":"29500","num_afd_stages":"2","afd_extra_config":{"afd_size":"2A2F"}}'

ffn dp 2

vllm serve /path/to/your/step3 --dtype bfloat16  --data_parallel_size=2 --enable_expert_parallel --enforce_eager --afd-config '{"afd_connector":"p2pconnector", "num_afd_stages":"2", "afd_role": "ffn", "afd_host":"127.0.0.1", "afd_port":"29500", "afd_extra_config":{"afd_size":"2A2F"}}'

ffn tp / ep 2

vllm serve /path/to/your/step3  --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --afd-config '{"afd_connector":"p2pconnector", "num_afd_stages":"2", "afd_role": "ffn", "afd_host":"127.0.0.1", "afd_port":"29500", "afd_extra_config":{"afd_size":"2A2F"}}'

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

):
logger.info(f"input_ids: {input_ids.shape}")
if inputs_embeds:
if input_ids is not None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里有bug,inputs_embeds是张量,不能直接用做if判断,顺便加了inputs_ids的判断
@jiangkuaixue123

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯 这个本来就是冗余代码~

@@ -229,7 +231,7 @@ def _execute_eager_mode(
else:
# Single TP case
rank_ffn_output = self.model.compute_ffn_output(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同修改参数顺序,统一是hidden_states在前,layer_idx在后,之前有的顺序是反着的。
@jiangkuaixue123

@InhabitancyCocoon InhabitancyCocoon marked this pull request as draft December 18, 2025 09:10
@InhabitancyCocoon InhabitancyCocoon changed the title Afd step3 merge [Feature] adapt step3 model with AFD Dec 18, 2025

return hidden_states, residual

def compute_attn_output(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiangkuaixue123

dsv2里也有compute_attn_output,不过这个method看起来根本没用到,我们要删了吗?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没用的话就删了吧

torch.tensor([num_tokens_per_ubatch] * self.config.parallel_config.data_parallel_size,
device="cpu", dtype=torch.int32),
)
logger.info("jcz recv_metadata self.dp_metadata_list:{}".format(self.dp_metadata_list))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个为啥不需要了

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好像是误删了,跑起来的时候似乎没影响,我改回来。再验证一下。

)
self._current_afd_connector_metadata.recv_handle_list = work_list
self._current_afd_connector_metadata.layer_idx = layer_idx
self._current_afd_connector_metadata.stage_idx = stage_idx
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还有这个

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同误删。


return hidden_states, residual

def compute_attn_output(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没用的话就删了吧

positions: torch.Tensor,
afd_metadata: AFDMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
recv_handle = None
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上次那个改动合入了 这个forward可能要改成上次视频通话的那种形式

@InhabitancyCocoon InhabitancyCocoon marked this pull request as ready for review December 22, 2025 07:59
@InhabitancyCocoon InhabitancyCocoon marked this pull request as draft December 22, 2025 08:44
@InhabitancyCocoon InhabitancyCocoon marked this pull request as ready for review December 22, 2025 09:50
@jiangkuaixue123 jiangkuaixue123 merged commit 93c656e into afd-p2p-dbo-rebase2 Dec 22, 2025
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants