Skip to content

Commit b09289c

Browse files
silee2svkeerthy
authored andcommitted
[MLIR][XeVM] Add XeVM 1D block operations to OpenCL calls conversion. (#161702)
XeVM 1D block load store operations are converted to OpenCL subgroup operations described here: https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_local_block_io.html
1 parent b44647d commit b09289c

File tree

2 files changed

+168
-1
lines changed

2 files changed

+168
-1
lines changed

mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,10 @@ static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
214214
return op.getCacheControl();
215215
}
216216

217+
static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
218+
return op.getCacheControl();
219+
}
220+
217221
static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
218222
return op.getCacheControl();
219223
}
@@ -222,6 +226,10 @@ static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
222226
return op.getCacheControl();
223227
}
224228

229+
static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
230+
return op.getCacheControl();
231+
}
232+
225233
static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
226234
if (op->hasAttr("cache_control")) {
227235
auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
@@ -263,6 +271,7 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
263271
constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
264272
std::is_same_v<OpType, BlockPrefetch2dOp> ||
265273
std::is_same_v<OpType, LLVM::LoadOp> ||
274+
std::is_same_v<OpType, BlockLoadOp> ||
266275
std::is_same_v<OpType, PrefetchOp>;
267276
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
268277
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
@@ -618,6 +627,77 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
618627
return success();
619628
}
620629
};
630+
631+
template <typename OpType>
632+
class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
633+
using OpConversionPattern<OpType>::OpConversionPattern;
634+
LogicalResult
635+
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
636+
ConversionPatternRewriter &rewriter) const override {
637+
constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
638+
// Get OpenCL function name
639+
// https://registry.khronos.org/OpenCL/extensions/
640+
// intel/cl_intel_subgroup_local_block_io.html
641+
std::string funcName{"intel_sub_group_block_"};
642+
// Value or Result type can be vector or scalar
643+
Type valOrResTy;
644+
if constexpr (isStore) {
645+
funcName += "write_u";
646+
valOrResTy = op.getVal().getType();
647+
} else {
648+
funcName += "read_u";
649+
valOrResTy = op.getType();
650+
}
651+
// Get element type of the vector/scalar
652+
VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
653+
Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
654+
funcName += getTypeMangling(elemType);
655+
if (vecTy)
656+
funcName += std::to_string(vecTy.getNumElements());
657+
SmallVector<Type, 2> argTypes{};
658+
// XeVM BlockLoad/StoreOp always use signless integer types
659+
// but OpenCL builtins expect unsigned types
660+
// use unsigned types for mangling
661+
SmallVector<bool, 2> isUnsigned{};
662+
// arg0: pointer to the src/dst address
663+
// arg1 - only if store : vector to store
664+
// Prepare arguments
665+
SmallVector<Value, 2> args{};
666+
args.push_back(op.getPtr());
667+
argTypes.push_back(op.getPtr().getType());
668+
isUnsigned.push_back(true);
669+
Type retType;
670+
if constexpr (isStore) {
671+
args.push_back(op.getVal());
672+
argTypes.push_back(op.getVal().getType());
673+
isUnsigned.push_back(true);
674+
retType = LLVM::LLVMVoidType::get(rewriter.getContext());
675+
} else {
676+
retType = valOrResTy;
677+
}
678+
funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName +
679+
"PU3AS" +
680+
std::to_string(op.getPtr().getType().getAddressSpace());
681+
funcName += getTypeMangling(elemType, /*isUnsigned=*/true);
682+
if constexpr (isStore)
683+
funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true);
684+
LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
685+
686+
LLVM::CallOp call =
687+
createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
688+
{}, funcAttr, op.getOperation());
689+
if (std::optional<ArrayAttr> optCacheControls =
690+
getCacheControlMetadata(rewriter, op)) {
691+
call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
692+
}
693+
if constexpr (isStore)
694+
rewriter.eraseOp(op);
695+
else
696+
rewriter.replaceOp(op, call->getResult(0));
697+
return success();
698+
}
699+
};
700+
621701
template <typename OpType>
622702
class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
623703
using OpConversionPattern<OpType>::OpConversionPattern;
@@ -693,7 +773,10 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
693773
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
694774
MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
695775
LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
696-
LLVMLoadStoreToOCLPattern<LLVM::StoreOp>>(patterns.getContext());
776+
LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
777+
BlockLoadStore1DToOCLPattern<BlockLoadOp>,
778+
BlockLoadStore1DToOCLPattern<BlockStoreOp>>(
779+
patterns.getContext());
697780
}
698781

