Skip to content

Commit dd25045

Browse files
authored
[python][gpy] Kernel api: local mem allocation support (#202)
1 parent 38a7bcf commit dd25045

File tree

11 files changed

+453
-26
lines changed

11 files changed

+453
-26
lines changed

dpcomp_gpu_runtime/lib/kernel_api_stubs.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,24 @@ ATOMIC_FUNC_DECL2(sub)
7373

7474
#undef ATOMIC_FUNC_DECL2
7575
#undef ATOMIC_FUNC_DECL
76+
77+
#define LOCAL_ARRAY_FUNC_DECL(type, cnt) \
78+
extern "C" DPCOMP_GPU_RUNTIME_EXPORT void \
79+
_mlir_ciface_local_array_##type##_##cnt() { \
80+
STUB(); \
81+
}
82+
83+
#define LOCAL_ARRAY_FUNC_DECL2(cnt) \
84+
LOCAL_ARRAY_FUNC_DECL(int32, cnt) \
85+
LOCAL_ARRAY_FUNC_DECL(int64, cnt) \
86+
LOCAL_ARRAY_FUNC_DECL(float32, cnt) \
87+
LOCAL_ARRAY_FUNC_DECL(float64, cnt)
88+
89+
LOCAL_ARRAY_FUNC_DECL2(1)
90+
LOCAL_ARRAY_FUNC_DECL2(2)
91+
LOCAL_ARRAY_FUNC_DECL2(3)
92+
LOCAL_ARRAY_FUNC_DECL2(4)
93+
LOCAL_ARRAY_FUNC_DECL2(5)
94+
95+
#undef LOCAL_ARRAY_FUNC_DECL2
96+
#undef LOCAL_ARRAY_FUNC_DECL

mlir/include/mlir-extensions/dialect/gpu_runtime/IR/GpuRuntimeOps.td

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,17 @@ def GpuRuntime_FenceFlags : I32EnumAttr<"FenceFlags",
5050
let genSpecializedAttr = 0;
5151
let cppNamespace = "::gpu_runtime";
5252
}
53-
def GpuRuntime_FenceFlagsAttr : EnumAttr<GpuRuntime_Dialect, GpuRuntime_FenceFlags, "fenceFlags">;
53+
def GpuRuntime_FenceFlagsAttr : EnumAttr<GpuRuntime_Dialect, GpuRuntime_FenceFlags, "fence_flags">;
54+
55+
def GpuRuntime_StorageClass : I32EnumAttr<"StorageClass",
56+
"Kernel barrier and fence flags",
57+
[
58+
I32EnumAttrCase<"local", 1>,
59+
]>{
60+
let genSpecializedAttr = 0;
61+
let cppNamespace = "::gpu_runtime";
62+
}
63+
def GpuRuntime_StorageClassAttr : EnumAttr<GpuRuntime_Dialect, GpuRuntime_StorageClass, "storage_class">;
5464

