Skip to content

Commit eae480b

Browse files
authored
[https://nvbugs/5820874][fix] Adjust deepgemm tuning buckets to cover larger num_tokens's scope (NVIDIA#11259)
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
1 parent 719e82c commit eae480b

File tree

5 files changed

+118
-45
lines changed

5 files changed

+118
-45
lines changed

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,17 +1450,17 @@ def _(
14501450

14511451
def deep_gemm_gen_tuning_buckets(x: int):
14521452
buckets = tuple(range(8, 128, 8))
1453+
# Clamp x to be between 4096 and 8192.
14531454
if x >= 128:
1455+
x = min(x, 8192)
1456+
x = max(x, 4096)
14541457
buckets += tuple(range(128, x, 128))
14551458
return buckets
14561459

14571460

14581461
class fp8SwapABGemmRunner(TunableRunner):
1459-
tuning_config = TuningConfig(
1460-
dynamic_tensor_specs=(DynamicTensorSpec(
1461-
0, 0, deep_gemm_gen_tuning_buckets), ),
1462-
tune_max_num_tokens=4096,
1463-
)
1462+
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
1463+
0, 0, deep_gemm_gen_tuning_buckets), ), )
14641464

14651465
def __init__(self, output_dtype: torch.dtype, disable_ue8m0_cast: bool):
14661466
self.output_dtype = output_dtype
@@ -1477,9 +1477,7 @@ def get_valid_tactics(
14771477
inputs: List[torch.Tensor],
14781478
profile: OptimizationProfile,
14791479
) -> List[int]:
1480-
# Encode swap_ab as False (0) and True (1). Currently enabled when GEMM m <= 128.
1481-
input, _, _ = inputs
1482-
return [0, 1] if input.shape[0] <= 128 else [0]
1480+
return [0]
14831481

