Skip to content

Commit e84d1fa

Browse files
Add kernel-run-level tensor blob save controls
Summary: Add env vars TRITONPARSE_TENSOR_SAVE_SKIP_RUNS and TRITONPARSE_TENSOR_SAVE_MAX_RUNS to skip/limit tensor blob saving at the kernel run granularity. A "kernel run" counts all autotune benchmark launches + the winner launch for a single kernel invocation as one run (autotune benchmarks are detected via stack frame inspection and excluded from the run counter). While disk writes are already deduped via content-addressed BLAKE2b hashing, the serialization + hashing overhead still occurs per launch. These controls let users skip that overhead for runs they don't need blobs from. - `TRITONPARSE_TENSOR_SAVE_SKIP_RUNS=N`: skip blob saving for first N kernel runs - `TRITONPARSE_TENSOR_SAVE_MAX_RUNS=N`: save blobs for at most N runs after skipping (0=unlimited) Also exposes these as `tensor_save_skip_runs` / `tensor_save_max_runs` params on `init()` and `TritonParseManager`. Reviewed By: FindHao Differential Revision: D96661726 fbshipit-source-id: c423799da37d21ead475740d3b33326444bd4fac
1 parent 04ea4d3 commit e84d1fa

File tree

3 files changed

+134
-2
lines changed

3 files changed

+134
-2
lines changed

tests/gpu/test_tensor_blob.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,74 @@ def count_all_blobs(manager_dir_path):
265265
)
266266
print("✓ Storage correctly disabled when enable_tensor_blob_storage=False")
267267

268+
# === Test 5: Skip/Max Runs ===
269+
print("\n=== Test 5: Skip/Max Runs ===")
270+
271+
# Test 5a: tensor_save_max_runs=1 — only first kernel run saves blobs
272+
print("\n--- Test 5a: max_runs=1 ---")
273+
temp_output_dir_5a = tempfile.mkdtemp()
274+
275+
with tritonparse.context_manager.TritonParseManager(
276+
enable_trace_launch=True,
277+
enable_tensor_blob_storage=True,
278+
tensor_save_max_runs=1,
279+
out=temp_output_dir_5a,
280+
) as manager:
281+
# Run 3 kernels with different inputs so blobs are not deduped
282+
for i in range(3):
283+
x = torch.randn(
284+
(512,), device=self.cuda_device, dtype=torch.float32
285+
) * (i + 1)
286+
y = run_kernel(x)
287+
y.sum()
288+
torch.cuda.synchronize()
289+
290+
blobs_5a = count_all_blobs(manager.dir_path)
291+
print(f" Blobs with max_runs=1 over 3 launches: {blobs_5a}")
292+
# With max_runs=1, only the first kernel run saves blobs (input + output = 2)
293+
# Runs 2 and 3 should not save any blobs
294+
self.assertGreater(
295+
blobs_5a, 0, "Should have at least 1 blob from first run"
296+
)
297+
self.assertLessEqual(
298+
blobs_5a, 2, "Should have at most 2 blobs (input+output of first run)"
299+
)
300+
print(f"✓ max_runs=1: {blobs_5a} blob(s) saved (only first kernel run)")
301+
302+
# Test 5b: skip_runs=1, max_runs=2 — only second and third kernel runs save blobs
303+
print("\n--- Test 5b: skip_runs=1, max_runs=2 ---")
304+
temp_output_dir_5b = tempfile.mkdtemp()
305+
306+
with tritonparse.context_manager.TritonParseManager(
307+
enable_trace_launch=True,
308+
enable_tensor_blob_storage=True,
309+
tensor_save_skip_runs=1,
310+
tensor_save_max_runs=2,
311+
out=temp_output_dir_5b,
312+
) as manager:
313+
for i in range(4):
314+
x = torch.randn(
315+
(512,), device=self.cuda_device, dtype=torch.float32
316+
) * (i + 1)
317+
y = run_kernel(x)
318+
y.sum()
319+
torch.cuda.synchronize()
320+
321+
blobs_5b = count_all_blobs(manager.dir_path)
322+
print(f" Blobs with skip=1, max=2 over 4 launches: {blobs_5b}")
323+
# skip=1 skips first run, max=2 saves runs 2 and 3, skips run 4
324+
# Each saved run has input + output = 2 blobs, so expect 3-4 blobs
325+
# (dedup may reduce if outputs happen to match)
326+
self.assertGreater(
327+
blobs_5b, 0, "Should have at least 1 blob from saved runs"
328+
)
329+
self.assertLessEqual(
330+
blobs_5b,
331+
4,
332+
"Should have at most 4 blobs (input+output of runs 2 and 3)",
333+
)
334+
print(f"✓ skip=1, max=2: {blobs_5b} blob(s) saved (runs 2 and 3 only)")
335+
268336
# Clean up all test outputs
269337
try:
270338
if TEST_KEEP_OUTPUT:
@@ -273,14 +341,18 @@ def count_all_blobs(manager_dir_path):
273341
f" Test 1: {temp_output_dir_1}\n"
274342
f" Test 2: {temp_output_dir_2}\n"
275343
f" Test 3: {temp_output_dir_3}\n"
276-
f" Test 4: {temp_output_dir_4}"
344+
f" Test 4: {temp_output_dir_4}\n"
345+
f" Test 5a: {temp_output_dir_5a}\n"
346+
f" Test 5b: {temp_output_dir_5b}"
277347
)
278348
else:
279349
for temp_dir in [
280350
temp_output_dir_1,
281351
temp_output_dir_2,
282352
temp_output_dir_3,
283353
temp_output_dir_4,
354+
temp_output_dir_5a,
355+
temp_output_dir_5b,
284356
]:
285357
if os.path.exists(temp_dir):
286358
shutil.rmtree(temp_dir)

