Skip to content

Commit d0ae04d

Browse files
committed
fix nvdsl
1 parent 031fb74 commit d0ae04d

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

mlir/test/Examples/NVGPU/Ch5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def producer_loop(
156156
):
157157
phase = const(True, ty=T.bool())
158158

159-
for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]):
159+
for iv, phase, _ in scf.for_(0, (K // TILE_K), 1, [phase]):
160160
stage = iv % num_stages
161161
# Wait MMA to be done
162162
mbar_mma[stage].try_wait(phase)

mlir/test/Examples/NVGPU/tools/nvdsl.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,7 @@ def arrive(self, txcount: int = 0, predicate=None):
8484
self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
8585
)
8686
else:
87-
nvgpu.mbarrier_arrive(
88-
ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op
87+
nvgpu.mbarrier_arrive(self.mbar_group_op, self.id_op
8988
)
9089

9190
def try_wait(self, phase: bool = False, ticks: int = 10000000):
@@ -144,7 +143,7 @@ def create_descriptor(self, device_ptr):
144143
device_ptr,
145144
)
146145
self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
147-
tma_descriptor_ty, device_unranked_memref, map(const, self.tma_box_shape)
146+
tma_descriptor_ty, device_unranked_memref, list(map(const, self.tma_box_shape))
148147
)
149148
return self.tma_descriptor.result
150149

@@ -156,7 +155,7 @@ def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
156155
dest,
157156
mbarrier.mbar_group_op,
158157
self.tma_descriptor,
159-
coordinates=map(const, coords),
158+
coordinates=list(map(const, coords)),
160159
mbarId=mbarrier.id_op,
161160
predicate=predicate,
162161
)

mlir/test/Examples/NVGPU/tools/nvgpucompiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ def compile(self, module: ir.Module):
3535

3636
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
3737
"""Wraps the module in a JIT execution engine."""
38-
return execution_engine.ExecutionEngine(
38+
ee = execution_engine.ExecutionEngine(
3939
module, opt_level=self.opt_level, shared_libs=self.shared_libs
4040
)
41+
ee.initialize()
42+
return ee
4143

4244
def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
4345
"""Compiles and jits the module."""

0 commit comments

Comments
 (0)