5565
def CreateGpuStreamOp : GpuRuntime_Op<"create_gpu_stream", [NoSideEffect]> {
5666
let results = (outs GpuRuntime_OpaqueType : $result);

mlir/lib/Conversion/gpu_to_gpu_runtime.cpp

Lines changed: 116 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,11 @@ struct InsertGPUAllocs
212212
if (op->getDialect() == scfDialect ||
213213
mlir::isa<mlir::ViewLikeOpInterface>(op))
214214
continue;
215-
if (mlir::dyn_cast<mlir::memref::AllocOp>(op)) {
216-
gpuBufferAllocs.insert({op, {}});
217-
} else if (mlir::dyn_cast<mlir::memref::GetGlobalOp>(op)) {
215+
if (mlir::isa<mlir::memref::AllocOp,
216+
mlir::memref::GetGlobalOp>(op)) {
218217
gpuBufferAllocs.insert({op, {}});
218+
} else if (mlir::isa<mlir::func::CallOp>(op)) {
219+
// Ignore
219220
} else {
220221
op->emitError("Unhandled memref producer");
221222
return mlir::WalkResult::interrupt();
@@ -942,6 +943,104 @@ class ConvertMemFenceOp
942943
}
943944
};
944945

946+
static llvm::Optional<mlir::spirv::StorageClass>
947+
convertStorageClass(mlir::Attribute src) {
948+
auto attr = src.dyn_cast_or_null<gpu_runtime::StorageClassAttr>();
949+
if (!attr)
950+
return llvm::None;
951+
952+
auto sc = attr.getValue();
953+
if (sc == gpu_runtime::StorageClass::local)
954+
return mlir::spirv::StorageClass::Workgroup;
955+
956+
return llvm::None;
957+
}
958+
959+
static mlir::spirv::StorageClass
960+
convertStorageClass(mlir::Attribute src, mlir::spirv::StorageClass def) {
961+
auto ret = convertStorageClass(src);
962+
if (ret)
963+
return *ret;
964+
965+
return def;
966+
}
967+
968+
class ConvertGlobalOp
969+
: public mlir::OpConversionPattern<mlir::memref::GlobalOp> {
970+
public:
971+
using OpConversionPattern::OpConversionPattern;
972+
973+
mlir::LogicalResult
974+
matchAndRewrite(mlir::memref::GlobalOp op,
975+
mlir::memref::GlobalOp::Adaptor adaptor,
976+
mlir::ConversionPatternRewriter &rewriter) const override {
977+
auto memrefType = op.type();
978+
if (!memrefType.hasStaticShape())
979+
return mlir::failure();
980+
981+
auto storageClass = convertStorageClass(memrefType.getMemorySpace());
982+
if (!storageClass)
983+
return mlir::failure();
984+
985+
auto converter = getTypeConverter();
986+
assert(converter);
987+
988+
auto elemType = converter->convertType(memrefType.getElementType());
989+
if (!elemType)
990+
return mlir::failure();
991+
992+
auto elemCount = memrefType.getNumElements();
993+
auto newType = mlir::spirv::ArrayType::get(elemType, elemCount);
994+
auto ptrType = mlir::spirv::PointerType::get(newType, *storageClass);
995+
996+
rewriter.replaceOpWithNewOp<mlir::spirv::GlobalVariableOp>(
997+
op, ptrType, adaptor.sym_name());
998+
return mlir::success();
999+
}
1000+
};
1001+
1002+
class ConvertGetGlobalOp
1003+
: public mlir::OpConversionPattern<mlir::memref::GetGlobalOp> {
1004+
public:
1005+
using OpConversionPattern::OpConversionPattern;
1006+
1007+
mlir::LogicalResult
1008+
matchAndRewrite(mlir::memref::GetGlobalOp op,
1009+
mlir::memref::GetGlobalOp::Adaptor adaptor,
1010+
mlir::ConversionPatternRewriter &rewriter) const override {
1011+
auto memrefType = op.getType().dyn_cast<mlir::MemRefType>();
1012+
if (!memrefType)
1013+
return mlir::failure();
1014+
1015+
auto storageClass = convertStorageClass(memrefType.getMemorySpace());
1016+
if (!storageClass)
1017+
return mlir::failure();
1018+
1019+
auto converter = getTypeConverter();
1020+
assert(converter);
1021+
auto resType = converter->convertType(memrefType);
1022+
if (!resType)
1023+
return mlir::failure();
1024+
1025+
auto elemType = converter->convertType(memrefType.getElementType());
1026+
if (!elemType)
1027+
return mlir::failure();
1028+
1029+
auto elemCount = memrefType.getNumElements();
1030+
auto newType = mlir::spirv::ArrayType::get(elemType, elemCount);
1031+
auto ptrType = mlir::spirv::PointerType::get(newType, *storageClass);
1032+
1033+
auto loc = op->getLoc();
1034+
mlir::Value res =
1035+
rewriter.create<mlir::spirv::AddressOfOp>(loc, ptrType, adaptor.name());
1036+
if (res.getType() != resType)
1037+
res = rewriter.create<mlir::spirv::BitcastOp>(loc, resType, res);
1038+
1039+
rewriter.replaceOp(op, res);
1040+
return mlir::success();
1041+
}
1042+
};
1043+
9451044
// TODO: something better
9461045
class ConvertFunc : public mlir::OpConversionPattern<mlir::FuncOp> {
9471046
public:
@@ -1024,12 +1123,18 @@ struct GPUToSpirvPass
10241123
mlir::RewritePatternSet patterns(context);
10251124

10261125
typeConverter.addConversion(
1027-
[](mlir::MemRefType type) -> llvm::Optional<mlir::Type> {
1028-
if (type.hasRank() && type.getElementType().isIntOrFloat())
1029-
return mlir::spirv::PointerType::get(
1030-
type.getElementType(),
1031-
mlir::spirv::StorageClass::CrossWorkgroup);
1032-
return mlir::Type(nullptr);
1126+
[&typeConverter](mlir::MemRefType type) -> llvm::Optional<mlir::Type> {
1127+
if (!type.hasRank() || !type.getElementType().isIntOrFloat())
1128+
return mlir::Type(nullptr);
1129+
1130+
auto elemType = typeConverter.convertType(type.getElementType());
1131+
if (!elemType)
1132+
return mlir::Type(nullptr);
1133+
1134+
auto sc = convertStorageClass(
1135+
type.getMemorySpace(), mlir::spirv::StorageClass::CrossWorkgroup);
1136+
1137+
return mlir::spirv::PointerType::get(elemType, sc);
10331138
});
10341139

10351140
mlir::ScfToSPIRVContext scfToSpirvCtx;
@@ -1044,8 +1149,8 @@ struct GPUToSpirvPass
10441149
.insert<ConvertSubviewOp, ConvertCastOp<mlir::memref::CastOp>,
10451150
ConvertCastOp<mlir::memref::ReinterpretCastOp>, ConvertLoadOp,
10461151
ConvertStoreOp, ConvertAtomicOps, ConvertFunc, ConvertAssert,
1047-
ConvertBarrierOp, ConvertMemFenceOp, ConvertUndef>(
1048-
typeConverter, context);
1152+
ConvertBarrierOp, ConvertMemFenceOp, ConvertUndef,
1153+
ConvertGlobalOp, ConvertGetGlobalOp>(typeConverter, context);
10491154

10501155
if (failed(
10511156
applyFullConversion(kernelModules, *target, std::move(patterns))))

mlir/lib/dialect/plier_util/dialect.cpp

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,7 +1180,7 @@ struct SignCastCastPropagate : public mlir::OpRewritePattern<CastOp> {
11801180
if (!signCast)
11811181
return mlir::failure();
11821182

1183-
auto srcType = op.source().getType().template cast<mlir::ShapedType>();
1183+
auto srcType = signCast.getType().template cast<mlir::ShapedType>();
11841184
auto dstType = op.getType().template cast<mlir::ShapedType>();
11851185
if (srcType.getElementType() != dstType.getElementType() ||
11861186
!srcType.hasRank() || !dstType.hasRank())
@@ -1200,6 +1200,88 @@ struct SignCastCastPropagate : public mlir::OpRewritePattern<CastOp> {
12001200
}
12011201
};
12021202

1203+
struct SignCastReinterpretPropagate
1204+
: public mlir::OpRewritePattern<mlir::memref::ReinterpretCastOp> {
1205+
using OpRewritePattern::OpRewritePattern;
1206+
1207+
mlir::LogicalResult
1208+
matchAndRewrite(mlir::memref::ReinterpretCastOp op,
1209+
mlir::PatternRewriter &rewriter) const override {
1210+
auto signCast = op.source().getDefiningOp<plier::SignCastOp>();
1211+
if (!signCast)
1212+
return mlir::failure();
1213+
1214+
auto srcType = signCast.getType().cast<mlir::ShapedType>();
1215+
auto dstType = op.getType().cast<mlir::MemRefType>();
1216+
if (srcType.getElementType() != dstType.getElementType())
1217+
return mlir::failure();
1218+
1219+
auto src = signCast.value();
1220+
auto finalType = src.getType().cast<mlir::MemRefType>();
1221+
1222+
auto newDstType =
1223+
mlir::MemRefType::get(dstType.getShape(), dstType.getElementType(),
1224+
dstType.getLayout(), finalType.getMemorySpace());
1225+
1226+
auto loc = op.getLoc();
1227+
auto offset = op.getMixedOffsets().front();
1228+
auto sizes = op.getMixedSizes();
1229+
auto strides = op.getMixedStrides();
1230+
auto cast = rewriter.createOrFold<mlir::memref::ReinterpretCastOp>(
1231+
loc, newDstType, src, offset, sizes, strides);
1232+
rewriter.replaceOpWithNewOp<plier::SignCastOp>(op, dstType, cast);
1233+
1234+
return mlir::success();
1235+
}
1236+
};
1237+
1238+
struct SignCastLoadPropagate
1239+
: public mlir::OpRewritePattern<mlir::memref::LoadOp> {
1240+
using OpRewritePattern::OpRewritePattern;
1241+
1242+
mlir::LogicalResult
1243+
matchAndRewrite(mlir::memref::LoadOp op,
1244+
mlir::PatternRewriter &rewriter) const override {
1245+
auto signCast = op.memref().getDefiningOp<plier::SignCastOp>();
1246+
if (!signCast)
1247+
return mlir::failure();
1248+
1249+
auto loc = op.getLoc();
1250+
auto src = signCast.value();
1251+
auto newOp =
1252+
rewriter.createOrFold<mlir::memref::LoadOp>(loc, src, op.indices());
1253+
1254+
if (newOp.getType() != op.getType())
1255+
newOp = rewriter.create<plier::SignCastOp>(loc, op.getType(), newOp);
1256+
1257+
rewriter.replaceOp(op, newOp);
1258+
return mlir::success();
1259+
}
1260+
};
1261+
1262+
struct SignCastStorePropagate
1263+
: public mlir::OpRewritePattern<mlir::memref::StoreOp> {
1264+
using OpRewritePattern::OpRewritePattern;
1265+
1266+
mlir::LogicalResult
1267+
matchAndRewrite(mlir::memref::StoreOp op,
1268+
mlir::PatternRewriter &rewriter) const override {
1269+
auto signCast = op.memref().getDefiningOp<plier::SignCastOp>();
1270+
if (!signCast)
1271+
return mlir::failure();
1272+
1273+
auto src = signCast.value();
1274+
auto srcElemType = src.getType().cast<mlir::MemRefType>().getElementType();
1275+
auto val = op.value();
1276+
if (val.getType() != srcElemType)
1277+
val = rewriter.create<plier::SignCastOp>(op.getLoc(), srcElemType, val);
1278+
1279+
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, val, src,
1280+
op.indices());
1281+
return mlir::success();
1282+
}
1283+
};
1284+
12031285
template <typename Op>
12041286
struct SignCastAllocPropagate
12051287
: public mlir::OpRewritePattern<plier::SignCastOp> {
@@ -1223,7 +1305,7 @@ struct SignCastAllocPropagate
12231305

12241306
struct SignCastTensorFromElementsPropagate
12251307
: public mlir::OpRewritePattern<plier::SignCastOp> {
1226-
using mlir::OpRewritePattern<plier::SignCastOp>::OpRewritePattern;
1308+
using OpRewritePattern::OpRewritePattern;
12271309

12281310
mlir::LogicalResult
12291311
matchAndRewrite(plier::SignCastOp op,
@@ -1422,7 +1504,8 @@ void SignCastOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results,
14221504
SignCastCastPropagate<mlir::tensor::CastOp>,
14231505
SignCastCastPropagate<mlir::memref::CastOp>,
14241506
SignCastCastPropagate<plier::ChangeLayoutOp>,
1425-
SignCastAllocPropagate<mlir::memref::AllocOp>,
1507+
SignCastReinterpretPropagate, SignCastLoadPropagate,
1508+
SignCastStorePropagate, SignCastAllocPropagate<mlir::memref::AllocOp>,
14261509
SignCastAllocPropagate<mlir::memref::AllocaOp>,
14271510
SignCastTensorFromElementsPropagate, SignCastTensorCollapseShapePropagate,
14281511
SignCastBuferizationPropagate<mlir::bufferization::ToMemrefOp>,

numba_dpcomp/numba_dpcomp/mlir/kernel_impl.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from numba import prange
1818
from numba.core import types
19+
from numba.core.typing.npydecl import parse_dtype, parse_shape
20+
from numba.core.types.npytypes import Array
1921
from numba.core.typing.templates import (
2022
AbstractTemplate,
2123
ConcreteTemplate,
@@ -338,3 +340,40 @@ def _memf_fence_impl(builder, flags=None):
338340
@infer_global(mem_fence)
339341
class _MemFenceId(ConcreteTemplate):
340342
cases = [signature(types.void, types.int64), signature(types.void)]
343+
344+
345+
class local(Stub):
346+
pass
347+
348+
349+
def local_array(shape, dtype):
350+
_stub_error()
351+
352+
353+
setattr(local, "array", local_array)
354+
355+
356+
@infer_global(local_array)
357+
class _LocalId(AbstractTemplate):
358+
def generic(self, args, kws):
359+
shape = kws["shape"] if "shape" in kws else args[0]
360+
dtype = kws["dtype"] if "dtype" in kws else args[1]
361+
362+
ndim = parse_shape(shape)
363+
dtype = parse_dtype(dtype)
364+
arr_type = Array(dtype=dtype, ndim=ndim, layout="C")
365+
return signature(arr_type, shape, dtype)
366+
367+
368+
@registry.register_func("local_array", local_array)
369+
def _local_array_impl(builder, shape, dtype):
370+
try:
371+
len(shape) # will raise if not available
372+
except:
373+
shape = (shape,)
374+
375+
func_name = f"local_array_{dtype_str(builder, dtype)}_{len(shape)}"
376+
res = builder.init_tensor(shape, dtype)
377+
return builder.external_call(
378+
func_name, inputs=shape, outputs=res, return_tensor=True
379+
)

0 commit comments

Comments
 (0)