tritonparse/context_manager.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def __init__(
2020
split_inductor_compilations=True,
2121
enable_tensor_blob_storage=False,
2222
tensor_storage_quota=None,
23+
tensor_save_skip_runs=None,
24+
tensor_save_max_runs=None,
2325
log_dir=None,
2426
keep_logs=False,
2527
**parse_kwargs,
@@ -32,6 +34,8 @@ def __init__(
3234
split_inductor_compilations: Whether to split inductor compilations in the output
3335
enable_tensor_blob_storage: Whether to enable tensor blob storage
3436
tensor_storage_quota: Storage quota in bytes for tensor blobs (default: 100GB)
37+
tensor_save_skip_runs: Skip tensor blob saving for the first N kernel runs
38+
tensor_save_max_runs: Save tensor blobs for at most N kernel runs after skipping
3539
log_dir: Optional directory path to store raw trace logs. If not provided,
3640
a temporary directory will be created and cleaned up after parsing.
3741
If provided, the directory will be created if it doesn't exist and
@@ -45,6 +49,8 @@ def __init__(
4549
self.split_inductor_compilations = split_inductor_compilations
4650
self.enable_tensor_blob_storage = enable_tensor_blob_storage
4751
self.tensor_storage_quota = tensor_storage_quota
52+
self.tensor_save_skip_runs = tensor_save_skip_runs
53+
self.tensor_save_max_runs = tensor_save_max_runs
4854
self.user_log_dir = log_dir
4955
self.keep_logs = keep_logs
5056
self.parse_kwargs = parse_kwargs
@@ -69,6 +75,10 @@ def __enter__(self):
6975
}
7076
if self.tensor_storage_quota is not None:
7177
init_kwargs["tensor_storage_quota"] = self.tensor_storage_quota
78+
if self.tensor_save_skip_runs is not None:
79+
init_kwargs["tensor_save_skip_runs"] = self.tensor_save_skip_runs
80+
if self.tensor_save_max_runs is not None:
81+
init_kwargs["tensor_save_max_runs"] = self.tensor_save_max_runs
7282

7383
init(self.dir_path, **init_kwargs)
7484
return self

tritonparse/structured_logging.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@
8989

9090
# The flag to mark if launch is traced. It is used to avoid initilizing the launch hook twice.
9191
_trace_launch_enabled = False
92+
# Kernel run counter and per-launch blob save flag for skip/max runs gating
93+
_kernel_run_count = 0
94+
_save_blobs_for_current_launch = True
9295
# Enable tensor blob storage
9396
TRITONPARSE_SAVE_TENSOR_BLOBS = os.getenv("TRITONPARSE_SAVE_TENSOR_BLOBS", "0") in [
9497
"1",
@@ -109,6 +112,14 @@
109112
TRITONPARSE_COMPRESSION_LEVEL = 4
110113
# Log statistics every N saved blobs
111114
TRITONPARSE_STATS_LOG_FREQUENCY = 100
115+
# Skip tensor blob saving for the first N kernel runs (0 = no skip)
116+
TRITONPARSE_TENSOR_SAVE_SKIP_RUNS = int(
117+
os.getenv("TRITONPARSE_TENSOR_SAVE_SKIP_RUNS", "0")
118+
)
119+
# Save tensor blobs for at most N kernel runs after skipping (0 = unlimited)
120+
TRITONPARSE_TENSOR_SAVE_MAX_RUNS = int(
121+
os.getenv("TRITONPARSE_TENSOR_SAVE_MAX_RUNS", "0")
122+
)
112123

113124
TRITON_TRACE_HANDLER = None
114125
# Global tensor blob manager instance
@@ -607,7 +618,11 @@ def _log_torch_tensor_info(tensor_value):
607618
arg_info["tensor_capture_error"] = str(e)
608619

609620
# Add tensor blob storage if enabled
610-
if TRITONPARSE_SAVE_TENSOR_BLOBS and TENSOR_BLOB_MANAGER is not None:
621+
if (
622+
TRITONPARSE_SAVE_TENSOR_BLOBS
623+
and TENSOR_BLOB_MANAGER is not None
624+
and _save_blobs_for_current_launch
625+
):
611626
blob_info = TENSOR_BLOB_MANAGER.save_tensor_blob(tensor_value)
612627
arg_info.update(blob_info)
613628
return arg_info
@@ -1428,6 +1443,8 @@ def extract_arg_info(arg_dict):
14281443

14291444

14301445
def add_launch_metadata(grid, metadata, arg_dict, inductor_args=None):
1446+
global _kernel_run_count, _save_blobs_for_current_launch
1447+
14311448
# Check if we're in CUDA graph capture mode - if so, skip detailed argument extraction
14321449
# to avoid CUDA errors (cudaErrorStreamCaptureUnsupported)
14331450
is_capturing = False
@@ -1448,6 +1465,25 @@ def add_launch_metadata(grid, metadata, arg_dict, inductor_args=None):
14481465
)
14491466
}
14501467

1468+
# Gate tensor blob saving based on skip/max runs
1469+
if TRITONPARSE_SAVE_TENSOR_BLOBS:
1470+
skip = TRITONPARSE_TENSOR_SAVE_SKIP_RUNS
1471+
max_runs = TRITONPARSE_TENSOR_SAVE_MAX_RUNS
1472+
if skip > 0 or max_runs > 0:
1473+
# Only capture the stack when we actually need kernel run counting
1474+
from .parse.sourcemap_utils import _is_autotune_benchmark_launch
1475+
1476+
if not _is_autotune_benchmark_launch(get_stack_trace()):
1477+
_kernel_run_count += 1
1478+
_save_blobs_for_current_launch = not (
1479+
_kernel_run_count <= skip
1480+
or (max_runs > 0 and _kernel_run_count > skip + max_runs)
1481+
)
1482+
else:
1483+
_save_blobs_for_current_launch = True
1484+
else:
1485+
_save_blobs_for_current_launch = False
1486+
14511487
# Extract detailed argument information (only when NOT capturing)
14521488
extracted_args = extract_arg_info(arg_dict)
14531489
extracted_inductor_args = extract_arg_info(inductor_args) if inductor_args else {}
@@ -1691,6 +1727,8 @@ def init(
16911727
enable_tensor_blob_storage: bool = False,
16921728
tensor_storage_quota: Optional[int] = None,
16931729
compression: Optional[str] = None,
1730+
tensor_save_skip_runs: Optional[int] = None,
1731+
tensor_save_max_runs: Optional[int] = None,
16941732
):
16951733
"""
16961734
This function is a wrapper around init_basic() that also sets up the compilation listener. Its arguments have higher priority than the environment variables for same settings.
@@ -1712,12 +1750,15 @@ def init(
17121750
tensor_storage_quota (Optional[int]): Storage quota in bytes for tensor blobs (default: 100GB).
17131751
compression (Optional[str]): Compression format for trace files ("none", "gzip", or "clp").
17141752
If not specified, respects TRITON_TRACE_COMPRESSION env var, or defaults to "none".
1753+
tensor_save_skip_runs (Optional[int]): Skip tensor blob saving for the first N kernel runs.
1754+
tensor_save_max_runs (Optional[int]): Save tensor blobs for at most N kernel runs after skipping.
17151755
"""
17161756
global TRITON_TRACE_LAUNCH, TRITON_TRACE_LAUNCH_WITHIN_PROFILING
17171757
global TRITONPARSE_MORE_TENSOR_INFORMATION
17181758
global TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK, TRITONPARSE_DUMP_SASS
17191759
global TRITONPARSE_SAVE_TENSOR_BLOBS, TRITONPARSE_TENSOR_STORAGE_QUOTA
17201760
global TRITON_TRACE_COMPRESSION
1761+
global TRITONPARSE_TENSOR_SAVE_SKIP_RUNS, TRITONPARSE_TENSOR_SAVE_MAX_RUNS
17211762

17221763
# Set global flags BEFORE calling init_basic, so init_logs() can see them
17231764
# TRITON_TRACE_LAUNCH and TRITON_TRACE_LAUNCH_WITHIN_PROFILING are mutually exclusive.
@@ -1750,6 +1791,12 @@ def init(
17501791
if os.getenv("TRITON_TRACE_COMPRESSION") is None:
17511792
TRITON_TRACE_COMPRESSION = compression
17521793

1794+
# Set tensor save skip/max runs (Python API overrides env var)
1795+
if tensor_save_skip_runs is not None:
1796+
TRITONPARSE_TENSOR_SAVE_SKIP_RUNS = tensor_save_skip_runs
1797+
if tensor_save_max_runs is not None:
1798+
TRITONPARSE_TENSOR_SAVE_MAX_RUNS = tensor_save_max_runs
1799+
17531800
init_basic(trace_folder)
17541801
from triton import knobs
17551802

@@ -1779,6 +1826,7 @@ def clear_logging_config():
17791826
global TRITON_TRACE_HANDLER, triton_trace_folder, _KERNEL_ALLOWLIST_PATTERNS
17801827
global _trace_launch_enabled
17811828
global TENSOR_BLOB_MANAGER
1829+
global _kernel_run_count, _save_blobs_for_current_launch
17821830
# 1. Clean up the log handler
17831831
if TRITON_TRACE_HANDLER is not None:
17841832
if TRITON_TRACE_HANDLER in triton_trace_log.handlers:
@@ -1793,6 +1841,8 @@ def clear_logging_config():
17931841

17941842
# 3. Reset tensor blob manager and related flags
17951843
TENSOR_BLOB_MANAGER = None
1844+
_kernel_run_count = 0
1845+
_save_blobs_for_current_launch = True
17961846

17971847
# 4. Reset Triton knobs
17981848
# Check if triton was actually imported and used

0 commit comments

Comments
 (0)