Skip to content

Commit c7aa49d

Browse files
authored
[python][gpu] Add mem_fence function (#199)
1 parent 1f35e17 commit c7aa49d

File tree

7 files changed

+112
-11
lines changed

7 files changed

+112
-11
lines changed

dpcomp_gpu_runtime/lib/kernel_api_stubs.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ extern "C" DPCOMP_GPU_RUNTIME_EXPORT void _mlir_ciface_kernel_barrier(int64_t) {
5151
STUB();
5252
}
5353

54+
extern "C" DPCOMP_GPU_RUNTIME_EXPORT void
55+
_mlir_ciface_kernel_mem_fence(int64_t) {
56+
STUB();
57+
}
58+
5459
#define ATOMIC_FUNC_DECL(op, suff, dt) \
5560
extern "C" DPCOMP_GPU_RUNTIME_EXPORT dt _mlir_ciface_atomic_##op##_##suff( \
5661
void *, dt) { \

mlir/include/mlir-extensions/dialect/gpu_runtime/IR/GpuRuntimeOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,5 +164,13 @@ def GPUBarrierOp : GpuRuntime_Op<"barrier"> {
164164
let assemblyFormat = "$flags attr-dict";
165165
}
166166

167+
def GPUMemFenceOp : GpuRuntime_Op<"mem_fence"> {
168+
let summary = "Orders loads and stores of a work-item executing a kernel.";
169+
170+
let arguments = (ins GpuRuntime_FenceFlagsAttr:$flags);
171+
172+
let assemblyFormat = "$flags attr-dict";
173+
}
174+
167175
#endif // GPURUNTIME_OPS
168176

mlir/lib/Conversion/gpu_to_gpu_runtime.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,32 @@ class ConvertBarrierOp
927927
}
928928
};
929929

930+
class ConvertMemFenceOp
931+
: public mlir::OpConversionPattern<gpu_runtime::GPUMemFenceOp> {
932+
public:
933+
using OpConversionPattern::OpConversionPattern;
934+
935+
mlir::LogicalResult
936+
matchAndRewrite(gpu_runtime::GPUMemFenceOp op,
937+
gpu_runtime::GPUMemFenceOp::Adaptor adaptor,
938+
mlir::ConversionPatternRewriter &rewriter) const override {
939+
auto scope = mlir::spirv::Scope::Workgroup;
940+
mlir::spirv::MemorySemantics semantics;
941+
if (adaptor.flags() == gpu_runtime::FenceFlags::global) {
942+
semantics = mlir::spirv::MemorySemantics::SequentiallyConsistent |
943+
mlir::spirv::MemorySemantics::CrossWorkgroupMemory;
944+
} else if (adaptor.flags() == gpu_runtime::FenceFlags::local) {
945+
semantics = mlir::spirv::MemorySemantics::SequentiallyConsistent |
946+
mlir::spirv::MemorySemantics::WorkgroupMemory;
947+
} else {
948+
return mlir::failure();
949+
}
950+
rewriter.replaceOpWithNewOp<mlir::spirv::MemoryBarrierOp>(op, scope,
951+
semantics);
952+
return mlir::success();
953+
}
954+
};
955+
930956
// TODO: something better
931957
class ConvertFunc : public mlir::OpConversionPattern<mlir::FuncOp> {
932958
public:
@@ -1006,11 +1032,11 @@ struct GPUToSpirvPass
10061032
mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
10071033
mlir::populateMathToSPIRVPatterns(typeConverter, patterns);
10081034

1009-
patterns.insert<ConvertSubviewOp, ConvertCastOp<mlir::memref::CastOp>,
1010-
ConvertCastOp<mlir::memref::ReinterpretCastOp>,
1011-
ConvertLoadOp, ConvertStoreOp, ConvertAtomicOps,
1012-
ConvertFunc, ConvertAssert, ConvertBarrierOp>(typeConverter,
1013-
context);
1035+
patterns
1036+
.insert<ConvertSubviewOp, ConvertCastOp<mlir::memref::CastOp>,
1037+
ConvertCastOp<mlir::memref::ReinterpretCastOp>, ConvertLoadOp,
1038+
ConvertStoreOp, ConvertAtomicOps, ConvertFunc, ConvertAssert,
1039+
ConvertBarrierOp, ConvertMemFenceOp>(typeConverter, context);
10141040

10151041
if (failed(
10161042
applyFullConversion(kernelModules, *target, std::move(patterns))))

numba_dpcomp/numba_dpcomp/mlir/kernel_impl.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,21 @@ def _barrier_impl(builder, flags=None):
320320
@infer_global(barrier)
321321
class _BarrierId(ConcreteTemplate):
322322
cases = [signature(types.void, types.int64), signature(types.void)]
323+
324+
325+
def mem_fence(flags=None):
326+
_stub_error()
327+
328+
329+
@registry.register_func("mem_fence", mem_fence)
330+
def _memf_fence_impl(builder, flags=None):
331+
if flags is None:
332+
flags = CLK_GLOBAL_MEM_FENCE
333+
334+
res = 0 # TODO: remove
335+
return builder.external_call("kernel_mem_fence", inputs=flags, outputs=res)
336+
337+
338+
@infer_global(mem_fence)
339+
class _MemFenceId(ConcreteTemplate):
340+
cases = [signature(types.void, types.int64), signature(types.void)]

numba_dpcomp/numba_dpcomp/mlir/kernel_sim.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
atomic_add,
3535
atomic_sub,
3636
barrier,
37+
mem_fence,
3738
)
3839

