-
Notifications
You must be signed in to change notification settings - Fork 2.2k
[TRTLLM-9493][feat] Custom AllToAll for helix parallelism #9986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| /* | ||
| * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| #pragma once | ||
|
|
||
| #include "tensorrt_llm/common/config.h" | ||
|
|
||
| #include <cuda_runtime.h> | ||
|
|
||
| #include <cstddef> | ||
| #include <cstdint> | ||
|
|
||
| TRTLLM_NAMESPACE_BEGIN | ||
|
|
||
| namespace kernels | ||
| { | ||
|
|
||
| struct HelixFieldInfo | ||
| { | ||
| uint8_t* dataPtr; | ||
| int elementCount; // Number of elements (e.g., kv_lora_rank for field 0, 1 for | ||
| // field 1) | ||
| int elementSize; // Size of each element in bytes (2 for half, 8 for float2) | ||
| int stride; // Stride between rows in bytes | ||
| }; | ||
|
|
||
| struct HelixAllToAllParams | ||
| { | ||
| HelixFieldInfo sendFields[2]; | ||
| HelixFieldInfo recvFields[2]; | ||
| int entryCount; // Number of entries per peer rank to process | ||
| uint64_t* workspace; | ||
| int workspaceStrideInU64; | ||
| int cpRank; | ||
| int cpSize; | ||
| int channelCount; // use 0 to auto-compute | ||
| int maxChannelCount; | ||
| }; | ||
|
|
||
| // ============================================================================ | ||
| // Workspace Management Functions | ||
| // ============================================================================ | ||
|
|
||
| /** | ||
| * Compute number of channels for communication based on cpSize. | ||
| * | ||
| * @param cpSize Number of context parallel ranks | ||
| * @param smCount Number of SMs available (0 = auto-detect) | ||
| * @return Number of channels to use | ||
| */ | ||
| int computeHelixMaxChannelCount(int cpSize, int smCount = 0); | ||
|
|
||
| /** | ||
| * Compute the workspace size required per rank for the all-to-all operation. | ||
| * | ||
| * @param cpSize Number of context parallel ranks | ||
| * @return Size in bytes | ||
| */ | ||
| size_t computeHelixWorkspaceSizePerRank(int cpSize); | ||
|
|
||
| /** | ||
| * Initialize workspace memory for a given rank. | ||
| * Should be called once during setup. | ||
| * | ||
| * @param workspace Pointer to workspace memory (per-rank view) | ||
| * @param cpSize Number of context parallel ranks | ||
| * @param stream CUDA stream for asynchronous operations | ||
| */ | ||
| void initializeHelixWorkspace(uint64_t* workspace, int cpSize, cudaStream_t stream); | ||
|
|
||
| /** | ||
| * Launch the helix all-to-all kernel. | ||
| * | ||
| * @param params Kernel parameters including field info and workspace | ||
| * @param allowVariableField1 Whether to allow variable field 1 | ||
| * @param stream CUDA stream for kernel launch | ||
| */ | ||
| void launchHelixAllToAll(HelixAllToAllParams const& params, bool allowVariableField1, cudaStream_t stream); | ||
|
|
||
| } // namespace kernels | ||
|
|
||
| TRTLLM_NAMESPACE_END |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.