14841482
def forward(
14851483
self,
@@ -1494,8 +1492,7 @@ def forward(
14941492
dtype=self.output_dtype,
14951493
)
14961494

1497-
forward_func = deep_gemm.fp8_gemm_ntt if tactic == 1 else deep_gemm.fp8_gemm_nt
1498-
forward_func(
1495+
deep_gemm.fp8_gemm_nt(
14991496
(a, a_sf),
15001497
(weight, weight_scale),
15011498
output,
@@ -1511,14 +1508,13 @@ def fp8_swap_ab_gemm(
15111508
weight_scale: torch.Tensor,
15121509
output_dtype: torch.dtype = torch.bfloat16,
15131510
disable_ue8m0_cast: bool = False,
1514-
tune_max_num_tokens: int = 4096,
15151511
) -> torch.Tensor:
15161512
tuner = AutoTuner.get()
15171513
fp8_swap_ab_gemm_runner = fp8SwapABGemmRunner(
15181514
output_dtype,
15191515
disable_ue8m0_cast,
15201516
)
1521-
fp8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
1517+
15221518
_, best_tactic = tuner.choose_one(
15231519
"trtllm::fp8_swap_ab_gemm",
15241520
[fp8_swap_ab_gemm_runner],
@@ -1538,7 +1534,6 @@ def _(
15381534
weight_scale: torch.Tensor,
15391535
output_dtype: torch.dtype = torch.bfloat16,
15401536
disable_ue8m0_cast: bool = False,
1541-
tune_max_num_tokens: int = 4096,
15421537
) -> torch.Tensor:
15431538
return input.new_empty((input.size(0), weight.size(0)), dtype=output_dtype)
15441539

tests/integration/defs/perf/test_perf_sanity.py

Lines changed: 80 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -549,20 +549,26 @@ def run_cmd(self, server_idx: int) -> List[str]:
549549
server_cmd_with_port = add_host_port_to_cmd(server_cmd, server_hostname, server_port)
550550

551551
server_file_path = os.path.join(self.output_dir, f"trtllm-serve.{server_idx}.log")
552+
server_error_file_path = os.path.join(
553+
self.output_dir, f"trtllm-serve.{server_idx}.error.log"
554+
)
552555

553556
print_info(f"Starting server. cmd is {server_cmd_with_port}")
554-
with open(server_file_path, "w") as server_ctx:
557+
with (
558+
open(server_file_path, "w") as server_ctx,
559+
open(server_error_file_path, "w") as server_err_ctx,
560+
):
555561
server_proc = subprocess.Popen(
556562
server_cmd_with_port,
557563
stdout=server_ctx,
558-
stderr=subprocess.STDOUT,
564+
stderr=server_err_ctx,
559565
env=copy.deepcopy(os.environ),
560566
)
561567

562568
wait_for_endpoint_ready(
563569
f"http://{server_hostname}:{server_port}/health",
564570
timeout=self.timeout,
565-
check_files=[server_file_path],
571+
check_files=[server_file_path, server_error_file_path],
566572
server_proc=server_proc,
567573
)
568574

@@ -571,20 +577,27 @@ def run_cmd(self, server_idx: int) -> List[str]:
571577
client_file_path = os.path.join(
572578
self.output_dir, f"trtllm-benchmark.{server_idx}.{client_idx}.log"
573579
)
580+
client_error_file_path = os.path.join(
581+
self.output_dir, f"trtllm-benchmark.{server_idx}.{client_idx}.error.log"
582+
)
574583

575584
client_cmd_with_port = add_host_port_to_cmd(
576585
client_cmd, server_hostname, server_port
577586
)
578587
print_info(f"Starting client. cmd is {client_cmd_with_port}")
579588

580-
output = subprocess.check_output(
589+
result = subprocess.run(
581590
client_cmd_with_port,
582-
stderr=subprocess.STDOUT,
591+
capture_output=True,
583592
env=copy.deepcopy(os.environ),
584-
).decode()
593+
check=True,
594+
)
595+
output = result.stdout.decode()
585596

586597
with open(client_file_path, "w") as client_ctx:
587598
client_ctx.write(output)
599+
with open(client_error_file_path, "w") as client_err_ctx:
600+
client_err_ctx.write(result.stderr.decode())
588601

589602
outputs.append(output)
590603

@@ -723,7 +736,10 @@ def run_cmd(self, server_idx: int) -> List[str]:
723736
if "CTX" in self.disagg_serving_type or "GEN" in self.disagg_serving_type:
724737
self._generate_hostname_file(server_idx, port)
725738
server_file_path = os.path.join(
726-
self.output_dir, f"trtllm-serve.{server_idx}.{self.disagg_serving_type}.log"
739+
self.output_dir, f"trtllm-serve.{self.disagg_serving_type}.{server_idx}.log"
740+
)
741+
server_error_file_path = os.path.join(
742+
self.output_dir, f"trtllm-serve.{self.disagg_serving_type}.{server_idx}.error.log"
727743
)
728744
is_ctx = "CTX" in self.disagg_serving_type
729745
server_cmd = ctx_cmd if is_ctx else gen_cmd
@@ -732,11 +748,14 @@ def run_cmd(self, server_idx: int) -> List[str]:
732748
print_info(
733749
f"Starting server. disagg_serving_type: {self.disagg_serving_type} cmd is {server_cmd}"
734750
)
735-
with open(server_file_path, "w") as server_ctx:
751+
with (
752+
open(server_file_path, "w") as server_ctx,
753+
open(server_error_file_path, "w") as server_err_ctx,
754+
):
736755
server_proc = subprocess.Popen(
737756
server_cmd,
738757
stdout=server_ctx,
739-
stderr=subprocess.STDOUT,
758+
stderr=server_err_ctx,
740759
env=copy.deepcopy(os.environ),
741760
)
742761
self.wait_for_benchmark_ready(benchmark_status_file)
@@ -747,16 +766,22 @@ def run_cmd(self, server_idx: int) -> List[str]:
747766

748767
elif self.disagg_serving_type == "DISAGG_SERVER":
749768
disagg_server_file_path = os.path.join(
750-
self.output_dir, f"trtllm-serve.{server_idx}.{self.disagg_serving_type}.log"
769+
self.output_dir, f"trtllm-serve.{self.disagg_serving_type}.{server_idx}.log"
770+
)
771+
disagg_server_error_file_path = os.path.join(
772+
self.output_dir, f"trtllm-serve.{self.disagg_serving_type}.{server_idx}.error.log"
751773
)
752774
try:
753775
self._generate_disagg_server_config(server_idx, port)
754776
print_info(f"Starting disagg server. cmd is {disagg_cmd}")
755-
with open(disagg_server_file_path, "w") as disagg_server_ctx:
777+
with (
778+
open(disagg_server_file_path, "w") as disagg_server_ctx,
779+
open(disagg_server_error_file_path, "w") as disagg_server_err_ctx,
780+
):
756781
disagg_server_proc = subprocess.Popen(
757782
disagg_cmd,
758783
stdout=disagg_server_ctx,
759-
stderr=subprocess.STDOUT,
784+
stderr=disagg_server_err_ctx,
760785
env=copy.deepcopy(os.environ),
761786
)
762787
self.wait_for_benchmark_ready(benchmark_status_file)
@@ -770,21 +795,28 @@ def run_cmd(self, server_idx: int) -> List[str]:
770795
disagg_server_hostname, disagg_server_port = (
771796
self._get_disagg_server_hostname_and_port(server_idx)
772797
)
773-
server_files = [
774-
os.path.join(self.output_dir, f"trtllm-serve.{server_idx}.DISAGG_SERVER.log"),
775-
]
776-
for ctx_idx in range(self.num_ctx_servers):
777-
server_files.append(
798+
server_files = (
799+
[
800+
os.path.join(
801+
self.output_dir, f"trtllm-serve.DISAGG_SERVER.{server_idx}.log"
802+
),
778803
os.path.join(
779-
self.output_dir, f"trtllm-serve.{server_idx}.CTX_{ctx_idx}.log"
804+
self.output_dir, f"trtllm-serve.DISAGG_SERVER.{server_idx}.error.log"
805+
),
806+
]
807+
+ [
808+
os.path.join(
809+
self.output_dir, f"trtllm-serve.CTX_{ctx_idx}.{server_idx}.log"
780810
)
781-
)
782-
for gen_idx in range(self.num_gen_servers):
783-
server_files.append(
811+
for ctx_idx in range(self.num_ctx_servers)
812+
]
813+
+ [
784814
os.path.join(
785-
self.output_dir, f"trtllm-serve.{server_idx}.GEN_{gen_idx}.log"
815+
self.output_dir, f"trtllm-serve.GEN_{gen_idx}.{server_idx}.log"
786816
)
787-
)
817+
for gen_idx in range(self.num_gen_servers)
818+
]
819+
)
788820
wait_for_endpoint_ready(
789821
f"http://{disagg_server_hostname}:{disagg_server_port}/health",
790822
timeout=self.timeout,
@@ -796,20 +828,27 @@ def run_cmd(self, server_idx: int) -> List[str]:
796828
benchmark_file_path = os.path.join(
797829
self.output_dir, f"trtllm-benchmark.{server_idx}.{client_idx}.log"
798830
)
831+
benchmark_error_file_path = os.path.join(
832+
self.output_dir, f"trtllm-benchmark.{server_idx}.{client_idx}.error.log"
833+
)
799834

800835
client_cmd_with_port = add_host_port_to_cmd(
801836
client_cmd, disagg_server_hostname, disagg_server_port
802837
)
803838
print_info(f"Starting benchmark. cmd is {client_cmd_with_port}")
804839

805-
output = subprocess.check_output(
840+
result = subprocess.run(
806841
client_cmd_with_port,
842+
capture_output=True,
807843
env=copy.deepcopy(os.environ),
808-
stderr=subprocess.STDOUT,
809-
).decode()
844+
check=True,
845+
)
846+
output = result.stdout.decode()
810847

811848
with open(benchmark_file_path, "w") as benchmark_ctx:
812849
benchmark_ctx.write(output)
850+
with open(benchmark_error_file_path, "w") as benchmark_err_ctx:
851+
benchmark_err_ctx.write(result.stderr.decode())
813852
outputs.append(output)
814853

815854
finally:
@@ -1197,11 +1236,21 @@ def run_ex(self, commands) -> Dict[int, List[str]]:
11971236

11981237
except Exception as e:
11991238
print_error(f"Test command failed for server {server_idx}. Error: {e}")
1200-
if isinstance(e, subprocess.CalledProcessError):
1201-
print_error("--- stdout ---")
1202-
if e.stdout:
1203-
print_error(e.stdout.decode() if isinstance(e.stdout, bytes) else e.stdout)
1204-
print_error("--------------")
1239+
# Print content of trtllm-serve error log files
1240+
error_log_pattern = os.path.join(
1241+
commands.output_dir, f"trtllm-serve*{server_idx}.error.log"
1242+
)
1243+
error_log_files = glob.glob(error_log_pattern)
1244+
for error_log_file in error_log_files:
1245+
if os.path.exists(error_log_file):
1246+
print_error(f"--- {error_log_file} ---")
1247+
with open(error_log_file, "r") as f:
1248+
content = f.read()
1249+
if content.strip():
1250+
print_error(content)
1251+
else:
1252+
print_error("(empty)")
1253+
print_error("-" * len(f"--- {error_log_file} ---"))
12051254
outputs[server_idx] = []
12061255

12071256
return outputs

tests/integration/test_lists/test-db/l0_dgx_b200_perf_sanity.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ l0_dgx_b200_perf_sanity:
2020
# - perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp8_blackwell-r1_fp8_dep8_mtp1_1k1k] TIMEOUT (90) # failed
2121
- perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp8_blackwell-r1_fp8_tp8_mtp3_8k1k]
2222
# - perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp8_blackwell-r1_fp8_dep8_mtp1_8k1k] TIMEOUT (90) # failed
23+
- perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp8_blackwell-r1_fp8_tp8_6k1k] TIMEOUT (90)
2324
# deepseek_r1_fp4_v2_blackwell
2425
- perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp4_v2_blackwell-r1_fp4_v2_tp4_mtp3_1k1k]
2526
- perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp4_v2_blackwell-r1_fp4_v2_tp4_mtp3_8k1k]

tests/scripts/perf-sanity/deepseek_r1_fp8_blackwell.yaml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,31 @@ server_configs:
134134
osl: 1024
135135
backend: "openai"
136136
dataset_file: datasets/perf-ci/deepseek_r1-8k1k-20480-ratio-1_for_serve.json
137+
138+
# 6k1k configs - TP8 with TRTLLM, MTP1
139+
- name: "r1_fp8_tp8_6k1k"
140+
model_name: "deepseek_r1_0528_fp8"
141+
tensor_parallel_size: 8
142+
moe_expert_parallel_size: 1
143+
pipeline_parallel_size: 1
144+
max_batch_size: 512
145+
max_num_tokens: 8192
146+
attn_backend: "TRTLLM"
147+
enable_attention_dp: false
148+
moe_config:
149+
backend: 'TRTLLM'
150+
cuda_graph_config:
151+
enable_padding: true
152+
max_batch_size: 64
153+
kv_cache_config:
154+
dtype: 'fp8'
155+
enable_block_reuse: false
156+
free_gpu_memory_fraction: 0.8
157+
client_configs:
158+
- name: "con64_iter10_6k1k"
159+
concurrency: 64
160+
iterations: 10
161+
isl: 6144
162+
osl: 1024
163+
backend: "openai"
164+
random_range_ratio: 0.2

tests/test_common/http_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import requests
66

7-
ERROR_KEYWORDS = ["RuntimeError", "out of memory", "ValueError"]
7+
ERROR_KEYWORDS = ["RuntimeError", "out of memory", "ValueError", "FileNotFoundError"]
88

99

1010
def wait_for_endpoint_ready(

0 commit comments

Comments
 (0)