Skip to content

Commit 050e7d2

Browse files
authored
[water] Add debug locations (#440)
Signed-off-by: Tim Gymnich <[email protected]>
1 parent 2fed198 commit 050e7d2

File tree

5 files changed

+459
-129
lines changed

5 files changed

+459
-129
lines changed

lit_tests/kernel/wave/mlir_converter.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
from wave_lang.kernel.wave.mlir_converter.mlir_converter import emit_wave_dialect
1818
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
1919
from wave_lang.kernel.wave.utils.general_utils import run_test
20+
from wave_lang.support.location_config import (
21+
LocationCaptureConfig,
22+
LocationCaptureLevel,
23+
)
2024

2125
M = tkl.sym.M
2226
N = tkl.sym.N
@@ -76,6 +80,8 @@ def mlir_converter_matrix_add():
7680
options = WaveCompileOptions(
7781
subs=subs,
7882
compile_to_mlir=True, # Avoid IREE compilation
83+
location_capture_config=LocationCaptureConfig(level=LocationCaptureLevel.NONE),
84+
enforce_locations=False,
7985
)
8086
options = set_default_run_config(options)
8187

@@ -210,6 +216,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
210216
options = WaveCompileOptions(
211217
subs=subs,
212218
compile_to_mlir=True, # Avoid IREE compilation
219+
location_capture_config=LocationCaptureConfig(level=LocationCaptureLevel.NONE),
220+
enforce_locations=False,
213221
)
214222
options = set_default_run_config(options)
215223

@@ -228,7 +236,10 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
228236

229237
# CHECK-LABEL: mlir_converter_matmul
230238
# CHECK: module
231-
# CHECK-NEXT: func.func @kernel(%[[ARG0:.*]]: !wave.tensor<[@M, @K] of f16, <global>>, %[[ARG1:.*]]: !wave.tensor<[@N, @K] of f16, <global>>, %[[ARG2:.*]]: !wave.tensor<[@M, @N] of f32, <global>>
239+
# CHECK-NEXT: func.func @kernel(
240+
# CHECK-SAME: %[[ARG0:.*]]: !wave.tensor<[@M, @K] of f16, <global>>
241+
# CHECK-SAME: %[[ARG1:.*]]: !wave.tensor<[@N, @K] of f16, <global>>
242+
# CHECK-SAME: %[[ARG2:.*]]: !wave.tensor<[@M, @N] of f32, <global>>
232243
# CHECK-SAME: wave.constraints =
233244
# CHECK-SAME: #wave.workgroup_constraint<dim = <"M">, tile_size = <[BLOCK_M] -> (BLOCK_M)>, workgroup_dim = <x>>
234245
# CHECK-SAME: #wave.workgroup_constraint<dim = <"N">, tile_size = <[BLOCK_N] -> (BLOCK_N)>, workgroup_dim = <y>>
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# REQUIRES: water
2+
# RUN: python %s | FileCheck %s
3+
4+
5+
import sympy
6+
from typing import Any
7+
8+
9+
from wave_lang.kernel._support.indexing import IndexSymbol
10+
import wave_lang.kernel.wave as wave
11+
import wave_lang.kernel.lang as tkl
12+
import wave_lang.kernel.wave as tkw
13+
from wave_lang.kernel.lang.global_symbols import *
14+
from wave_lang.kernel.lang.wave_types import *
15+
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
16+
from wave_lang.kernel.wave.constraints import MMAType
17+
from wave_lang.kernel.wave.mlir_converter.mlir_converter import emit_wave_dialect
18+
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
19+
from wave_lang.kernel.wave.utils.general_utils import run_test
20+
from wave_lang.support.location_config import (
21+
LocationCaptureConfig,
22+
LocationCaptureLevel,
23+
)
24+
25+
M = tkl.sym.M
26+
N = tkl.sym.N
27+
BLOCK_M = tkl.sym.BLOCK_M
28+
BLOCK_N = tkl.sym.BLOCK_N
29+
ADDRESS_SPACE_A = tkl.sym.ADDRESS_SPACE_A
30+
ADDRESS_SPACE_B = tkl.sym.ADDRESS_SPACE_B
31+
ADDRESS_SPACE_C = tkl.sym.ADDRESS_SPACE_C
32+
33+
# Define constraints for the kernel
34+
constraints = [
35+
# specifies how computation is tiled
36+
tkw.WorkgroupConstraint(M, BLOCK_M, 0),
37+
tkw.WorkgroupConstraint(N, BLOCK_N, 1),
38+
tkw.WaveConstraint(M, sympy.floor(BLOCK_M / 2)),
39+
tkw.WaveConstraint(N, sympy.floor(BLOCK_N / 2)),
40+
tkw.HardwareConstraint(threads_per_wave=64, vector_shapes={M: BLOCK_M, N: BLOCK_N}),
41+
]
42+
43+
44+
@wave.wave(constraints)
45+
def matrix_add(
46+
# defines matrix in memory of req dimension with specific data types
47+
a: Memory[M, N, ADDRESS_SPACE_A, tkl.f16],
48+
b: Memory[M, N, ADDRESS_SPACE_B, tkl.f16],
49+
c: Memory[M, N, ADDRESS_SPACE_C, tkl.f16],
50+
):
51+
# Initialize the accumulator register with zeroes
52+
c_reg = Register[M, N, tkl.f16](0.0)
53+
54+
# loads values from memory into registers
55+
a_reg = wave.read(a)
56+
b_reg = wave.read(b)
57+
58+
# compute the sum
59+
c_reg = a_reg + b_reg
60+
61+
# writing results back to memory
62+
wave.write(c_reg, c)
63+
64+
65+
@run_test
66+
def mlir_converter_location():
67+
"""Test MLIR converter debug location emission with a matrix addition kernel."""
68+
# Set parameters for compilation
69+
subs: dict[str | IndexSymbol, Any] = {
70+
ADDRESS_SPACE_A: GLOBAL_ADDRESS_SPACE,
71+
ADDRESS_SPACE_B: GLOBAL_ADDRESS_SPACE,
72+
ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE,
73+
BLOCK_M: 64,
74+
BLOCK_N: 64,
75+
M: 128,
76+
N: 128,
77+
}
78+
79+
# Compile the kernel to get the trace
80+
options = WaveCompileOptions(
81+
subs=subs,
82+
compile_to_mlir=True, # Avoid IREE compilation
83+
location_capture_config=LocationCaptureConfig(
84+
level=LocationCaptureLevel.FILE_LINE_COL
85+
),
86+
)
87+
options = set_default_run_config(options)
88+
89+
# Compile the kernel to get the trace
90+
compiled_kernel = wave_compile(options, matrix_add)
91+
92+
# Get the compiled graph from the compiled kernel
93+
trace = compiled_kernel.get_compiled_graph()
94+
95+
constraints = matrix_add.constraints
96+
97+
# Use the mlir_converter to emit wave MLIR dialect
98+
mlir_output, _ = emit_wave_dialect(trace, constraints, options, False)
99+
100+
# Print to stdout for FileCheck
101+
print(mlir_output)
102+
103+
# CHECK-LABEL: mlir_converter_location
104+
# CHECK: #loc = loc("{{.*}}mlir_converter_debug_locations.py":44
105+
# CHECK: module
106+
# CHECK: func.func @kernel(%arg0: !wave.tensor<[@M, @N] of f16> loc("{{.*}}mlir_converter_debug_locations.py":44{{.*}}), %arg1: !wave.tensor<[@M, @N] of f16> loc("{{.*}}mlir_converter_debug_locations.py":44{{.*}}), %arg2: !wave.tensor<[@M, @N] of f16> loc("{{.*}}mlir_converter_debug_locations.py":44
107+
108+
# CHECK: wave.read
109+
# CHECK-SAME: loc(#loc1)
110+
111+
# CHECK: wave.read
112+
# CHECK-SAME: loc(#loc2)
113+
114+
# CHECK: wave.add
115+
# CHECK-SAME: loc(#loc3)
116+
117+
# CHECK: wave.write
118+
# CHECK-SAME: loc(#loc4)
119+
120+
# CHECK: return loc(#loc)
121+
122+
# CHECK: loc(#loc)
123+
# CHECK: loc(#loc)
124+
125+
# CHECK: #loc1 = loc("{{.*}}mlir_converter_debug_locations.py":55
126+
# CHECK: #loc2 = loc("{{.*}}mlir_converter_debug_locations.py":56
127+
# CHECK: #loc3 = loc("{{.*}}mlir_converter_debug_locations.py":59
128+
# CHECK: #loc4 = loc("{{.*}}mlir_converter_debug_locations.py":62
129+
130+
131+
@run_test
132+
def mlir_converter_location_iterate():
133+
"""Test MLIR converter debug location emission with iterate."""
134+
135+
# Input sizes
136+
M = tkl.sym.M
137+
N = tkl.sym.N
138+
K = tkl.sym.K
139+
# Workgroup tile sizes
140+
BLOCK_M = tkl.sym.BLOCK_M
141+
BLOCK_N = tkl.sym.BLOCK_N
142+
BLOCK_K = tkl.sym.BLOCK_K
143+
# Address space (for GPU, shared(1) or global(0))
144+
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
145+
dtype = tkl.f16
146+
# Expose user-constraints
147+
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
148+
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
149+
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
150+
constraints += [tkw.WaveConstraint(M, sympy.floor(BLOCK_M / 2))]
151+
constraints += [tkw.WaveConstraint(N, sympy.floor(BLOCK_N / 2))]
152+
153+
constraints += [
154+
tkw.HardwareConstraint(threads_per_wave=64, mma_type=MMAType.F32_32x32x8_F16)
155+
]
156+
157+
@tkw.wave(constraints)
158+
def matmul(
159+
a: tkl.Memory[M, K, ADDRESS_SPACE, dtype],
160+
b: tkl.Memory[N, K, ADDRESS_SPACE, dtype],
161+
c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
162+
):
163+
c_reg = tkl.Register[M, N, tkl.f32](0.0)
164+
165+
# This microkernel encodes the fact that if the iterate
166+
# dimension were tiled, then we would need to materialize a loop.
167+
@tkw.iterate(K, init_args=[c_reg])
168+
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
169+
# a_reg: tkw.Register[M, K, dtype]
170+
a_reg = tkw.read(a)
171+
# b_reg: tkw.Register[N, K, dtype]
172+
b_reg = tkw.read(b)
173+
# acc: tkw.Register[M, N, tkl.f32]
174+
acc = tkw.mma(a_reg, b_reg, acc)
175+
return acc
176+
177+
# repeat represents the results of the loop
178+
tkw.write(repeat, c)
179+
180+
subs = {
181+
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
182+
BLOCK_M: 64,
183+
BLOCK_N: 64,
184+
BLOCK_K: 8,
185+
M: 1024,
186+
N: 5120,
187+
K: 640,
188+
}
189+
190+
options = WaveCompileOptions(
191+
subs=subs,
192+
compile_to_mlir=True, # Avoid IREE compilation
193+
location_capture_config=LocationCaptureConfig(
194+
level=LocationCaptureLevel.FILE_LINE_COL
195+
),
196+
)
197+
options = set_default_run_config(options)
198+
199+
compiled_kernel = wave_compile(options, matmul)
200+
201+
# Get the trace from the compiled kernel
202+
trace = compiled_kernel.compiled_graph
203+
204+
constraints = matmul.constraints
205+
206+
# Use the mlir_converter to emit wave MLIR dialect
207+
mlir_output, _ = emit_wave_dialect(trace, constraints, options, False)
208+
209+
# Print to stdout for FileCheck
210+
print(mlir_output)
211+
212+
# CHECK-LABEL: mlir_converter_location_iterate
213+
# CHECK: #loc = loc("{{.*}}mlir_converter_debug_locations.py":157
214+
# CHECK: #loc5 = loc("{{.*}}mlir_converter_debug_locations.py":174
215+
# CHECK: module
216+
# CHECK: func.func @kernel
217+
218+
# CHECK: wave.iterate
219+
# CHECK: %arg3: !wave.tensor<[@M, @N] of f32, <register>> loc("{{.*}}mlir_converter_debug_locations.py":174
220+
221+
# CHECK: wave.read
222+
# CHECK-SAME: loc(#loc3)
223+
224+
# CHECK: amdgpu.lds_barrier
225+
# CHECK-SAME: loc(#loc3)
226+
227+
# CHECK: wave.write
228+
# CHECK-SAME: loc(#loc3)
229+
230+
# CHECK: wave.read
231+
# CHECK-SAME: loc(#loc1)
232+
233+
# CHECK: wave.write
234+
# CHECK-SAME: loc(#loc1)
235+
236+
# CHECK: amdgpu.lds_barrier
237+
# CHECK-SAME: loc(#loc1)
238+
239+
# CHECK: wave.read
240+
# CHECK-SAME: loc(#loc1)
241+
242+
# CHECK: wave.read
243+
# CHECK-SAME: loc(#loc3)
244+
245+
# CHECK: wave.mma
246+
# CHECK-SAME: loc(#loc5)
247+
248+
# CHECK: wave.yield
249+
# CHECK-SAME: loc(#loc4)
250+
251+
# CHECK: (!wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>> loc(#loc4)
252+
253+
# CHECK: #loc1 = loc("{{.*}}mlir_converter_debug_locations.py":172
254+
# CHECK: #loc2 = loc("{{.*}}mlir_converter_debug_locations.py":163
255+
# CHECK: #loc3 = loc("{{.*}}mlir_converter_debug_locations.py":170
256+
# CHECK: #loc4 = loc("{{.*}}mlir_converter_debug_locations.py":167
257+
# CHECK: #loc6 = loc("{{.*}}mlir_converter_debug_locations.py":178

lit_tests/kernel/wave/mlir_converter_diagnostics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,4 @@ def mlir_converter_diagnostics_emission():
9191
print(diagnostics[0])
9292

9393
# CHECK-LABEL: mlir_converter_diagnostics_emission
94-
# CHECK: loc(unknown): test error
94+
# CHECK: loc("{{.*}}mlir_converter_diagnostics.py":37{{.*}}): test error

wave_lang/kernel/_support/location.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import inspect
99
import sys
1010
from dataclasses import dataclass
11-
from typing import List, Optional, Union, Callable
11+
from typing import List, Optional, TypeVar, Union, Callable
1212

1313
from ...support.location_config import LocationCaptureConfig, LocationCaptureLevel
1414

@@ -23,10 +23,12 @@ class FileLineColInfo:
2323
line: Union[int, tuple[int, int]]
2424
col: Union[int, tuple[int, int]]
2525

26-
def to_mlir(self):
27-
# Lazy import to avoid IREE dependency on module import
28-
from iree.compiler.ir import Context, Location
26+
ContextImpl = TypeVar("ContextImpl")
27+
LocationImpl = TypeVar("LocationImpl")
2928

29+
def _to_mlir_impl(
30+
self, Context: ContextImpl, Location: LocationImpl
31+
) -> LocationImpl:
3032
assert Context.current is not None, "Must be called under MLIR context manager."
3133

3234
line_is_range = isinstance(self.line, tuple)
@@ -39,6 +41,18 @@ def to_mlir(self):
3941
col_end = self.col[1] if col_is_range else self.col
4042
return Location.file(self.filename, line_start, col_start, line_end, col_end)
4143

44+
def to_mlir(self) -> Location:
45+
# Lazy import to avoid IREE dependency on module import
46+
from iree.compiler.ir import Context, Location
47+
48+
return self._to_mlir_impl(Context, Location)
49+
50+
def to_water(self) -> Location:
51+
# Lazy import to avoid IREE dependency on module import
52+
from water_mlir.water_mlir.ir import Context, Location
53+
54+
return self._to_mlir_impl(Context, Location)
55+
4256
@staticmethod
4357
def capture_current_location():
4458
# Need to find a part of the call stack that doesn't belong to us.
@@ -72,10 +86,12 @@ class StackTraceInfo:
7286

7387
frames: List[FileLineColInfo]
7488

75-
def to_mlir(self) -> Location:
76-
# Lazy import to avoid IREE dependency on module import
77-
from iree.compiler.ir import Context, Location
89+
ContextImpl = TypeVar("ContextImpl")
90+
LocationImpl = TypeVar("LocationImpl")
7891

92+
def _to_mlir_impl(
93+
self, Context: ContextImpl, Location: LocationImpl
94+
) -> LocationImpl:
7995
assert Context.current is not None, "Must be called under MLIR context manager."
8096
if not self.frames:
8197
return Location.unknown()
@@ -85,6 +101,18 @@ def to_mlir(self) -> Location:
85101
self.frames[0].to_mlir(), [f.to_mlir() for f in self.frames[1:]]
86102
)
87103

104+
def to_mlir(self) -> Location:
105+
# Lazy import to avoid IREE dependency on module import
106+
from iree.compiler.ir import Context, Location
107+
108+
return self._to_mlir_impl(Context, Location)
109+
110+
def to_water(self) -> Location:
111+
# Lazy import to avoid IREE dependency on module import
112+
from water_mlir.water_mlir.ir import Context, Location
113+
114+
return self._to_mlir_impl(Context, Location)
115+
88116
@staticmethod
89117
def capture_current_location(*, preserve_system_frames=False) -> "StackTraceInfo":
90118
# TODO: we may want to cache location info so we don't keep copying the

0 commit comments

Comments
 (0)