Skip to content

Commit f6bff8f

Browse files
authored
add loc everywhere (#142)
1 parent db4a843 commit f6bff8f

File tree

6 files changed

+86
-33
lines changed

6 files changed

+86
-33
lines changed

mlir/extras/dialects/ext/gpu.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,43 +49,43 @@ def __get__(self, owner_self, owner_cls):
4949
class block_idx:
5050
@classproperty
5151
def x(cls):
52-
return _block_id("x")
52+
return _block_id("x", loc=get_user_code_loc())
5353

5454
@classproperty
5555
def y(cls):
56-
return _block_id("y")
56+
return _block_id("y", loc=get_user_code_loc())
5757

5858
@classproperty
5959
def z(cls):
60-
return _block_id("z")
60+
return _block_id("z", loc=get_user_code_loc())
6161

6262

6363
class block_dim:
6464
@classproperty
6565
def x(cls):
66-
return _block_dim("x")
66+
return _block_dim("x", loc=get_user_code_loc())
6767

6868
@classproperty
6969
def y(cls):
70-
return _block_dim("y")
70+
return _block_dim("y", loc=get_user_code_loc())
7171

7272
@classproperty
7373
def z(cls):
74-
return _block_dim("z")
74+
return _block_dim("z", loc=get_user_code_loc())
7575

7676

7777
class thread_idx:
7878
@classproperty
7979
def x(cls):
80-
return _thread_id("x")
80+
return _thread_id("x", loc=get_user_code_loc())
8181

8282
@classproperty
8383
def y(cls):
84-
return _thread_id("y")
84+
return _thread_id("y", loc=get_user_code_loc())
8585

8686
@classproperty
8787
def z(cls):
88-
return _thread_id("z")
88+
return _thread_id("z", loc=get_user_code_loc())
8989

9090

9191
def thread_id():
@@ -222,6 +222,8 @@ def __init__(
222222
loc=None,
223223
ip=None,
224224
):
225+
if loc is None:
226+
loc = get_user_code_loc()
225227
super().__init__(
226228
function_type=function_type,
227229
arg_attrs=arg_attrs,
@@ -301,10 +303,10 @@ def launch_(
301303
):
302304
if loc is None:
303305
loc = get_user_code_loc()
304-
for size in [grid_size, block_size]:
305-
for i, s in enumerate(size):
306-
if isinstance(s, int):
307-
size[i] = constant(s, index=True)
306+
for size in [grid_size, block_size]:
307+
for i, s in enumerate(size):
308+
if isinstance(s, int):
309+
size[i] = constant(s, index=True)
308310
launch_op = LaunchOp(
309311
grid_size,
310312
block_size,
@@ -371,13 +373,16 @@ def __call__(
371373
async_dependencies=None,
372374
dynamic_shared_memory_size: Optional[Value] = None,
373375
stream=None,
376+
loc=None,
377+
ip=None,
374378
):
375379
for size in [grid_size, block_size]:
376380
for i, s in enumerate(size):
377381
if isinstance(s, int):
378382
size[i] = constant(s, index=True)
379383

380-
loc = get_user_code_loc()
384+
if loc is None:
385+
loc = get_user_code_loc()
381386
return get_op_result_or_op_results(
382387
LaunchFuncOp(
383388
(
@@ -469,6 +474,8 @@ def all_reduce__(value: Value, *, op=None, uniform=None, loc=None, ip=None):
469474

470475

471476
def all_reduce_(value: Value, *, op=None, uniform=None, loc=None, ip=None):
477+
if loc is None:
478+
loc = get_user_code_loc()
472479
return get_op_result_or_op_results(
473480
all_reduce__(value, op=op, uniform=uniform, loc=loc, ip=ip)
474481
)
@@ -577,15 +584,18 @@ def get_compile_object_bytes(compiled_module):
577584
_printf = printf
578585

579586

580-
def printf(format, *args):
581-
loc = get_user_code_loc()
582-
return _printf(format=format, args=args, loc=loc)
587+
def printf(format, *args, loc=None, ip=None):
588+
if loc is None:
589+
loc = get_user_code_loc()
590+
return _printf(format=format, args=args, loc=loc, ip=ip)
583591

584592

585593
_dynamic_shared_memory = dynamic_shared_memory
586594

587595

588596
def dynamic_shared_memory(*, int=False, loc=None, ip=None):
597+
if loc is None:
598+
loc = get_user_code_loc()
589599
return _dynamic_shared_memory(
590600
T.memref(
591601
ShapedType.get_dynamic_size(),
@@ -611,3 +621,10 @@ def memset(dst, value, async_dependencies=None, *, loc=None, ip=None):
611621
if isinstance(value, (int, float, bool)):
612622
value = constant(value, type=dst.type.element_type)
613623
return _memset(async_token, async_dependencies, dst, value, loc=loc, ip=ip)
624+
625+
626+
def barrier(*, loc=None, ip=None):
627+
if loc is None:
628+
loc = get_user_code_loc()
629+
630+
return BarrierOp(loc=loc, ip=ip)

mlir/extras/dialects/ext/memref.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,8 @@ def _canonicalize_start_stop(start, stop, step):
281281
elif isinstance(start, int) and isinstance(stop, int):
282282
return stop - start
283283

284+
raise NotImplementedError
285+
284286

285287
def _subview(
286288
mem: MemRef,
@@ -362,6 +364,8 @@ def _copy_to_subview(
362364

363365

364366
def dim(source, index, *, loc=None, ip=None):
367+
if loc is None:
368+
loc = get_user_code_loc()
365369
if isinstance(index, int):
366370
index = constant(index, index=True)
367371
return _dim(source=source, index=index, loc=loc, ip=ip)
@@ -412,7 +416,9 @@ def global_(
412416
).opview
413417

414418

415-
def view(source, shape, dtype=None, shift=0, memory_space=None):
419+
def view(source, shape, dtype=None, shift=0, memory_space=None, loc=None, ip=None):
420+
if loc is None:
421+
loc = get_user_code_loc()
416422
if dtype is None:
417423
dtype = source.type.element_type
418424
byte_width_dtype = dtype.width // 8
@@ -425,6 +431,8 @@ def view(source, shape, dtype=None, shift=0, memory_space=None):
425431
source,
426432
byte_shift,
427433
[],
434+
loc=loc,
435+
ip=ip,
428436
)
429437

430438

@@ -434,6 +442,8 @@ def view(source, shape, dtype=None, shift=0, memory_space=None):
434442
def get_global(
435443
name_or_global, *, name=None, global_=None, result=None, loc=None, ip=None
436444
):
445+
if loc is None:
446+
loc = get_user_code_loc()
437447
if isinstance(name_or_global, GlobalOp):
438448
global_ = name_or_global
439449
elif isinstance(name_or_global, str):

mlir/extras/dialects/ext/rocdl.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class WMMA_F16_16X16X16_F16(ir.OpView):
2424
_ODS_REGIONS = (0, True)
2525

2626
def __init__(self, res, args, *, loc=None, ip=None):
27+
if loc is None:
28+
loc = get_user_code_loc()
2729
operands = []
2830
results = []
2931
attributes = {}
@@ -56,5 +58,11 @@ def res(self):
5658
return self.operation.results[0]
5759

5860

59-
def wmma_f16_16x16x16_f16(res, args, *, loc=None, ip=None) -> ir.Value:
60-
return WMMA_F16_16X16X16_F16(res=res, args=args, loc=loc, ip=ip).result
61+
def wmma_f16_16x16x16_f16(A, B, C, *, OPSEL=False, loc=None, ip=None) -> ir.Value:
62+
if loc is None:
63+
loc = get_user_code_loc()
64+
65+
opsel = arith.constant(OPSEL, ir.IntegerType.get_signless(1))
66+
args = [A, B, C, opsel]
67+
v16 = ir.VectorType.get((16,), ir.F16Type.get())
68+
return WMMA_F16_16X16X16_F16(res=v16, args=args, loc=loc, ip=ip).result

mlir/extras/dialects/ext/vector.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ def extract_strided_slice(vector, offsets, sizes, strides, *, loc=None, ip=None)
251251

252252

253253
def outerproduct(lhs, rhs, acc=None, *, kind=None, loc=None, ip=None):
254+
if loc is None:
255+
loc = get_user_code_loc()
254256
if kind is None:
255257
kind = CombiningKind.ADD
256258
result_shape = [lhs.shape[0], rhs.shape[0]]
@@ -262,6 +264,8 @@ def outerproduct(lhs, rhs, acc=None, *, kind=None, loc=None, ip=None):
262264

263265
@Infix
264266
def outer(lhs, rhs, acc=None, *, kind=None, loc=None, ip=None):
267+
if loc is None:
268+
loc = get_user_code_loc()
265269
return outerproduct(lhs, rhs, acc, kind=kind, loc=loc, ip=ip)
266270

267271

@@ -270,14 +274,20 @@ def outer(lhs, rhs, acc=None, *, kind=None, loc=None, ip=None):
270274

271275
@Infix
272276
def shuffle(v1, v2, mask, *, loc=None, ip=None):
277+
if loc is None:
278+
loc = get_user_code_loc()
273279
return ShuffleOp(v1=v1, v2=v2, mask=mask, loc=loc, ip=ip).result
274280

275281

276282
_load = load
277283

278284

279-
@Infix
280-
def load(base, indices, result, *, nontemporal=None, loc=None, ip=None):
285+
def load_(base, indices, result, *, nontemporal=None, loc=None, ip=None):
286+
if loc is None:
287+
loc = get_user_code_loc()
288+
for j, i in enumerate(indices):
289+
if isinstance(i, int):
290+
indices[j] = constant(i, index=True)
281291
return LoadOp(
282292
result=result,
283293
base=base,
@@ -286,3 +296,6 @@ def load(base, indices, result, *, nontemporal=None, loc=None, ip=None):
286296
loc=loc,
287297
ip=ip,
288298
).result
299+
300+
301+
load = Infix(load_)

mlir/extras/runtime/passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def run_pipeline(
3131
print_pipeline=False,
3232
verify=True,
3333
):
34-
module = Module.parse(str(module))
34+
module = Module.parse(module.operation.get_asm(enable_debug_info=True))
3535

3636
if isinstance(pipeline, Pipeline):
3737
pipeline = str(pipeline)

tests/test_gpu.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from mlir.dialects.memref import cast
1616

1717
from mlir.extras.ast.canonicalize import canonicalize
18-
from mlir.extras.dialects.ext import arith, scf, memref, rocdl
18+
from mlir.extras.dialects.ext import arith, scf, memref, rocdl, gpu
1919
from mlir.extras.dialects.ext.func import func
2020

2121
# noinspection PyUnresolvedReferences
@@ -758,7 +758,7 @@ def mat_product_kernel(
758758

759759
props = hip.hipDeviceProp_t()
760760
hip_check(hip.hipGetDeviceProperties(props, 0))
761-
arch = props.gcnArchName.decode()
761+
arch = props.gcnArchName.decode().split(":")[0]
762762

763763
@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
764764
def gpu_module():
@@ -869,7 +869,7 @@ def mat_product_kernel(
869869

870870
props = hip.hipDeviceProp_t()
871871
hip_check(hip.hipGetDeviceProperties(props, 0))
872-
arch = props.gcnArchName.decode()
872+
arch = props.gcnArchName.decode().split(":")[0]
873873

874874
@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
875875
def gpu_module():
@@ -996,7 +996,7 @@ def smol_matmul(
996996

997997
props = hip.hipDeviceProp_t()
998998
hip_check(hip.hipGetDeviceProperties(props, 0))
999-
arch = props.gcnArchName.decode()
999+
arch = props.gcnArchName.decode().split(":")[0]
10001000

10011001
@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
10021002
def gpu_module():
@@ -1104,7 +1104,7 @@ def all_bank_conflicts(A: T.memref(M, M, T.f32()), B: T.memref(M, M, T.f32())):
11041104

11051105
props = hip.hipDeviceProp_t()
11061106
hip_check(hip.hipGetDeviceProperties(props, 0))
1107-
arch = props.gcnArchName.decode()
1107+
arch = props.gcnArchName.decode().split(":")[0]
11081108

11091109
@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
11101110
def gpu_module():
@@ -1228,9 +1228,10 @@ def smol_matmul(
12281228
a_frag[ele] = a[lane, ele]
12291229
a_frag, b_frag = yield a_frag, b_frag
12301230

1231-
# call the WMMA intrinsic
1232-
false = arith.constant(False, T.bool())
1233-
c_frag = rocdl.wmma_f16_16x16x16_f16(v16f16, [a_frag, b_frag, c_frag, false])
1231+
c_frag = rocdl.wmma_f16_16x16x16_f16(a_frag, b_frag, c_frag)
1232+
1233+
for i in scf.range_(v_len):
1234+
gpu.printf("(%02ld, %02ld, %02ld), %f\n", lIdx, lane, i, c_frag[i])
12341235

12351236
for ele in scf.range_(v_len // 2):
12361237
r = ele * 2 + (lIdx // v_len)
@@ -1239,7 +1240,7 @@ def smol_matmul(
12391240

12401241
props = hip.hipDeviceProp_t()
12411242
hip_check(hip.hipGetDeviceProperties(props, 0))
1242-
arch = props.gcnArchName.decode()
1243+
arch = props.gcnArchName.decode().split(":")[0]
12431244

12441245
@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
12451246
def gpu_module():
@@ -1250,7 +1251,11 @@ def gpu_module():
12501251
lowered_module = run_pipeline(
12511252
gpu_module,
12521253
Pipeline()
1253-
.Gpu(Pipeline().convert_gpu_to_rocdl(use_bare_ptr_memref_call_conv=True))
1254+
.Gpu(
1255+
Pipeline().convert_gpu_to_rocdl(
1256+
use_bare_ptr_memref_call_conv=True, runtime="HIP"
1257+
)
1258+
)
12541259
.rocdl_attach_target(chip=arch, abi="500")
12551260
.gpu_to_llvm()
12561261
.lower_to_llvm()

0 commit comments

Comments
 (0)