Skip to content

Commit aba6492

Browse files
authored
[python][gpu] barrier support in kernel api (#197)
* Add GPUBarrierOp and spirv lowering * Update kernel simulator to support barriers (using coroutines)
1 parent 04b7d6f commit aba6492

File tree

12 files changed

+275
-17
lines changed

12 files changed

+275
-17
lines changed

dpcomp_gpu_runtime/lib/kernel_api_stubs.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ _mlir_ciface_get_local_size(int64_t) {
4747
STUB();
4848
}
4949

50+
extern "C" DPCOMP_GPU_RUNTIME_EXPORT void _mlir_ciface_kernel_barrier(int64_t) {
51+
STUB();
52+
}
53+
5054
#define ATOMIC_FUNC_DECL(op, suff, dt) \
5155
extern "C" DPCOMP_GPU_RUNTIME_EXPORT dt _mlir_ciface_atomic_##op##_##suff( \
5256
void *, dt) { \

mlir/include/mlir-extensions/dialect/gpu_runtime/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ set(dialect_namespace gpu_runtime)
44
set(LLVM_TARGET_DEFINITIONS ${dialect}.td)
55
mlir_tablegen(${dialect}Enums.h.inc -gen-enum-decls)
66
mlir_tablegen(${dialect}Enums.cpp.inc -gen-enum-defs)
7+
mlir_tablegen(${dialect}Attributes.h.inc -gen-attrdef-decls -attrdefs-dialect=gpu_runtime)
8+
mlir_tablegen(${dialect}Attributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=gpu_runtime)
79
mlir_tablegen(${dialect}.h.inc -gen-op-decls)
810
mlir_tablegen(${dialect}.cpp.inc -gen-op-defs)
911
mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace})

mlir/include/mlir-extensions/dialect/gpu_runtime/IR/GpuRuntimeOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
2222
include "mlir/Interfaces/LoopLikeInterface.td"
2323
include "mlir/Interfaces/SideEffectInterfaces.td"
2424
include "mlir/Interfaces/ViewLikeInterface.td"
25+
include "mlir/IR/EnumAttr.td"
2526

2627
def GpuRuntime_Dialect : Dialect {
2728
let name = "gpu_runtime";
@@ -40,6 +41,17 @@ def GpuRuntime_OpaqueType
4041
"opaque_type">,
4142
BuildableType<"$_builder.getType<::gpu_runtime::OpaqueType>()"> {}
4243

44+
def GpuRuntime_FenceFlags : I32EnumAttr<"FenceFlags",
45+
"Kernel barrier and fence flags",
46+
[
47+
I32EnumAttrCase<"local", 1>,
48+
I32EnumAttrCase<"global", 2>
49+
]>{
50+
let genSpecializedAttr = 0;
51+
let cppNamespace = "::gpu_runtime";
52+
}
53+
def GpuRuntime_FenceFlagsAttr : EnumAttr<GpuRuntime_Dialect, GpuRuntime_FenceFlags, "fenceFlags">;
54+
4355
def CreateGpuStreamOp : GpuRuntime_Op<"create_gpu_stream", [NoSideEffect]> {
4456
let results = (outs GpuRuntime_OpaqueType : $result);
4557

@@ -144,5 +156,13 @@ def GPUSuggestBlockSizeOp : GpuRuntime_Op<"suggest_block_size",
144156
}];
145157
}
146158

159+
def GPUBarrierOp : GpuRuntime_Op<"barrier"> {
160+
let summary = "Synchronizes all work items of a workgroup.";
161+
162+
let arguments = (ins GpuRuntime_FenceFlagsAttr:$flags);
163+
164+
let assemblyFormat = "$flags attr-dict";
165+
}
166+
147167
#endif // GPURUNTIME_OPS
148168

mlir/include/mlir-extensions/dialect/gpu_runtime/IR/gpu_runtime_ops.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
#include <mlir/Dialect/GPU/GPUDialect.h>
2828

