Skip to content

Commit 7e4cef9

Browse files
authored
[None][fix] Cherry-pick conflict changes for PR 7999 PR 8515 (#9446)
Signed-off-by: Jin Li <[email protected]>
1 parent d8b5aeb commit 7e4cef9

File tree

5 files changed

+211
-46
lines changed

5 files changed

+211
-46
lines changed

cpp/tensorrt_llm/thop/fp8BlockScalingGemm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ TRTLLM_NAMESPACE_END
411411

412412
TORCH_LIBRARY_FRAGMENT(trtllm, m)
413413
{
414-
m.def("fp8_block_scaling_gemm(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale) -> Tensor");
414+
m.def("fp8_block_scaling_gemm_impl(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale) -> Tensor");
415415
m.def(
416416
"fp8_block_scaling_bmm(Tensor mat1, Tensor mat2, Tensor mat1Scale, Tensor mat2Scale, ScalarType? "
417417
"out_dtype=None) -> Tensor");
@@ -425,7 +425,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
425425

426426
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
427427
{
428-
m.impl("fp8_block_scaling_gemm", &tensorrt_llm::torch_ext::fp8_block_scaling_gemm);
428+
m.impl("fp8_block_scaling_gemm_impl", &tensorrt_llm::torch_ext::fp8_block_scaling_gemm);
429429
m.impl("fp8_block_scaling_bmm", &tensorrt_llm::torch_ext::fp8_block_scaling_bmm);
430430
m.impl("fp8_block_scaling_bmm_out", &tensorrt_llm::torch_ext::fp8_block_scaling_bmm_out);
431431
m.impl("fp8_block_scaling_moe_gemm", &tensorrt_llm::torch_ext::fp8_block_scaling_moe_gemm);

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def _(logits, seq_lens, indices, next_n, index_topk):
201201
def _(input, force_applying_finalize):
202202
return torch.empty_like(input)
203203

204-
@torch.library.register_fake("trtllm::fp8_block_scaling_gemm")
204+
@torch.library.register_fake("trtllm::fp8_block_scaling_gemm_impl")
205205
def _(a, b, a_scale, b_scale):
206206
m = a.shape[0]
207207
n = b.shape[0]

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,7 +1441,7 @@ def _(
14411441
return input.new_empty((M, N), dtype=output_dtype)
14421442

14431443

1444-
def fp8_swap_ab_gen_tuning_buckets(x: int):
1444+
def deep_gemm_gen_tuning_buckets(x: int):
14451445
buckets = tuple(range(8, 128, 8))
14461446
if x >= 128:
14471447
buckets += tuple(range(128, x, 128))
@@ -1451,7 +1451,7 @@ def fp8_swap_ab_gen_tuning_buckets(x: int):
14511451
class fp8SwapABGemmRunner(TunableRunner):
14521452
tuning_config = TuningConfig(
14531453
dynamic_tensor_specs=(DynamicTensorSpec(
1454-
0, 0, fp8_swap_ab_gen_tuning_buckets), ),
1454+
0, 0, deep_gemm_gen_tuning_buckets), ),
14551455
tune_max_num_tokens=4096,
14561456
)
14571457

@@ -1536,6 +1536,78 @@ def _(
15361536
return input.new_empty((input.size(0), weight.size(0)), dtype=output_dtype)
15371537

15381538

1539+
# The runner is used to trigger deepgemm jit during autotune.
1540+
class Fp8BlockScalingGemmRunner(TunableRunner):
1541+
tuning_config = TuningConfig(
1542+
dynamic_tensor_specs=(DynamicTensorSpec(
1543+
0, 0, deep_gemm_gen_tuning_buckets), ),
1544+
tune_max_num_tokens=4096,
1545+
)
1546+
1547+
def get_valid_tactics(
1548+
self,
1549+
inputs: List[torch.Tensor],
1550+
profile: OptimizationProfile,
1551+
) -> List[int]:
1552+
return [0]
1553+
1554+
def forward(
1555+
self,
1556+
inputs: List[torch.Tensor],
1557+
tactic: int = -1,
1558+
) -> torch.Tensor:
1559+
a, b, a_scale, b_scale = inputs
1560+
return torch.ops.trtllm.fp8_block_scaling_gemm_impl(
1561+
a, b, a_scale, b_scale)
1562+
1563+
1564+
def get_fp8_block_scaling_gemm_constraint_spec():
1565+
# The implementation aligns with the fp8_quantize_1x128 custom op.
1566+
def fp8_quantize_1x128_sm90_constrant(inputs: List[List[int]]):
1567+
pad_m = fp4_utils.pad_up(inputs[0][0], 4)
1568+
blocked_n = (inputs[0][1] + 127) // 128
1569+
return fp4_utils.pad_up(pad_m * blocked_n * 4, 128) // 4
1570+
1571+
if get_sm_version() >= 100:
1572+
return (ConstraintSpec(2, 1, lambda inputs: inputs[0][0]), )
1573+
else:
1574+
return (ConstraintSpec(2, 0, fp8_quantize_1x128_sm90_constrant), )
1575+
1576+
1577+
@torch.library.custom_op("trtllm::fp8_block_scaling_gemm", mutates_args=())
1578+
def fp8_block_scaling_gemm(
1579+
a: torch.Tensor,
1580+
b: torch.Tensor,
1581+
a_scale: torch.Tensor,
1582+
b_scale: torch.Tensor,
1583+
tune_max_num_tokens: int = 4096,
1584+
) -> torch.Tensor:
1585+
tuner = AutoTuner.get()
1586+
fp8_block_scaling_gemm_runner = Fp8BlockScalingGemmRunner()
1587+
Fp8BlockScalingGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
1588+
1589+
Fp8BlockScalingGemmRunner.tuning_config.constraint_specs = get_fp8_block_scaling_gemm_constraint_spec(
1590+
)
1591+
1592+
_, best_tactic = tuner.choose_one(
1593+
"trtllm::fp8_block_scaling_gemm",
1594+
[fp8_block_scaling_gemm_runner],
1595+
Fp8BlockScalingGemmRunner.tuning_config,
1596+
[a, b, a_scale, b_scale],
1597+
)
1598+
return fp8_block_scaling_gemm_runner(
1599+
inputs=[a, b, a_scale, b_scale],
1600+
tactic=best_tactic,
1601+
)
1602+
1603+
1604+
@fp8_block_scaling_gemm.register_fake
1605+
def _(a, b, a_scale, b_scale, tune_max_num_tokens=4096):
1606+
m = a.shape[0]
1607+
n = b.shape[0]
1608+
return a.new_empty((m, n), dtype=torch.bfloat16)
1609+
1610+
15391611
@torch.library.custom_op("trtllm::silu_and_mul", mutates_args=())
15401612
def silu_and_mul(x: torch.Tensor,
15411613
scale: Optional[torch.Tensor] = None,

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,17 +155,26 @@ def _create_dummy_mm_context_request(
155155
dummy_mm_prompt = input_processor.get_dummy_prompt(input_seq_len)
156156

157157
if dummy_mm_prompt is not None:
158-
prompt_token_ids, extra_processed_inputs = self._model_engine.input_processor(
158+
prompt_token_ids, extra_processed_inputs = self._model_engine.input_processor_with_hash(
159159
dummy_mm_prompt, sampling_params=None)
160+
161+
multimodal_input = extra_processed_inputs.get(
162+
'multimodal_input')
160163
multimodal_data = extra_processed_inputs.get('multimodal_data')
164+
req_mm_input = trtllm.MultimodalInput(
165+
multimodal_hashes=multimodal_input.multimodal_hashes,
166+
multimodal_positions=multimodal_input.multimodal_positions,
167+
multimodal_lengths=multimodal_input.multimodal_lengths
168+
) if multimodal_input else None
161169

162170
request = trtllm.Request(prompt_token_ids,
163171
max_tokens=1,
164172
streaming=False,
165173
sampling_config=trtllm.SamplingConfig(
166174
beam_width=max_beam_width, ),
167175
output_config=trtllm.OutputConfig(),
168-
end_id=-1)
176+
end_id=-1,
177+
multimodal_input=req_mm_input)
169178
request.py_multimodal_data = multimodal_data
170179
else:
171180
# Fall back to text-only prompt when we could not find the small image size.
@@ -266,9 +275,29 @@ def _get_token_num_for_estimation(self) -> int:
266275
# Requests cannot share KV cache blocks. Round up to nearest integer multiple of block size.
267276
num_cache_blocks += (num_req_tokens + self._tokens_per_block -
268277
1) // self._tokens_per_block
278+
279+
# Max cuda graph warmup required tokens
280+
max_cuda_graph_bs = min(self._model_engine.batch_size,
281+
self._model_engine._max_cuda_graph_batch_size)
282+
cuda_graph_warmup_block = (
283+
self._model_engine.max_seq_len +
284+
1) // self._tokens_per_block + max_cuda_graph_bs - 1
285+
num_cache_blocks = max(cuda_graph_warmup_block, num_cache_blocks)
286+
287+
# This is the minimal blocks required to run with max bs
288+
# If not able to allocate self._model_engine.batch_size blocks, the max batch size should be adjusted.
289+
num_cache_blocks = max(num_cache_blocks, self._model_engine.batch_size)
290+
291+
free_mem, total_mem = torch.cuda.mem_get_info()
292+
max_memory = self._kv_cache_config.free_gpu_memory_fraction * free_mem
293+
max_num_tokens_in_memory = max_memory // self._get_kv_size_per_token(
294+
) // self._tokens_per_block * self._tokens_per_block
295+
269296
# Multiply by beam width, to prevent rescaling of the max_seq_len caused by the influence of beam width during the preparation for kv_cache_estimation
270-
return num_cache_blocks * self._tokens_per_block * self._dummy_reqs[
271-
0].sampling_config.beam_width
297+
return min(
298+
num_cache_blocks * self._tokens_per_block *
299+
self._dummy_reqs[0].sampling_config.beam_width,
300+
max_num_tokens_in_memory)
272301

273302
def try_prepare_estimation(self) -> bool:
274303
"""Prepare for possible KV cache capacity estimation.
@@ -279,8 +308,10 @@ def try_prepare_estimation(self) -> bool:
279308
estimating_kv_cache = False
280309
if 'cp_type' not in self._mapping.cp_config:
281310
estimating_kv_cache = True
282-
self._kv_cache_config.max_tokens = self._get_token_num_for_estimation(
283-
)
311+
estimate_max_tokens = self._get_token_num_for_estimation()
312+
self._kv_cache_config.max_tokens = min(
313+
estimate_max_tokens, self._kv_cache_config.max_tokens
314+
) if self._kv_cache_config.max_tokens is not None else estimate_max_tokens
284315
model_config = self._model_engine.model.model_config
285316
if model_config.attn_backend == "VANILLA":
286317
logger.info(

0 commit comments

Comments
 (0)