3940
_ExecutionState = namedtuple(
@@ -82,13 +83,11 @@ def sub(arr, ind, val):
8283

8384

8485
def barrier_proxy(flags):
85-
global _greenlet_found
86-
assert _greenlet_found, "greenlet package not installed"
8786
state = get_exec_state()
88-
assert len(state.tasks) > 0
8987
wg_size = state.wg_size[0]
9088
assert wg_size > 0
9189
if wg_size > 1:
90+
assert len(state.tasks) > 0
9291
indices = copy.deepcopy(state.indices)
9392
next_task = state.current_task[0] + 1
9493
if next_task >= wg_size:
@@ -98,6 +97,10 @@ def barrier_proxy(flags):
9897
state.indices[:] = indices
9998

10099

100+
def mem_fence_proxy(flags):
101+
pass # Nothing
102+
103+
101104
def _setup_execution_state(global_size, local_size):
102105
import numba_dpcomp.mlir.kernel_impl
103106

@@ -129,6 +132,7 @@ def _destroy_execution_state():
129132
("atomic_add", atomic_add, atomic_proxy.add),
130133
("atomic_sub", atomic_sub, atomic_proxy.sub),
131134
("barrier", barrier, barrier_proxy),
135+
("mem_fence", mem_fence, mem_fence_proxy),
132136
]
133137

134138

@@ -179,6 +183,11 @@ def wrapper():
179183
_barrier_ops = ["barrier"]
180184

181185

186+
def _have_barrier_ops(func):
187+
g = func.__globals__
188+
return any(n in g for n in _barrier_ops)
189+
190+
182191
def _execute_kernel(global_size, local_size, func, *args):
183192
if len(local_size) == 0:
184193
local_size = (1,) * len(global_size)
@@ -188,7 +197,7 @@ def _execute_kernel(global_size, local_size, func, *args):
188197
state = _setup_execution_state(global_size, local_size)
189198
try:
190199
groups = tuple((g + l - 1) // l for g, l in zip(global_size, local_size))
191-
need_barrier = any(n in func.__globals__ for n in _barrier_ops)
200+
need_barrier = max(local_size) > 1 and _have_barrier_ops(func)
192201
for gid in product(*(range(g) for g in groups)):
193202
offset = tuple(g * l for g, l in zip(gid, local_size))
194203
size = tuple(
@@ -202,7 +211,7 @@ def _execute_kernel(global_size, local_size, func, *args):
202211

203212
if need_barrier:
204213
global _greenlet_found
205-
assert _greenlet_found
214+
assert _greenlet_found, "greenlet package not installed"
206215
tasks = state.tasks
207216
assert len(tasks) == 0
208217
for indices in product(*indices_range):

numba_dpcomp/numba_dpcomp/mlir/tests/test_gpu.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
kernel_func,
3030
DEFAULT_LOCAL_SIZE,
3131
barrier,
32+
mem_fence,
3233
CLK_LOCAL_MEM_FENCE,
3334
CLK_GLOBAL_MEM_FENCE,
3435
)
@@ -687,6 +688,38 @@ def func(c):
687688
assert_equal(gpu_res, sim_res)
688689

689690

691+
@require_gpu
692+
@pytest.mark.parametrize("op", [barrier, mem_fence])
693+
@pytest.mark.parametrize("flags", [CLK_LOCAL_MEM_FENCE, CLK_GLOBAL_MEM_FENCE])
694+
@pytest.mark.parametrize("global_size", [1, 2, 27])
695+
@pytest.mark.parametrize("local_size", [1, 2, 7])
696+
def test_barrier_ops(op, flags, global_size, local_size):
697+
atomic_add = atomic.add
698+
699+
def func(a, b):
700+
i = get_global_id(0)
701+
v = a[i]
702+
op(flags)
703+
b[i] = a[i]
704+
705+
sim_func = kernel_sim(func)
706+
gpu_func = kernel_cached(func)
707+
708+
a = np.arange(global_size, dtype=np.int64)
709+
710+
sim_res = np.zeros(global_size, a.dtype)
711+
sim_func[global_size, local_size](a, sim_res)
712+
713+
gpu_res = np.zeros(global_size, a.dtype)
714+
715+
with print_pass_ir([], ["ConvertParallelLoopToGpu"]):
716+
gpu_func[global_size, local_size](a, gpu_res)
717+
ir = get_print_buffer()
718+
assert ir.count("gpu.launch blocks") == 1, ir
719+
720+
assert_equal(gpu_res, sim_res)
721+
722+
690723
@require_gpu
691724
@pytest.mark.parametrize("global_size", [1, 2, 27])
692725
@pytest.mark.parametrize("local_size", [1, 2, 7])

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/lower_to_gpu.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1186,7 +1186,9 @@ class ConvertBarrierOps : public mlir::OpRewritePattern<mlir::func::CallOp> {
11861186
using funcptr_t = void (*)(mlir::Operation *, mlir::PatternRewriter &,
11871187
gpu_runtime::FenceFlags);
11881188
const std::pair<llvm::StringRef, funcptr_t> handlers[] = {
1189-
{"kernel_barrier", &genBarrierOp<gpu_runtime::GPUBarrierOp>}};
1189+
{"kernel_barrier", &genBarrierOp<gpu_runtime::GPUBarrierOp>},
1190+
{"kernel_mem_fence", &genBarrierOp<gpu_runtime::GPUMemFenceOp>},
1191+
};
11901192

11911193
auto funcName = op.getCallee();
11921194
for (auto &h : handlers) {

0 commit comments

Comments
 (0)