Skip to content

Commit 1d74879

Browse files
[ConSan] Make sure kernel is recompiled when consan state changes (#8342)
Include ConSan state in the compilation options and kwargs of the kernel to force jit cache miss. Also, preparing the knobs to unify this behavior with proton by introducing `instrumentation_mode` compiltion knob that can be set to "consan" or "proton" by proton runtime.
1 parent 6e4647e commit 1d74879

File tree

4 files changed

+75
-38
lines changed

4 files changed

+75
-38
lines changed

python/test/gluon/test_consan.py

Lines changed: 65 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -86,40 +86,43 @@ def failing_kernel(input):
8686
ampere.async_copy.wait_group(0)
8787

8888

89-
def run_failing_kernel(device):
89+
def run_failing_kernel(device, enable_consan, mode):
9090
# ConSan requires a global memory allocation
9191
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
9292
return torch.empty(size, device="cuda", dtype=torch.int8)
9393

9494
triton.set_allocator(alloc_fn)
9595

96+
if enable_consan:
97+
if mode == "env":
98+
os.environ["TRITON_INSTRUMENTATION_MODE"] = "consan"
99+
knobs.refresh_knobs()
100+
elif mode == "knob":
101+
knobs.compilation.instrumentation_mode = "consan"
102+
96103
input = torch.randn((XBLOCK, XBLOCK), device=device, dtype=torch.float16)
97104
failing_kernel[(1, )](input)
98105

99106

100107
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper")
101-
def test_cache_miss_knob(device, fresh_knobs, monkeypatch):
108+
def test_cache_miss_knob(device, monkeypatch):
102109
# First run without consan
103-
knobs.compilation.enable_experimental_consan = False
104-
run_failing_kernel(device)
110+
run_in_process(run_failing_kernel, (device, False, "knob"))
105111

106112
# Then run with consan and assert that if fails
107-
knobs.compilation.enable_experimental_consan = True
108113
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
109-
result = run_in_process(run_failing_kernel, (device, ))
114+
result = run_in_process(run_failing_kernel, (device, True, "knob"))
110115
assert "device-side assert" in str(result.exc)
111116

112117

113118
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper")
114119
def test_cache_miss_env(device, monkeypatch):
115120
# First run without consan
116-
knobs.compilation.enable_experimental_consan = False
117-
run_failing_kernel(device)
121+
run_in_process(run_failing_kernel, (device, False, "env"))
118122

119123
# Then run with consan and assert that if fails
120-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
121124
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
122-
result = run_in_process(run_failing_kernel, (device, ))
125+
result = run_in_process(run_failing_kernel, (device, True, "env"))
123126
assert "device-side assert" in str(result.exc)
124127

125128

@@ -133,8 +136,9 @@ def test_async_tma_kernel(FAILURE, device, run_wrapper, monkeypatch):
133136
assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output
134137
return
135138

136-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
139+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
137140
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
141+
knobs.refresh_knobs()
138142

139143
# ConSan requires a global memory allocation
140144
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -182,8 +186,9 @@ def test_tma_interleave_kernel(FAILURE, device, run_wrapper, monkeypatch):
182186
assert result.driver_stderr_output == ""
183187
return
184188

185-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
189+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
186190
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
191+
knobs.refresh_knobs()
187192

188193
# ConSan requires a global memory allocation
189194
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -240,8 +245,9 @@ def test_async_copy(FAILURE, device, run_wrapper, monkeypatch):
240245
assert result.driver_stderr_output == ""
241246
return
242247

243-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
248+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
244249
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
250+
knobs.refresh_knobs()
245251

246252
# ConSan requires a global memory allocation
247253
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -292,8 +298,9 @@ def test_tcgen5_mma(FAILURE, MEM_ACCESS_KIND, device, run_wrapper, monkeypatch):
292298
assert result.driver_stderr_output == ""
293299
return
294300

295-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
301+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
296302
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
303+
knobs.refresh_knobs()
297304

298305
# ConSan requires a global memory allocation
299306
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -355,8 +362,9 @@ def test_warpgroup_mma(FAILURE, device, run_wrapper, monkeypatch):
355362
assert result.driver_stderr_output == ""
356363
return
357364

358-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
365+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
359366
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
367+
knobs.refresh_knobs()
360368

361369
# ConSan requires a global memory allocation
362370
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -399,8 +407,9 @@ def test_warpgroup_mma2(FAILURE, device, run_wrapper, monkeypatch):
399407
assert result.driver_stderr_output == ""
400408
return
401409

402-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
410+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
403411
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
412+
knobs.refresh_knobs()
404413

405414
# ConSan requires a global memory allocation
406415
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -447,8 +456,9 @@ def test_tcgen5_mma_multibar(BUF_IDX, BAR_IDX, device, run_wrapper, monkeypatch)
447456
assert result.exc is None
448457
assert result.driver_stderr_output == ""
449458
return
450-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
459+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
451460
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
461+
knobs.refresh_knobs()
452462

453463
# ConSan requires a global memory allocation
454464
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -505,8 +515,9 @@ def test_multibuffered_loop(FAILURE, device, run_wrapper, monkeypatch):
505515
assert result.driver_stderr_output == ""
506516
return
507517

508-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
518+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
509519
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
520+
knobs.refresh_knobs()
510521

511522
# ConSan requires a global memory allocation
512523
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -619,8 +630,9 @@ def test_multibuffered_wgmma_loop(FAILURE, device, run_wrapper, monkeypatch):
619630
assert result.driver_stderr_output == ""
620631
return
621632

622-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
633+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
623634
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
635+
knobs.refresh_knobs()
624636

625637
# ConSan requires a global memory allocation
626638
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -699,8 +711,9 @@ def test_ws_store_wait_load(FAILURE, device, run_wrapper, monkeypatch):
699711
assert result.exc is None
700712
assert result.driver_stderr_output == ""
701713
return
702-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
714+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
703715
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
716+
knobs.refresh_knobs()
704717

705718
# ConSan requires a global memory allocation
706719
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -752,8 +765,9 @@ def test_ws_load_wait_store(FAILURE, device, run_wrapper, monkeypatch):
752765
assert result.exc is None
753766
assert result.driver_stderr_output == ""
754767
return
755-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
768+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
756769
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
770+
knobs.refresh_knobs()
757771

758772
# ConSan requires a global memory allocation
759773
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -805,8 +819,9 @@ def test_ws_two_loads_two_bars(MISSING_BAR, device, run_wrapper, monkeypatch):
805819
assert result.exc is None
806820
assert result.driver_stderr_output == ""
807821
return
808-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
822+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
809823
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
824+
knobs.refresh_knobs()
810825

811826
# ConSan requires a global memory allocation
812827
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -867,8 +882,9 @@ def test_ws_two_loads_one_bar(FAILURE, device, run_wrapper, monkeypatch):
867882
assert result.exc is None
868883
assert result.driver_stderr_output == ""
869884
return
870-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
885+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
871886
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
887+
knobs.refresh_knobs()
872888

873889
# ConSan requires a global memory allocation
874890
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -929,8 +945,9 @@ def test_ws_two_loads_two_bars_loop(MISSING_BAR, device, run_wrapper, monkeypatc
929945
assert result.exc is None
930946
assert result.driver_stderr_output == ""
931947
return
932-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
948+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
933949
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
950+
knobs.refresh_knobs()
934951

935952
# ConSan requires a global memory allocation
936953
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -1009,8 +1026,9 @@ def test_ws_load_ordering(FAILURE, device, run_wrapper, monkeypatch):
10091026
assert result.exc is None
10101027
assert result.driver_stderr_output == ""
10111028
return
1012-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
1029+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
10131030
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
1031+
knobs.refresh_knobs()
10141032

10151033
# ConSan requires a global memory allocation
10161034
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -1073,8 +1091,9 @@ def test_ws_two_producers_two_consumers(MISSING_BAR, device, run_wrapper, monkey
10731091
assert result.exc is None
10741092
assert result.driver_stderr_output == ""
10751093
return
1076-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
1094+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
10771095
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
1096+
knobs.refresh_knobs()
10781097

10791098
# ConSan requires a global memory allocation
10801099
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -1160,8 +1179,9 @@ def test_ws_different_warp_sizes(MISSING_BAR, device, run_wrapper, monkeypatch):
11601179
assert result.exc is None
11611180
assert result.driver_stderr_output == ""
11621181
return
1163-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
1182+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
11641183
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
1184+
knobs.refresh_knobs()
11651185

11661186
# ConSan requires a global memory allocation
11671187
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -1229,8 +1249,9 @@ def test_ws_async_copy_commits(FAILURE, device, run_wrapper, monkeypatch):
12291249
assert result.driver_stderr_output == ""
12301250
return
12311251

1232-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
1252+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
12331253
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
1254+
knobs.refresh_knobs()
12341255

12351256
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
12361257
return torch.empty(size, device="cuda", dtype=torch.int8)
@@ -1292,8 +1313,9 @@ def test_ws_async_copy_wait_visibility(FAILURE, device, run_wrapper, monkeypatch
12921313
assert result.driver_stderr_output == ""
12931314
return
12941315

1295-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
1316+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
12961317
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
1318+
knobs.refresh_knobs()
12971319

12981320
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
12991321
return torch.empty(size, device="cuda", dtype=torch.int8)
@@ -1344,8 +1366,9 @@ def test_ws_wgmma_wait_visibility(FAILURE, device, run_wrapper, monkeypatch):
13441366
assert result.driver_stderr_output == ""
13451367
return
13461368

1347-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
1369+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
13481370
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
1371+
knobs.refresh_knobs()
13491372

13501373
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
13511374
return torch.empty(size, device="cuda", dtype=torch.int8)
@@ -1392,8 +1415,9 @@ def test_deadlock_two_partitions(device, run_wrapper, monkeypatch):
13921415
assert "device-side assert" in str(result.exc)
13931416
assert "Deadlock detected" in result.driver_stderr_output
13941417
return
1395-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
1418+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
13961419
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
1420+
knobs.refresh_knobs()
13971421

13981422
# ConSan requires a global memory allocation
13991423
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -1426,8 +1450,9 @@ def test_deadlock_overarrival(device, run_wrapper, monkeypatch):
14261450
assert "device-side assert" in str(result.exc)
14271451
assert "Deadlock detected" in result.driver_stderr_output
14281452
return
1429-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
1453+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
14301454
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
1455+
knobs.refresh_knobs()
14311456

14321457
# ConSan requires a global memory allocation
14331458
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -1455,8 +1480,9 @@ def test_deadlock_underarrival(device, run_wrapper, monkeypatch):
14551480
assert "device-side assert" in str(result.exc)
14561481
assert "Deadlock detected" in result.driver_stderr_output
14571482
return
1458-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
1483+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
14591484
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
1485+
knobs.refresh_knobs()
14601486

14611487
# ConSan requires a global memory allocation
14621488
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -1491,8 +1517,9 @@ def test_deadlock_different_phases(device, run_wrapper, monkeypatch):
14911517
assert result.exc is None
14921518
assert result.driver_stderr_output == ""
14931519
return
1494-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
1520+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
14951521
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
1522+
knobs.refresh_knobs()
14961523

14971524
# ConSan requires a global memory allocation
14981525
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -1526,8 +1553,9 @@ def test_deadlock_exempt_when_tma_signals(device, run_wrapper, monkeypatch):
15261553
assert result.exc is None
15271554
assert result.driver_stderr_output == ""
15281555
return
1529-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
1556+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
15301557
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
1558+
knobs.refresh_knobs()
15311559

15321560
# ConSan requires a global memory allocation
15331561
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
@@ -1569,8 +1597,9 @@ def test_barrier_underflow(device, run_wrapper, monkeypatch):
15691597
assert "device-side assert" in str(result.exc)
15701598
assert "Barrier arrive underflow: current count would become negative" in result.driver_stderr_output
15711599
return
1572-
monkeypatch.setenv("TRITON_ENABLE_EXPERIMENTAL_CONSAN", "1")
1600+
monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan")
15731601
monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1")
1602+
knobs.refresh_knobs()
15741603

15751604
# ConSan requires a global memory allocation
15761605
def alloc_fn(size: int, alignment: int, stream: Optional[int]):

python/triton/knobs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,9 @@ class compilation_knobs(base_knobs):
364364
disable_line_info: env_bool = env_bool("TRITON_DISABLE_LINE_INFO")
365365
front_end_debugging: env_bool = env_bool("TRITON_FRONT_END_DEBUGGING")
366366
allow_non_constexpr_globals: env_bool = env_bool("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS")
367-
enable_experimental_consan: env_bool = env_bool("TRITON_ENABLE_EXPERIMENTAL_CONSAN")
367+
# Instrumentation mode is checked on every run, which is expensive.
368+
# We cache the value here to avoid the expensive check on every run.
369+
instrumentation_mode: str = env_str("TRITON_INSTRUMENTATION_MODE", "").get()
368370
listener: Union[CompilationListener, None] = None
369371

370372

@@ -533,3 +535,4 @@ class proton_knobs(base_knobs):
533535

534536
def refresh_knobs():
535537
runtime.debug = env_bool("TRITON_DEBUG").get()
538+
compilation.instrumentation_mode = env_str("TRITON_INSTRUMENTATION_MODE", "").get()

python/triton/runtime/jit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,7 @@ def _pack_args(self, backend, kwargs, bound_args, specialization, options):
650650

651651
def run(self, *args, grid, warmup, **kwargs):
652652
kwargs["debug"] = kwargs.get("debug", self.debug) or knobs.runtime.debug
653+
kwargs["instrumentation_mode"] = knobs.compilation.instrumentation_mode
653654

654655
# parse options
655656
device = driver.active.get_current_device()

third_party/nvidia/backend/compiler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ def __init__(self, target: GPUTarget) -> None:
170170
self.binary_ext = "cubin"
171171

172172
def parse_options(self, opts) -> Any:
173+
# Enable debug mode for ConSan, so device-side assertions are not optimized out
174+
if "instrumentation_mode" in opts and opts["instrumentation_mode"] == "consan":
175+
opts["debug"] = True
176+
173177
args = {'arch': knobs.runtime.override_arch or f"sm{self.target.arch}"}
174178
args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None})
175179
capability = int(self._parse_arch(args["arch"]))
@@ -353,7 +357,7 @@ def make_llir(self, src, metadata, options, capability):
353357
passes.gluon.add_inliner(pm)
354358
nvidia.passes.ttgpuir.add_allocate_shared_memory_nv(pm, capability, ptx_version)
355359
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
356-
if knobs.compilation.enable_experimental_consan:
360+
if knobs.compilation.instrumentation_mode == "consan":
357361
# Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared
358362
passes.ttgpuir.add_concurrency_sanitizer(pm)
359363
passes.ttgpuir.add_allocate_global_scratch_memory(pm)

0 commit comments

Comments
 (0)