-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Expand file tree
/
Copy pathdisaggregated_params.py
More file actions
124 lines (110 loc) · 5.68 KB
/
disaggregated_params.py
File metadata and controls
124 lines (110 loc) · 5.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Dict, List, Optional
import numpy as np
# isort: off
# needed before trying to import bindings to load tensorrt_libs
import tensorrt as trt # noqa
# isort: on
from tensorrt_llm.bindings import executor as tllme
class DisaggScheduleStyle(IntEnum):
CONTEXT_FIRST = 0
GENERATION_FIRST = 1
@dataclass(slots=True, kw_only=True)
class DisaggregatedParams:
"""Disaggregated serving parameters.
Args:
request_type (str): The type of request ("context_only" | "generation_only" | "context_and_generation")
first_gen_tokens (List[int]): The first tokens of the generation request
ctx_request_id (int): The context request id
opaque_state(bytes): Any additional state needing to be exchanged between context and gen instances
draft_tokens (List[int]): The draft tokens of the generation request
disagg_request_id (int): The disaggregated request id, if set, both context and generation requests will use it
as underlying request id.
first_gen_log_probs (List): The logprobs for first_gen_tokens, produced during prefill.
Each entry is a list (one per beam) of TokenLogprobs (list of dict[int, Logprob]).
first_gen_logits (List): The generation logits for first_gen_tokens, produced during prefill.
Each entry is a torch.Tensor of shape [num_tokens, vocab_size] (one per beam/sequence).
multimodal_embedding_handles (List[Dict[str, Any]]): The resulting multimodal embedding handles from ViT.
multimodal_hashes (List[List[int]]): The multimodal hashes of each multimodal item in the request.
"""
request_type: Optional[str] = None
# P-D Disaggregated Params
first_gen_tokens: Optional[List[int]] = None
first_gen_log_probs: Optional[List] = None
first_gen_logits: Optional[List] = None
ctx_request_id: Optional[int] = None
opaque_state: Optional[bytes] = None
draft_tokens: Optional[List[int]] = None
# If disagg_request_id is set, both context and generation requests will use it as underlying request id.
disagg_request_id: Optional[int] = None
ctx_dp_rank: Optional[int] = None
ctx_info_endpoint: Optional[str] = None
schedule_style: Optional[DisaggScheduleStyle] = None
# E-P Disaggregated Params
multimodal_embedding_handles: Optional[List[Dict[str, Any]]] = (
None # multimodal embedding handles should be a list of cudaIPC handles for each mm_embedding
)
multimodal_hashes: Optional[List[List[int]]] = (
None # user provided mm hashes should be a list of 8 integers
)
mrope_position_ids_handle: Optional[Dict[str, Any]] = None
mrope_position_deltas_handle: Optional[Dict[str, Any]] = None
def get_context_phase_params(self) -> tllme.ContextPhaseParams:
# Prefer disagg_request_id over ctx_request_id
request_id = (
self.disagg_request_id if self.disagg_request_id is not None else self.ctx_request_id
)
# `first_gen_tokens` is now required by bindings and cannot be None.
first_gen_tokens = self.first_gen_tokens if self.first_gen_tokens is not None else []
return tllme.ContextPhaseParams(
first_gen_tokens,
request_id,
self.opaque_state,
self.draft_tokens,
self.ctx_dp_rank,
self.ctx_info_endpoint,
)
def get_request_type(self) -> tllme.RequestType:
if self.request_type == "context_only":
return tllme.RequestType.REQUEST_TYPE_CONTEXT_ONLY
elif self.request_type == "generation_only":
return tllme.RequestType.REQUEST_TYPE_GENERATION_ONLY
elif self.request_type == "context_and_generation":
return tllme.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION
else:
raise ValueError(
f"Unknown request type: {self.request_type}. Must be context_only, generation_only or "
"context_and_generation"
)
def __post_init__(self):
if self.request_type is not None:
self.request_type = self.request_type.lower()
if self.request_type not in [
"context_only",
"generation_only",
"context_and_generation",
]:
raise ValueError(
f"Unknown request type: {self.request_type}. Must be context_only, generation_only or "
"context_and_generation"
)
if self.multimodal_embedding_handles is not None:
if self.multimodal_hashes is not None:
# if mm hashes are provided, kvcache reuse can be enabled
assert len(self.multimodal_embedding_handles) == len(self.multimodal_hashes), (
"multimodal_embedding_handles and multimodal_hashes must have the same length"
)
for mm_hash in self.multimodal_hashes:
assert isinstance(mm_hash, list), "mm_hash must be a list"
assert len(mm_hash) == 8, "mm_hash must be a list of 8 integers"
assert all(isinstance(x, int) for x in mm_hash), "mm_hash must contain integers"
else:
# if user did not provide mm embedding handles, kvcache reuse will be disabled
assert len(self.multimodal_embedding_handles) > 0, (
"multimodal_embedding_handles must be provided"
)
vals = np.random.randint(
np.iinfo(np.int32).min, np.iinfo(np.int32).max, size=8, dtype=np.int32
).tolist()
self.multimodal_hashes = [vals] * len(self.multimodal_embedding_handles)