forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
365 lines (275 loc) · 10.9 KB
/
utils.py
File metadata and controls
365 lines (275 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
import contextlib
import threading
from dataclasses import dataclass
from enum import Enum, IntEnum
from typing import Dict, List
import torch
from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.math_utils import ceil_div, pad_up
from tensorrt_llm.quantization.utils import fp4_utils
is_torch_compiling_flag = False
is_piecewise_running_flag = False
aux_stream_name_list = [
'Attention',
'MoeShared',
'MoeChunkingOverlap',
'MoeBalancer',
]
AuxStreamType = Enum(
'AuxStreamType',
aux_stream_name_list,
)
EventType = Enum(
'EventType',
['Main', *aux_stream_name_list],
start=0,
)
# IMPORTANT: Keep the same order of activation functions in this enum and the enum in
# cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
class ActivationType(IntEnum):
InvalidType = 0
Identity = 1
Gelu = 2
Relu = 3
Silu = 4
Swiglu = 5
Geglu = 6
SwigluBias = 7
Relu2 = 8
def set_torch_compiling(enable: bool):
global is_torch_compiling_flag
is_torch_compiling_flag = enable
def is_torch_compiling() -> bool:
global is_torch_compiling_flag
return is_torch_compiling_flag
def set_piecewise_running(enable: bool):
global is_piecewise_running_flag
is_piecewise_running_flag = enable
def is_piecewise_running() -> bool:
global is_piecewise_running_flag
return is_piecewise_running_flag
_global_attrs = threading.local()
def get_global_attrs():
return _global_attrs
_model_extra_attrs = threading.local()
def get_model_extra_attrs():
return getattr(_model_extra_attrs, 'attrs', None)
@contextlib.contextmanager
def model_extra_attrs(attrs: Dict):
old_attrs = getattr(_model_extra_attrs, 'attrs', None)
_model_extra_attrs.attrs = attrs
try:
yield
finally:
_model_extra_attrs.attrs = old_attrs
def with_model_extra_attrs(get_attrs):
def decorator(func):
def wrapper(self, *args, **kwargs):
with model_extra_attrs(get_attrs(self)):
return func(self, *args, **kwargs)
return wrapper
return decorator
def make_weak_ref(x):
if isinstance(x, torch.Tensor):
return convert_to_torch_tensor(
TensorWrapper(x.data_ptr(), x.dtype, x.shape,
x.stride())) if x.is_cuda else x
elif isinstance(x, tuple):
return tuple(make_weak_ref(i) for i in x)
elif isinstance(x, list):
return [make_weak_ref(i) for i in x]
elif isinstance(x, dict):
return {k: make_weak_ref(v) for k, v in x.items()}
elif isinstance(x, (int, float, bool)):
return x
else:
raise TypeError(f"Invalid type {type(x)} to make weak ref")
@dataclass
class Fp4QuantizedTensor:
fp4_tensor: torch.Tensor
scaling_factor: torch.Tensor
is_sf_swizzled: bool = True
@property
def shape(self):
return self.fp4_tensor.shape
def compute_swizzled_sf_shape(row: int, col: int):
padded_row = pad_up(row, 128)
padded_col = pad_up(col, 4)
return padded_row, padded_col
def swizzle_sf(sf: torch.Tensor,
rows: int,
cols: int,
scaling_vector_size: int = 16):
"""Swizzle FP4 scaling factors using C++ torch op implementation
Args:
sf: [b, rows, cols_sf] or [rows, cols_sf]. The original unswizzled scaling factors.
rows: rows of the original unquantized tensor
cols_sf: ceil_div(cols, scaling_vector_size) where cols is the number of columns of the original unquantized tensor
scaling_vector_size: the size of the scaling vector
Returns:
[b * pad_up(rows, 128) * pad_up(cols_sf, 4), ] 1D swizzled scaling factors, possibly with rows and cols padded.
"""
sf_cols = ceil_div(cols, scaling_vector_size)
sf = sf.view(-1, rows, sf_cols)
return torch.ops.trtllm.block_scale_interleave(sf)
def unswizzle_sf(sf: torch.Tensor,
rows: int,
cols: int,
scaling_vector_size: int = 16):
"""Swizzle FP4 scaling factors using C++ torch op implementation
Args:
sf: The (padded and) swizzled scaling factors.
rows: rows of the original unquantized tensor
cols: cols of the original unquantized tensor
scaling_vector_size: the size of the scaling vector
Returns:
2D unswizzled scaling factors
"""
sf_cols = ceil_div(cols, scaling_vector_size)
sf = sf.view(-1, rows, sf_cols)
return torch.ops.trtllm.block_scale_interleave_reverse(sf).view(-1, sf_cols)
@torch.library.custom_op("trtllm::reswizzle_sf", mutates_args=())
def reswizzle_sf(sf: torch.Tensor,
rows: int,
cols: int,
scaling_vector_size: int = 16) -> torch.Tensor:
"""Reswizzle FP4 scaling factors using C++ torch op implementation.
It unswizzles the scaling factors in each partition first, then concatenates them together, and finally swizzles them back.
Args:
sf: The (padded and) swizzled scaling factors.
rows: rows of the original unquantized tensor
cols: cols of the original unquantized tensor
scaling_vector_size: the size of the scaling vector
Returns:
1D reswizzled scaling factors
"""
sf_cols = ceil_div(cols, scaling_vector_size)
padded_rows, padded_sf_cols = compute_swizzled_sf_shape(rows, sf_cols)
padded_cols = padded_sf_cols * scaling_vector_size
assert sf.numel() % (padded_rows * padded_sf_cols) == 0
num_partitions = sf.numel() // (padded_rows * padded_sf_cols)
sf_reshaped = sf.view(num_partitions, padded_rows, padded_sf_cols)
# Unswizzle each partition
sf_unswizzled = unswizzle_sf(sf_reshaped, padded_rows, padded_cols,
scaling_vector_size)
# Brings the unswizzled scaling factors in each partition together
total_rows = num_partitions * rows
sf_unswizzled = sf_unswizzled.view(num_partitions, padded_rows,
padded_sf_cols)
sf_concatenated = sf_unswizzled[:, :rows, :sf_cols].contiguous(
) # TODO: This will incur a elementwise kernel
sf_concatenated = sf_concatenated.view(total_rows, sf_cols)
# Finally swizzle the concatenated scaling factors
return swizzle_sf(sf_concatenated, total_rows, cols, scaling_vector_size)
@torch.library.register_fake("trtllm::reswizzle_sf")
def _(sf, rows, cols, scaling_vector_size=16):
sf_cols = ceil_div(cols, scaling_vector_size)
padded_rows, padded_sf_cols = compute_swizzled_sf_shape(rows, sf_cols)
num_partitions = sf.numel() // (padded_rows * padded_sf_cols)
total_rows = num_partitions * rows
sz = pad_up(total_rows, 128) * pad_up(cols, 4)
return sf.new_empty(sz)
def next_positive_power_of_2(x: int) -> int:
if x < 1:
return 1
# Following code is equivalent to 1 << (x - 1).bit_length()
# But this impl does not contain bit_length() so can be used by torch compile.
# It can correctly handle 64bit number which should be enough for now.
n = x - 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
return n + 1
def last_positive_power_of_2(x: int) -> int:
next = next_positive_power_of_2(x)
if next == x:
return next
return next // 2
def nearest_in_buckets(x: int, buckets: List[int]) -> int:
return min(max(next_positive_power_of_2(x), buckets[0]), buckets[-1])
def get_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
max_num_tokens = next_positive_power_of_2(max_num_tokens)
num_token_buckets = []
m = max_num_tokens
while m >= 1:
num_token_buckets.append(m)
m //= 2
return tuple(num_token_buckets[::-1])
def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
max_num_tokens = last_positive_power_of_2(max_num_tokens)
num_token_buckets = []
m = max_num_tokens
while m >= 1:
num_token_buckets.append(m)
m //= 2
return tuple(num_token_buckets[::-1])
def fp4_scale_infer_shape(input_shapes: List[List[int]]):
"""Calculate the dimensions of the fp4 scale tensor.
"""
out_shape, scale_shape = fp4_utils.get_fp4_shape(input_shapes[0],
sf_vec_size=16)
return scale_shape * 2
_enable_piecewise_cuda_graph = True
def set_piecewise_cuda_graph_flag(enable: bool):
global _enable_piecewise_cuda_graph
_enable_piecewise_cuda_graph = enable
def get_piecewise_cuda_graph_flag() -> bool:
global _enable_piecewise_cuda_graph
return _enable_piecewise_cuda_graph
@contextlib.contextmanager
def piecewise_cuda_graph(enable: bool):
prev_enable = get_piecewise_cuda_graph_flag()
set_piecewise_cuda_graph_flag(enable)
try:
yield
finally:
set_piecewise_cuda_graph_flag(prev_enable)
def set_per_request_piecewise_cuda_graph_flag(enable: bool):
_global_attrs.per_request_piecewise_cuda_graph_flag = enable
def get_per_request_piecewise_cuda_graph_flag() -> bool:
return getattr(_global_attrs, 'per_request_piecewise_cuda_graph_flag', True)
def create_lm_head_tp_mapping(mapping: Mapping, token_count: int) -> Mapping:
# We use heuristic to determine the lm_head_tp_size
# Since token_count=256 will hit the boundary of math-bound problem
# We use 256 // token_count to determine the lm_head_tp_size
lm_head_tp_size_raw = 256 // token_count
lm_head_tp_size = nearest_in_buckets(lm_head_tp_size_raw,
[1, mapping.gpus_per_node])
assert mapping.tp_size % lm_head_tp_size == 0
lm_head_pp_size = mapping.pp_size * mapping.tp_size // lm_head_tp_size
return Mapping(
world_size=lm_head_tp_size * lm_head_pp_size,
rank=mapping.rank,
gpus_per_node=mapping.gpus_per_node,
tp_size=lm_head_tp_size,
pp_size=lm_head_pp_size,
enable_attention_dp=mapping.enable_attention_dp,
enable_lm_head_tp_in_adp=mapping.enable_lm_head_tp_in_adp,
)
def get_device_uuid(device_idx: int) -> str:
"""Get the UUID of a CUDA device using torch cuda api"""
property = torch.cuda.get_device_properties(device_idx)
uuid = "GPU-" + str(property.uuid)
return uuid
def maybe_compile(func=None, **compile_kwargs):
"""
Conditionally compile a function with torch.compile.
If is_piecewise_running() is True, the function will not be compiled to avoid host overhead in attention op.
Args:
func: The function to decorate (optional, for direct decoration).
**compile_kwargs: Keyword arguments for torch.compile.
Returns:
The conditionally compiled function..
"""
def decorator(f):
compiled_func = torch.compile(f, **compile_kwargs)
def wrapper(*args, **kwargs):
if is_piecewise_running():
return f(*args, **kwargs)
return compiled_func(*args, **kwargs)
return wrapper
return decorator(func) if func else decorator