Skip to content

Commit 69574ad

Browse files
authored
[TRTLLM-5966][feat] Helix: extend mapping to support different CP types (NVIDIA#6816)
Signed-off-by: Matthias Jouanneaux <mjoux@nvidia.com>
1 parent 96339c6 commit 69574ad

File tree

10 files changed

+134
-37
lines changed

10 files changed

+134
-37
lines changed

examples/llm-api/star_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88

99
from tensorrt_llm import LLM, SamplingParams
10+
from tensorrt_llm.mapping import CpType
1011
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
1112

1213

@@ -59,7 +60,7 @@ def generate_llm_outputs(args, data, fp8=False, fp8_kv_cache=False):
5960
kv_cache_quant_algo=QuantAlgo.FP8 if fp8_kv_cache
6061
else None) if fp8 else QuantConfig()
6162
cp_config = {
62-
"cp_type": "star_attention",
63+
"cp_type": CpType.STAR,
6364
"cp_anchor_size": args.sa_anchor_size,
6465
"block_size": args.sa_block_size
6566
}

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from tensorrt_llm.lora_helper import (LoraConfig,
1717
get_default_trtllm_modules_to_hf_modules)
1818
from tensorrt_llm.lora_manager import load_torch_lora
19-
from tensorrt_llm.mapping import Mapping
19+
from tensorrt_llm.mapping import CpType, Mapping
2020