699782
void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry &registry) {

mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,87 @@ llvm.func @llvm.store(%a: !llvm.ptr<1>, %val: i32) {
261261
llvm.store %val, %a {cache_control=#xevm.store_cache_control<L1wt_L2uc_L3wb>} : i32, !llvm.ptr<1>
262262
llvm.return
263263
}
264+
265+
// -----
266+
// CHECK-LABEL: llvm.func spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS1t
267+
// CHECK: llvm.func @blockload_as1(%[[ARG0:.*]]: !llvm.ptr<1>)
268+
llvm.func @blockload_as1(%ptr: !llvm.ptr<1>) -> vector<8xi16> {
269+
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS1t(%[[ARG0]])
270+
// CHECK-SAME: {function_type = !llvm.func<vector<8xi16> (ptr<1>)>, linkage = #llvm.linkage<external>,
271+
// CHECK-SAME: no_unwind, sym_name = "_Z30intel_sub_group_block_read_us8PU3AS1t",
272+
// CHECK-SAME: visibility_ = 0 : i64, will_return, xevm.DecorationCacheControl =
273+
// CHECK-SAME: [6442 : i32, 0 : i32, 1 : i32, 0 : i32],
274+
// CHECK-SAME: [6442 : i32, 1 : i32, 1 : i32, 0 : i32]
275+
%loaded_a = xevm.blockload %ptr <{cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>) -> vector<8xi16>
276+
llvm.return %loaded_a : vector<8xi16>
277+
}
278+
279+
// -----
280+
// CHECK-LABEL: llvm.func spir_funccc @_Z31intel_sub_group_block_read_uc16PU3AS3h(!llvm.ptr<3>)
281+
// CHECK: llvm.func @blockload_as3(%[[ARG0:.*]]: !llvm.ptr<3>)
282+
llvm.func @blockload_as3(%ptr: !llvm.ptr<3>) -> vector<16xi8> {
283+
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z31intel_sub_group_block_read_uc16PU3AS3h(%[[ARG0]])
284+
// CHECK-SAME: {function_type = !llvm.func<vector<16xi8> (ptr<3>)>, linkage = #llvm.linkage<external>,
285+
// CHECK-SAME: no_unwind, sym_name = "_Z31intel_sub_group_block_read_uc16PU3AS3h", visibility_ = 0 : i64,
286+
// CHECK-SAME: will_return, xevm.DecorationCacheControl =
287+
// CHECK-SAME: [6442 : i32, 0 : i32, 1 : i32, 0 : i32],
288+
// CHECK-SAME: [6442 : i32, 1 : i32, 1 : i32, 0 : i32]
289+
%loaded_a = xevm.blockload %ptr <{cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<3>) -> vector<16xi8>
290+
llvm.return %loaded_a : vector<16xi8>
291+
}
292+
293+
// -----
294+
// CHECK-LABEL: llvm.func spir_funccc @_Z29intel_sub_group_block_read_ucPU3AS3h(!llvm.ptr<3>)
295+
// CHECK: llvm.func @blockload_scalar(%[[ARG0:.*]]: !llvm.ptr<3>)
296+
llvm.func @blockload_scalar(%ptr: !llvm.ptr<3>) -> i8 {
297+
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z29intel_sub_group_block_read_ucPU3AS3h(%[[ARG0]])
298+
// CHECK-SAME: {function_type = !llvm.func<i8 (ptr<3>)>, linkage = #llvm.linkage<external>,
299+
// CHECK-SAME: no_unwind, sym_name = "_Z29intel_sub_group_block_read_ucPU3AS3h", visibility_ = 0 : i64,
300+
// CHECK-SAME: will_return, xevm.DecorationCacheControl =
301+
// CHECK-SAME: [6442 : i32, 0 : i32, 1 : i32, 0 : i32],
302+
// CHECK-SAME: [6442 : i32, 1 : i32, 1 : i32, 0 : i32]
303+
%loaded_a = xevm.blockload %ptr <{cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<3>) -> i8
304+
llvm.return %loaded_a : i8
305+
}
306+
307+
// -----
308+
// CHECK-LABEL: llvm.func spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS1jDv8_j
309+
// CHECK: llvm.func @blockstore_as1(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: vector<8xi32>) {
310+
llvm.func @blockstore_as1(%ptr: !llvm.ptr<1>, %data: vector<8xi32>) {
311+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS1jDv8_j(%[[ARG0]], %[[ARG1]])
312+
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, vector<8xi32>)>, linkage = #llvm.linkage<external>,
313+
// CHECK-SAME: no_unwind, sym_name = "_Z31intel_sub_group_block_write_ui8PU3AS1jDv8_j", visibility_ = 0 : i64,
314+
// CHECK-SAME: will_return, xevm.DecorationCacheControl =
315+
// CHECK-SAME: [6443 : i32, 0 : i32, 2 : i32, 0 : i32],
316+
// CHECK-SAME: [6443 : i32, 1 : i32, 2 : i32, 0 : i32]
317+
xevm.blockstore %ptr, %data <{cache_control=#xevm.store_cache_control<L1wt_L2uc_L3wb>}> : (!llvm.ptr<1>, vector<8xi32>)
318+
llvm.return
319+
}
320+
321+
// -----
322+
// CHECK-LABEL: llvm.func spir_funccc @_Z31intel_sub_group_block_write_ul2PU3AS3mDv2_m
323+
// CHECK: llvm.func @blockstore_as3(%[[ARG0:.*]]: !llvm.ptr<3>, %[[ARG1:.*]]: vector<2xi64>) {
324+
llvm.func @blockstore_as3(%ptr: !llvm.ptr<3>, %data: vector<2xi64>) {
325+
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul2PU3AS3mDv2_m(%[[ARG0]], %[[ARG1]])
326+
// CHECK-SAME: {function_type = !llvm.func<void (ptr<3>, vector<2xi64>)>, linkage = #llvm.linkage<external>,
327+
// CHECK-SAME: no_unwind, sym_name = "_Z31intel_sub_group_block_write_ul2PU3AS3mDv2_m", visibility_ = 0 : i64,
328+
// CHECK-SAME: will_return, xevm.DecorationCacheControl =
329+
// CHECK-SAME: [6443 : i32, 0 : i32, 2 : i32, 0 : i32],
330+
// CHECK-SAME: [6443 : i32, 1 : i32, 2 : i32, 0 : i32]
331+
xevm.blockstore %ptr, %data <{cache_control=#xevm.store_cache_control<L1wt_L2uc_L3wb>}> : (!llvm.ptr<3>, vector<2xi64>)
332+
llvm.return
333+
}
334+
335+
// -----
336+
// CHECK-LABEL: llvm.func spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm
337+
// CHECK: llvm.func @blockstore_scalar(%[[ARG0:.*]]: !llvm.ptr<3>, %[[ARG1:.*]]: i64) {
338+
llvm.func @blockstore_scalar(%ptr: !llvm.ptr<3>, %data: i64) {
339+
// CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm(%[[ARG0]], %[[ARG1]])
340+
// CHECK-SAME: {function_type = !llvm.func<void (ptr<3>, i64)>, linkage = #llvm.linkage<external>,
341+
// CHECK-SAME: no_unwind, sym_name = "_Z30intel_sub_group_block_write_ulPU3AS3mm", visibility_ = 0 : i64,
342+
// CHECK-SAME: will_return, xevm.DecorationCacheControl =
343+
// CHECK-SAME: [6443 : i32, 0 : i32, 2 : i32, 0 : i32],
344+
// CHECK-SAME: [6443 : i32, 1 : i32, 2 : i32, 0 : i32]
345+
xevm.blockstore %ptr, %data <{cache_control=#xevm.store_cache_control<L1wt_L2uc_L3wb>}> : (!llvm.ptr<3>, i64)
346+
llvm.return
347+
}

0 commit comments

Comments
 (0)