2929
#include "mlir-extensions/dialect/gpu_runtime/IR/GpuRuntimeOpsDialect.h.inc"
30+
31+
#include "mlir-extensions/dialect/gpu_runtime/IR/GpuRuntimeOpsEnums.h.inc"
32+
33+
#define GET_ATTRDEF_CLASSES
34+
#include "mlir-extensions/dialect/gpu_runtime/IR/GpuRuntimeOpsAttributes.h.inc"
35+
3036
#define GET_OP_CLASSES
3137
#include "mlir-extensions/dialect/gpu_runtime/IR/GpuRuntimeOps.h.inc"
3238

mlir/include/mlir-extensions/dialect/plier_util/PlierUtilOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#define PLIER_UTIL_OPS
1717

1818
include "mlir/IR/OpBase.td"
19-
include "mlir/Dialect/GPU/GPUBase.td"
2019
include "mlir/Interfaces/ControlFlowInterfaces.td"
2120
include "mlir/Interfaces/InferTypeOpInterface.td"
2221
include "mlir/Interfaces/LoopLikeInterface.td"

mlir/include/mlir-extensions/dialect/plier_util/dialect.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
#include <mlir/Interfaces/SideEffectInterfaces.h>
2525
#include <mlir/Interfaces/ViewLikeInterface.h>
2626

27-
#include <mlir/Dialect/GPU/GPUDialect.h>
28-
2927
#include "mlir-extensions/dialect/plier_util/PlierUtilOpsDialect.h.inc"
3028
#include "mlir-extensions/dialect/plier_util/PlierUtilOpsEnums.h.inc"
3129

