Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/test/Examples/NVGPU/Ch5.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def producer_loop(
):
phase = const(True, ty=T.bool())

for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]):
for iv, phase, _ in scf.for_(0, (K // TILE_K), 1, [phase]):
stage = iv % num_stages
# Wait MMA to be done
mbar_mma[stage].try_wait(phase)
Expand Down
7 changes: 3 additions & 4 deletions mlir/test/Examples/NVGPU/tools/nvdsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ def arrive(self, txcount: int = 0, predicate=None):
self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
)
else:
nvgpu.mbarrier_arrive(
ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op
nvgpu.mbarrier_arrive(self.mbar_group_op, self.id_op
)

def try_wait(self, phase: bool = False, ticks: int = 10000000):
Expand Down Expand Up @@ -144,7 +143,7 @@ def create_descriptor(self, device_ptr):
device_ptr,
)
self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
tma_descriptor_ty, device_unranked_memref, map(const, self.tma_box_shape)
tma_descriptor_ty, device_unranked_memref, list(map(const, self.tma_box_shape))
)
return self.tma_descriptor.result

Expand All @@ -156,7 +155,7 @@ def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
dest,
mbarrier.mbar_group_op,
self.tma_descriptor,
coordinates=map(const, coords),
coordinates=list(map(const, coords)),
mbarId=mbarrier.id_op,
predicate=predicate,
)
Expand Down
4 changes: 3 additions & 1 deletion mlir/test/Examples/NVGPU/tools/nvgpucompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def compile(self, module: ir.Module):

def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Wraps the module in a JIT execution engine."""
return execution_engine.ExecutionEngine(
ee = execution_engine.ExecutionEngine(
module, opt_level=self.opt_level, shared_libs=self.shared_libs
)
ee.initialize()
return ee

def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Compiles and jits the module."""
Expand Down
Loading