2121
from ..model_config import ModelConfig
2222
from ..speculative import get_num_extra_kv_tokens, get_spec_decoder
@@ -589,7 +589,7 @@ def instantiate_sampler(engine: PyTorchModelEngine,
589589
mapping,
590590
max_seq_len=engine.max_seq_len,
591591
enable_mixed_sampler=pytorch_backend_config.enable_mixed_sampler)
592-
if mapping.cp_config.get('cp_type') == 'star_attention':
592+
if mapping.cp_config.get('cp_type') == CpType.STAR:
593593
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
594594
return TorchSampler(sampler_args)
595595
if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder(

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212

1313
from tensorrt_llm._utils import nvtx_range
14+
from tensorrt_llm.mapping import CpType
1415

1516
from ..distributed import Distributed
1617
from .llm_request import (ExecutorRequest, LlmRequest,
@@ -569,9 +570,9 @@ def _merge_requests(
569570
cp_config = self.dist.cp_config
570571
if 'cp_type' in cp_config:
571572
cp_type = cp_config['cp_type']
572-
if cp_type == 'star_attention':
573+
if cp_type == CpType.STAR:
573574
return self._merge_star_attention_requests(new_requests)
574-
elif cp_type == 'ring_attention':
575+
elif cp_type == CpType.RING:
575576
raise NotImplementedError("ring attention not implemented yet")
576577
else:
577578
raise NotImplementedError(f'unsupport cp type {cp_type}')

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from tensorrt_llm.logger import logger
3030
from tensorrt_llm.lora_helper import LoraConfig
3131
from tensorrt_llm.lora_manager import LoraModelConfig
32-
from tensorrt_llm.mapping import Mapping
32+
from tensorrt_llm.mapping import CpType, Mapping
3333
from tensorrt_llm.models.modeling_utils import QuantAlgo
3434
from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2
3535

@@ -666,7 +666,7 @@ def release_batch(result: ScheduledRequests | None):
666666

667667
# TODO: current warmup_request is not suitable for star attention
668668
cp_type = self.mapping.cp_config.get('cp_type', None)
669-
if cp_type == 'star_attention':
669+
if cp_type == CpType.STAR:
670670
return
671671

672672
with contextlib.ExitStack() as stack:
@@ -2110,7 +2110,7 @@ def _prepare_inputs(
21102110
cache_indirection_buffer: Optional[torch.Tensor] = None):
21112111
if self.mapping is not None and 'cp_type' in self.mapping.cp_config:
21122112
cp_type = self.mapping.cp_config['cp_type']
2113-
if 'star_attention' == cp_type:
2113+
if CpType.STAR == cp_type:
21142114
return self._prepare_star_attention_inputs(
21152115
scheduled_requests, kv_cache_manager, attn_metadata)
21162116
else:

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
3232
ReqIdsSet)
3333
from tensorrt_llm.logger import logger
34+
from tensorrt_llm.mapping import CpType
3435
from tensorrt_llm.runtime.generation import CUASSERT
3536

3637
from ..distributed import Distributed
@@ -1460,7 +1461,7 @@ def _update_request_states(self, scheduled_requests: ScheduledRequests):
14601461
cp_config = self.dist.cp_config
14611462
if 'cp_type' in cp_config:
14621463
cp_type = cp_config['cp_type']
1463-
if cp_type == 'star_attention':
1464+
if cp_type == CpType.STAR:
14641465
self._update_request_states_star_attention(scheduled_requests)
14651466
else:
14661467
assert False, f'Unsupport cp_type {cp_type}'

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range
1818
from ...logger import logger
19-
from ...mapping import Mapping
19+
from ...mapping import CpType, Mapping
2020
from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig,
2121
get_draft_token_length)
2222
from .scheduler import ScheduledRequests
@@ -402,7 +402,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
402402
# allocate KV Cache
403403
for req in context_batch:
404404
req_beam_width = req.sampling_config.beam_width
405-
if 'cp_type' in self.mapping.cp_config and 'star_attention' == self.mapping.cp_config[
405+
if 'cp_type' in self.mapping.cp_config and CpType.STAR == self.mapping.cp_config[
406406
'cp_type']:
407407
if req.ctx_iters == 0:
408408
seq_len = sum(

tensorrt_llm/mapping.py

Lines changed: 78 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,23 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
from enum import IntEnum
1516
from typing import List
1617

1718
import torch
1819

1920

21+
class CpType(IntEnum):
22+
# CP type for ulysses parallelism
23+
ULYSSES = 0
24+
# CP type for star attention
25+
STAR = 1
26+
# CP type for ring attention
27+
RING = 2
28+
# CP type for helix parallelism
29+
HELIX = 3
30+
31+
2032
class Mapping(object):
2133
'''
2234
A node with 8 GPUs, tp_size = 4, cp_size = 1, pp_size = 2
@@ -135,58 +147,70 @@ def __init__(
135147
if moe_cluster_size == -1:
136148
moe_cluster_size = 1
137149

150+
cp_type = CpType.ULYSSES if cp_config is None else cp_config.get(
151+
"cp_type", CpType.ULYSSES)
152+
moe_world_size = tp_size if cp_type == CpType.ULYSSES else tp_size * cp_size
153+
138154
if moe_tp_size == -1 and moe_ep_size == -1:
139-
moe_tp_size = tp_size // moe_cluster_size
155+
moe_tp_size = moe_world_size // moe_cluster_size
140156
moe_ep_size = 1
141157

142158
elif moe_tp_size == -1:
143-
moe_tp_size = tp_size // (moe_ep_size * moe_cluster_size)
159+
moe_tp_size = moe_world_size // (moe_ep_size * moe_cluster_size)
144160

145161
elif moe_ep_size == -1:
146-
moe_ep_size = tp_size // (moe_tp_size * moe_cluster_size)
162+
moe_ep_size = moe_world_size // (moe_tp_size * moe_cluster_size)
147163

148164
if attn_tp_size == -1 and attn_cp_size == -1:
149-
# fallback to ulysses
150-
attn_tp_size = tp_size * cp_size
151-
attn_cp_size = 1
165+
if cp_type == CpType.ULYSSES:
166+
# fallback to ulysses
167+
attn_tp_size = tp_size * cp_size
168+
attn_cp_size = 1
169+
else:
170+
# fallback to helix
171+
attn_tp_size = tp_size
172+
attn_cp_size = cp_size
152173

153174
elif attn_tp_size == -1:
154-
attn_tp_size = cp_size * tp_size // attn_cp_size
175+
attn_tp_size = (tp_size * cp_size) // attn_cp_size
155176

156177
elif attn_cp_size == -1:
157-
attn_cp_size = cp_size * tp_size // attn_tp_size
178+
attn_cp_size = (tp_size * cp_size) // attn_tp_size
158179

159-
if attn_cp_size != 1:
180+
if attn_cp_size != 1 and cp_type == CpType.ULYSSES:
160181
raise ValueError(
161-
f"attn_cp_size must be 1 for now, but got {attn_tp_size}, {attn_cp_size}."
182+
f"attn_cp_size must be 1 for now for ulysses, but got {attn_tp_size}, {attn_cp_size}."
162183
)
163184

164185
if auto_parallel:
165-
if tp_size != 1 or pp_size != 1 or tp_size != 1:
186+
if tp_size != 1 or pp_size != 1 or cp_size != 1:
166187
raise ValueError(
167-
f"When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, but got {tp_size}, {pp_size}, {cp_size}."
168-
)
188+
"When auto parallel is enabled, tp_size, pp_size, cp_size must be 1, "
189+
f"but got {tp_size}, {pp_size}, {cp_size}.")
169190
else:
170191
if tp_size * pp_size * cp_size != world_size:
171192
raise ValueError(
172-
f"world_size must equal to tp_size * pp_size * cp_size, but got {world_size} != {tp_size} * {pp_size} * {cp_size}."
193+
"world_size must equal to tp_size * pp_size * cp_size, "
194+
f"but got {world_size} != {tp_size} * {pp_size} * {cp_size}."
173195
)
174196

175197
moe_tp_ep_size = moe_tp_size * moe_ep_size
176198
moe_tp_cluster_ep_size = moe_tp_ep_size * moe_cluster_size
177-
if moe_tp_cluster_ep_size != tp_size:
199+
if moe_tp_cluster_ep_size != moe_world_size:
178200
raise ValueError(
179-
f"tp_size must equal to moe_tp_size * moe_ep_size * moe_cluster_size, but got {tp_size} != {moe_tp_size} * {moe_ep_size} * {moe_cluster_size}"
180-
)
201+
"moe_tp_size * moe_ep_size * moe_cluster_size must equal to moe_world_size, "
202+
f"but got {moe_tp_cluster_ep_size} != {moe_world_size}")
181203

182204
attn_tp_cp_size = attn_tp_size * attn_cp_size
183205
if attn_tp_cp_size != tp_size * cp_size:
184206
raise ValueError(
185-
f"tp_size * cp_size must equal to attn_tp_size * attn_cp_size, but got {tp_size} * {cp_size} != {attn_tp_size} * {attn_cp_size}"
207+
"tp_size * cp_size must equal to attn_tp_size * attn_cp_size, "
208+
f"but got {tp_size} * {cp_size} != {attn_tp_size} * {attn_cp_size}"
186209
)
187210

188-
if moe_ep_size != 1 and cp_size > 1:
189-
raise NotImplementedError("CP don't support MoE tp/ep yet")
211+
if moe_ep_size != 1 and cp_size > 1 and cp_type != CpType.HELIX:
212+
raise NotImplementedError(
213+
f"CP {cp_type} doesn't support MoE tp/ep yet")
190214

191215
self.tp_size = tp_size
192216
self.cp_size = cp_size
@@ -275,6 +299,7 @@ def __eq__(self, other):
275299
and self.moe_ep_size == other.moe_ep_size
276300
and self.attn_tp_size == other.attn_tp_size
277301
and self.attn_cp_size == other.attn_cp_size
302+
and self.cp_config == other.cp_config
278303
and self.auto_parallel == other.auto_parallel)
279304

280305
def __hash__(self):
@@ -290,6 +315,8 @@ def __hash__(self):
290315
self.moe_ep_size,
291316
self.attn_tp_size,
292317
self.attn_cp_size,
318+
# note: we do not allow updating cp_config after initialization
319+
tuple(sorted(self.cp_config.items())),
293320
self.auto_parallel,
294321
))
295322

@@ -376,8 +403,13 @@ def local_rank(self):
376403
def dp_size(self):
377404
return self.tp_size if self.enable_attention_dp else 1
378405

379-
def has_cp(self):
380-
return self.cp_size > 1
406+
def has_cp_ulysses(self):
407+
return self.cp_size > 1 and self.cp_config.get(
408+
"cp_type") == CpType.ULYSSES
409+
410+
def has_cp_helix(self):
411+
return self.cp_size > 1 and self.cp_config.get(
412+
"cp_type") == CpType.HELIX
381413

382414
def get_node_rank(self, rank: int):
383415
return rank // self.gpus_per_node
@@ -415,6 +447,29 @@ def next_pp_rank(self):
415447
p = p - self.world_size
416448
return p
417449

450+
def is_last_cp_rank(self):
451+
return self.cp_rank == self.cp_size - 1
452+
453+
def is_first_cp_rank(self):
454+
return self.cp_rank == 0
455+
456+
def has_cp(self):
457+
return self.cp_size > 1
458+
459+
def prev_cp_rank(self):
460+
p = self.rank - self.tp_size
461+
if p // (self.tp_size * self.cp_size) < self.rank // (self.tp_size *
462+
self.cp_size):
463+
return p + self.tp_size * self.cp_size
464+
return p
465+
466+
def next_cp_rank(self):
467+
p = self.rank + self.tp_size
468+
if p // (self.tp_size * self.cp_size) > self.rank // (self.tp_size *
469+
self.cp_size):
470+
return p - self.tp_size * self.cp_size
471+
return p
472+
418473
def has_moe_cluster(self):
419474
return self.moe_cluster_size > 1
420475

@@ -453,5 +508,6 @@ def to_dict(self):
453508
'moe_ep_size': self.moe_ep_size,
454509
'attn_tp_size': self.attn_tp_size,
455510
'attn_cp_size': self.attn_cp_size,
511+
'cp_config': self.cp_config,
456512
'auto_parallel': self.auto_parallel,
457513
}

tests/unittest/_torch/multi_gpu/test_star_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tensorrt_llm import LLM, SamplingParams
99
from tensorrt_llm.llmapi import KvCacheConfig
1010
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
11+
from tensorrt_llm.mapping import CpType
1112
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
1213

1314
MAX_SEQ_LEN = 4096 + 1024
@@ -54,7 +55,7 @@ def test_model(backend, model_name, quant, sp_size, sa_block_size,
5455

5556
model_dir = str(llm_models_root() / model_name)
5657
cp_config = {
57-
"cp_type": "star_attention",
58+
"cp_type": CpType.STAR,
5859
"cp_anchor_size": sa_anchor_size,
5960
"block_size": sa_block_size
6061
}

tests/unittest/_torch/test_flashinfer_star_attn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tensorrt_llm._torch.metadata import KVCacheParams
1414
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
1515
from tensorrt_llm.bindings.executor import KvCacheConfig
16-
from tensorrt_llm.mapping import Mapping
16+
from tensorrt_llm.mapping import CpType, Mapping
1717

1818

1919
class TestingStarAttentionMetadata(StarAttentionMetadata):
@@ -144,7 +144,7 @@ def test_flashinfer_star_attention(self, scenario: Scenario):
144144
tokens_per_block = 64
145145
max_seq_len = tokens_per_block * num_blocks
146146
cp_config = {
147-
"cp_type": "star_attention",
147+
"cp_type": CpType.STAR,
148148
"cp_anchor_size": scenario.anchor_size,
149149
"block_size": scenario.block_size
150150
}
@@ -579,7 +579,7 @@ def test_attention_with_cuda_graphs(
579579
max_seq_len = tokens_per_block * num_blocks
580580
num_layers = 1 if isinstance(num_kv_heads, int) else len(num_kv_heads)
581581
cp_config = {
582-
"cp_type": "star_attention",
582+
"cp_type": CpType.STAR,
583583
"cp_anchor_size": test_scenario.anchor_size,
584584
"block_size": test_scenario.block_size
585585
}

tests/unittest/others/test_mapping.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,40 @@ def test_mapping(self):
4444
self.assertTrue(m.is_last_pp_rank())
4545
self.assertEqual(m.prev_pp_rank(), 4)
4646
self.assertEqual(m.next_pp_rank(), 0)
47+
48+
m = Mapping(world_size=2, rank=0, cp_size=2)
49+
self.assertEqual(len(m.tp_groups), 2)
50+
self.assertEqual(len(m.pp_groups), 2)
51+
self.assertEqual(len(m.cp_groups), 1)
52+
self.assertEqual(m.tp_group, [0])
53+
self.assertEqual(m.pp_group, [0])
54+
self.assertEqual(m.cp_group, [0, 1])
55+
56+
m = Mapping(world_size=8, rank=3, tp_size=2, pp_size=2, cp_size=2)
57+
self.assertEqual(len(m.tp_groups), 4)
58+
self.assertEqual(len(m.pp_groups), 4)
59+
self.assertEqual(len(m.cp_groups), 4)
60+
self.assertEqual(m.tp_group, [2, 3])
61+
self.assertEqual(m.pp_group, [3, 7])
62+
self.assertEqual(m.cp_group, [1, 3])
63+
self.assertTrue(m.is_first_pp_rank())
64+
self.assertFalse(m.is_last_pp_rank())
65+
self.assertFalse(m.is_first_cp_rank())
66+
self.assertTrue(m.is_last_cp_rank())
67+
self.assertEqual(m.prev_pp_rank(), 7)
68+
self.assertEqual(m.next_pp_rank(), 7)
69+
self.assertEqual(m.prev_cp_rank(), 1)
70+
self.assertEqual(m.next_cp_rank(), 1)
71+
72+
m = Mapping(world_size=16, rank=9, tp_size=2, pp_size=2, cp_size=4)
73+
self.assertEqual(m.tp_group, [8, 9])
74+
self.assertEqual(m.pp_group, [1, 9])
75+
self.assertEqual(m.cp_group, [9, 11, 13, 15])
76+
self.assertFalse(m.is_first_pp_rank())
77+
self.assertTrue(m.is_last_pp_rank())
78+
self.assertTrue(m.is_first_cp_rank())
79+
self.assertFalse(m.is_last_cp_rank())
80+
self.assertEqual(m.prev_pp_rank(), 1)
81+
self.assertEqual(m.next_pp_rank(), 1)
82+
self.assertEqual(m.prev_cp_rank(), 15)
83+
self.assertEqual(m.next_cp_rank(), 11)

0 commit comments

Comments
 (0)