mlir/lib/Conversion/gpu_to_gpu_runtime.cpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ static mlir::Value lowerFloatSubAtomic(mlir::OpBuilder &builder,
833833

834834
class ConvertAtomicOps : public mlir::OpConversionPattern<mlir::func::CallOp> {
835835
public:
836-
using mlir::OpConversionPattern<mlir::func::CallOp>::OpConversionPattern;
836+
using OpConversionPattern::OpConversionPattern;
837837

838838
mlir::LogicalResult
839839
matchAndRewrite(mlir::func::CallOp op, mlir::func::CallOp::Adaptor adaptor,
@@ -901,6 +901,32 @@ class ConvertAtomicOps : public mlir::OpConversionPattern<mlir::func::CallOp> {
901901
}
902902
};
903903

904+
class ConvertBarrierOp
905+
: public mlir::OpConversionPattern<gpu_runtime::GPUBarrierOp> {
906+
public:
907+
using OpConversionPattern::OpConversionPattern;
908+
909+
mlir::LogicalResult
910+
matchAndRewrite(gpu_runtime::GPUBarrierOp op,
911+
gpu_runtime::GPUBarrierOp::Adaptor adaptor,
912+
mlir::ConversionPatternRewriter &rewriter) const override {
913+
auto scope = mlir::spirv::Scope::Workgroup;
914+
mlir::spirv::MemorySemantics semantics;
915+
if (adaptor.flags() == gpu_runtime::FenceFlags::global) {
916+
semantics = mlir::spirv::MemorySemantics::SequentiallyConsistent |
917+
mlir::spirv::MemorySemantics::CrossWorkgroupMemory;
918+
} else if (adaptor.flags() == gpu_runtime::FenceFlags::local) {
919+
semantics = mlir::spirv::MemorySemantics::SequentiallyConsistent |
920+
mlir::spirv::MemorySemantics::WorkgroupMemory;
921+
} else {
922+
return mlir::failure();
923+
}
924+
rewriter.replaceOpWithNewOp<mlir::spirv::ControlBarrierOp>(op, scope, scope,
925+
semantics);
926+
return mlir::success();
927+
}
928+
};
929+
904930
// TODO: something better
905931
class ConvertFunc : public mlir::OpConversionPattern<mlir::FuncOp> {
906932
public:
@@ -980,11 +1006,11 @@ struct GPUToSpirvPass
9801006
mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
9811007
mlir::populateMathToSPIRVPatterns(typeConverter, patterns);
9821008

983-
patterns
984-
.insert<ConvertSubviewOp, ConvertCastOp<mlir::memref::CastOp>,
985-
ConvertCastOp<mlir::memref::ReinterpretCastOp>, ConvertLoadOp,
986-
ConvertStoreOp, ConvertAtomicOps, ConvertFunc, ConvertAssert>(
987-
typeConverter, context);
1009+
patterns.insert<ConvertSubviewOp, ConvertCastOp<mlir::memref::CastOp>,
1010+
ConvertCastOp<mlir::memref::ReinterpretCastOp>,
1011+
ConvertLoadOp, ConvertStoreOp, ConvertAtomicOps,
1012+
ConvertFunc, ConvertAssert, ConvertBarrierOp>(typeConverter,
1013+
context);
9881014

9891015
if (failed(
9901016
applyFullConversion(kernelModules, *target, std::move(patterns))))

mlir/lib/dialect/gpu_runtime/IR/gpu_runtime_ops.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ void GpuRuntimeDialect::initialize() {
5555
#define GET_OP_LIST
5656
#include "mlir-extensions/dialect/gpu_runtime/IR/GpuRuntimeOps.cpp.inc"
5757
>();
58+
addAttributes<
59+
#define GET_ATTRDEF_LIST
60+
#include "mlir-extensions/dialect/gpu_runtime/IR/GpuRuntimeOpsAttributes.cpp.inc"
61+
>();
5862
addTypes<OpaqueType>();
5963
addInterfaces<GpuRuntimeInlinerInterface>();
6064
}
@@ -208,5 +212,10 @@ mlir::StringAttr GPUSuggestBlockSizeOp::getKernelName() {
208212

209213
#include "mlir-extensions/dialect/gpu_runtime/IR/GpuRuntimeOpsDialect.cpp.inc"
210214

215+
#define GET_ATTRDEF_CLASSES
216+
#include "mlir-extensions/dialect/gpu_runtime/IR/GpuRuntimeOpsAttributes.cpp.inc"
217+
211218
#define GET_OP_CLASSES
212219
#include "mlir-extensions/dialect/gpu_runtime/IR/GpuRuntimeOps.cpp.inc"
220+
221+
#include "mlir-extensions/dialect/gpu_runtime/IR/GpuRuntimeOpsEnums.cpp.inc"

numba_dpcomp/numba_dpcomp/mlir/kernel_impl.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,26 @@ def generic(self, args, kws):
297297

298298
_define_atomic_funcs()
299299
del _define_atomic_funcs
300+
301+
302+
# mem fence
303+
CLK_LOCAL_MEM_FENCE = 0x1
304+
CLK_GLOBAL_MEM_FENCE = 0x2
305+
306+
307+
def barrier(flags=None):
308+
_stub_error()
309+
310+
311+
@registry.register_func("barrier", barrier)
312+
def _barrier_impl(builder, flags=None):
313+
if flags is None:
314+
flags = CLK_GLOBAL_MEM_FENCE
315+
316+
res = 0 # TODO: remove
317+
return builder.external_call("kernel_barrier", inputs=flags, outputs=res)
318+
319+
320+
@infer_global(barrier)
321+
class _BarrierId(ConcreteTemplate):
322+
cases = [signature(types.void, types.int64), signature(types.void)]

numba_dpcomp/numba_dpcomp/mlir/kernel_sim.py

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414

1515
from collections import namedtuple
1616
from itertools import product
17+
from functools import reduce
18+
import copy
19+
20+
try:
21+
from greenlet import greenlet
22+
23+
_greenlet_found = True
24+
except ImportError:
25+
_greenlet_found = False
1726

1827
from .kernel_base import KernelBase
1928
from .kernel_impl import (
@@ -24,10 +33,12 @@
2433
atomic,
2534
atomic_add,
2635
atomic_sub,
36+
barrier,
2737
)
2838

2939
_ExecutionState = namedtuple(
30-
"_ExecutionState", ["global_size", "local_size", "indices",]
40+
"_ExecutionState",
41+
["global_size", "local_size", "indices", "wg_size", "tasks", "current_task"],
3142
)
3243

3344
_execution_state = None
@@ -70,16 +81,35 @@ def sub(arr, ind, val):
7081
return new_val
7182

7283

84+
def barrier_proxy(flags):
85+
global _greenlet_found
86+
assert _greenlet_found, "greenlet package not installed"
87+
state = get_exec_state()
88+
wg_size = state.wg_size[0]
89+
assert wg_size > 0
90+
if wg_size > 1:
91+
indices = copy.deepcopy(state.indices)
92+
next_task = state.current_task[0] + 1
93+
if next_task >= wg_size:
94+
next_task = 0
95+
state.current_task[0] = next_task
96+
state.tasks[next_task].switch()
97+
state.indices[:] = indices
98+
99+
73100
def _setup_execution_state(global_size, local_size):
74101
import numba_dpcomp.mlir.kernel_impl
75102

76103
global _execution_state
77104
assert _execution_state is None
78-
if len(local_size) == 0:
79-
local_size = (1,) * len(global_size)
80105

81106
_execution_state = _ExecutionState(
82-
global_size=global_size, local_size=local_size, indices=[0] * len(global_size)
107+
global_size=global_size,
108+
local_size=local_size,
109+
indices=[0] * len(global_size),
110+
wg_size=[None],
111+
tasks=[],
112+
current_task=[None],
83113
)
84114
return _execution_state
85115

@@ -97,6 +127,7 @@ def _destroy_execution_state():
97127
("atomic", atomic, atomic_proxy),
98128
("atomic_add", atomic_add, atomic_proxy.add),
99129
("atomic_sub", atomic_sub, atomic_proxy.sub),
130+
("barrier", barrier, barrier_proxy),
100131
]
101132

102133

@@ -136,14 +167,50 @@ def _restore_closure(src, old_closure):
136167
src[i].cell_contents = old_closure[i]
137168

138169

170+
def _capture_func(func, indices, args):
171+
def wrapper():
172+
get_exec_state().indices[:] = indices
173+
func(*args)
174+
175+
return wrapper
176+
177+
139178
def _execute_kernel(global_size, local_size, func, *args):
179+
if len(local_size) == 0:
180+
local_size = (1,) * len(global_size)
181+
140182
saved_globals = _replace_globals(func.__globals__)
141183
saved_closure = _replace_closure(func.__closure__)
142184
state = _setup_execution_state(global_size, local_size)
143185
try:
144-
for indices in product(*(range(d) for d in global_size)):
145-
state.indices[:] = indices
146-
func(*args)
186+
groups = tuple((g + l - 1) // l for g, l in zip(global_size, local_size))
187+
for gid in product(*(range(g) for g in groups)):
188+
offset = tuple(g * l for g, l in zip(gid, local_size))
189+
size = tuple(
190+
min(g - o, l) for o, g, l in zip(offset, global_size, local_size)
191+
)
192+
count = reduce(lambda a, b: a * b, size)
193+
state.wg_size[0] = count
194+
state.current_task[0] = 0
195+
196+
indices_range = (range(o, o + s) for o, s in zip(offset, size))
197+
198+
global _greenlet_found
199+
if _greenlet_found:
200+
tasks = state.tasks
201+
assert len(tasks) == 0
202+
for indices in product(*indices_range):
203+
tasks.append(greenlet(_capture_func(func, indices, args)))
204+
205+
for t in tasks:
206+
t.switch()
207+
208+
tasks.clear()
209+
else:
210+
for indices in product(*indices_range):
211+
state.indices[:] = indices
212+
func(*args)
213+
147214
finally:
148215
_restore_closure(func.__closure__, saved_closure)
149216
_restore_globals(func.__globals__, saved_globals)

0 commit comments

Comments
 (0)