Skip to content

Commit 5d1636d

Browse files
authored
[MLIR][XeVM] XeVM to LLVM: Add conversion patterns for id ops (#162536)
XeVM to LLVM pass: Add conversion patterns for XeVM id ops. Target OpenCL functions described here: https://registry.khronos.org/OpenCL/sdk/3.0/docs/man/html/get_group_id.html
1 parent c265d7a commit 5d1636d

File tree

2 files changed

+291
-2
lines changed

2 files changed

+291
-2
lines changed

mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,135 @@ class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
714714
}
715715
};
716716

717+
//===----------------------------------------------------------------------===//
718+
// GPU index id operations
719+
//===----------------------------------------------------------------------===//
720+
/*
721+
// Launch Config ops
722+
// dimidx - x, y, z - is fixed to i32
723+
// return type is set by XeVM type converter
724+
// get_local_id
725+
xevm::WorkitemIdXOp;
726+
xevm::WorkitemIdYOp;
727+
xevm::WorkitemIdZOp;
728+
// get_local_size
729+
xevm::WorkgroupDimXOp;
730+
xevm::WorkgroupDimYOp;
731+
xevm::WorkgroupDimZOp;
732+
// get_group_id
733+
xevm::WorkgroupIdXOp;
734+
xevm::WorkgroupIdYOp;
735+
xevm::WorkgroupIdZOp;
736+
// get_num_groups
737+
xevm::GridDimXOp;
738+
xevm::GridDimYOp;
739+
xevm::GridDimZOp;
740+
// get_global_id : to be added if needed
741+
*/
742+
743+
// Helpers to get the OpenCL function name and dimension argument for each op.
744+
static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
745+
return {"get_local_id", 0};
746+
}
747+
static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
748+
return {"get_local_id", 1};
749+
}
750+
static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
751+
return {"get_local_id", 2};
752+
}
753+
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
754+
return {"get_local_size", 0};
755+
}
756+
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
757+
return {"get_local_size", 1};
758+
}
759+
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
760+
return {"get_local_size", 2};
761+
}
762+
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
763+
return {"get_group_id", 0};
764+
}
765+
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
766+
return {"get_group_id", 1};
767+
}
768+
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
769+
return {"get_group_id", 2};
770+
}
771+
static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
772+
return {"get_num_groups", 0};
773+
}
774+
static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
775+
return {"get_num_groups", 1};
776+
}
777+
static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
778+
return {"get_num_groups", 2};
779+
}
780+
/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
781+
/// a constant argument for the dimension - x, y or z.
782+
template <typename OpType>
783+
class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
784+
using OpConversionPattern<OpType>::OpConversionPattern;
785+
LogicalResult
786+
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
787+
ConversionPatternRewriter &rewriter) const override {
788+
Location loc = op->getLoc();
789+
auto [baseName, dim] = getConfig(op);
790+
Type dimTy = rewriter.getI32Type();
791+
Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
792+
static_cast<int64_t>(dim));
793+
std::string func = mangle(baseName, {dimTy}, {true});
794+
Type resTy = op.getType();
795+
auto call =
796+
createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
797+
noUnwindWillReturnAttrs, op.getOperation());
798+
constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
799+
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
800+
/*other=*/noModRef,
801+
/*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
802+
call.setMemoryEffectsAttr(memAttr);
803+
rewriter.replaceOp(op, call);
804+
return success();
805+
}
806+
};
807+
808+
/*
809+
// Subgroup ops
810+
// get_sub_group_local_id
811+
xevm::LaneIdOp;
812+
// get_sub_group_id
813+
xevm::SubgroupIdOp;
814+
// get_sub_group_size
815+
xevm::SubgroupSizeOp;
816+
// get_num_sub_groups : to be added if needed
817+
*/
818+
819+
// Helpers to get the OpenCL function name for each op.
820+
static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
821+
static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
822+
static StringRef getConfig(xevm::SubgroupSizeOp) {
823+
return "get_sub_group_size";
824+
}
825+
template <typename OpType>
826+
class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
827+
using OpConversionPattern<OpType>::OpConversionPattern;
828+
LogicalResult
829+
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
830+
ConversionPatternRewriter &rewriter) const override {
831+
std::string func = mangle(getConfig(op).str(), {});
832+
Type resTy = op.getType();
833+
auto call =
834+
createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
835+
noUnwindWillReturnAttrs, op.getOperation());
836+
constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
837+
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
838+
/*other=*/noModRef,
839+
/*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
840+
call.setMemoryEffectsAttr(memAttr);
841+
rewriter.replaceOp(op, call);
842+
return success();
843+
}
844+
};
845+
717846
//===----------------------------------------------------------------------===//
718847
// Pass Definition
719848
//===----------------------------------------------------------------------===//
@@ -775,7 +904,22 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
775904
LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
776905
LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
777906
BlockLoadStore1DToOCLPattern<BlockLoadOp>,
778-
BlockLoadStore1DToOCLPattern<BlockStoreOp>>(
907+
BlockLoadStore1DToOCLPattern<BlockStoreOp>,
908+
LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
909+
LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
910+
LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
911+
LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
912+
LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
913+
LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
914+
LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
915+
LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
916+
LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
917+
LaunchConfigOpToOCLPattern<GridDimXOp>,
918+
LaunchConfigOpToOCLPattern<GridDimYOp>,
919+
LaunchConfigOpToOCLPattern<GridDimZOp>,
920+
SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
921+
SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
922+
SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>(
779923
patterns.getContext());
780924
}
781925

mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32
3535
// -----
3636
// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
3737
llvm.func @blockload2d_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
38-
// CHECK: xevm.DecorationCacheControl =
38+
// CHECK: xevm.DecorationCacheControl =
3939
// CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32
4040
// CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32
4141
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
@@ -345,3 +345,148 @@ llvm.func @blockstore_scalar(%ptr: !llvm.ptr<3>, %data: i64) {
345345
xevm.blockstore %ptr, %data <{cache_control=#xevm.store_cache_control<L1wt_L2uc_L3wb>}> : (!llvm.ptr<3>, i64)
346346
llvm.return
347347
}
348+
349+
// -----
350+
// CHECK-LABEL: llvm.func @local_id.x() -> i32 {
351+
llvm.func @local_id.x() -> i32 {
352+
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
353+
// CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[VAR0]])
354+
// CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
355+
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
356+
// CHECK-SAME: no_unwind, sym_name = "_Z12get_local_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32
357+
%1 = xevm.local_id.x : i32
358+
llvm.return %1 : i32
359+
}
360+
361+
// -----
362+
// CHECK-LABEL: llvm.func @local_id.y() -> i32 {
363+
llvm.func @local_id.y() -> i32 {
364+
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
365+
%1 = xevm.local_id.y : i32
366+
llvm.return %1 : i32
367+
}
368+
369+
// -----
370+
// CHECK-LABEL: llvm.func @local_id.z() -> i32 {
371+
llvm.func @local_id.z() -> i32 {
372+
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
373+
%1 = xevm.local_id.z : i32
374+
llvm.return %1 : i32
375+
}
376+
377+
// -----
378+
// CHECK-LABEL: llvm.func @local_size.x() -> i32 {
379+
llvm.func @local_size.x() -> i32 {
380+
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
381+
// CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_local_sizej(%[[VAR0]])
382+
// CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
383+
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
384+
// CHECK-SAME: no_unwind, sym_name = "_Z14get_local_sizej", visibility_ = 0 : i64, will_return} : (i32) -> i32
385+
%1 = xevm.local_size.x : i32
386+
llvm.return %1 : i32
387+
}
388+
389+
// -----
390+
// CHECK-LABEL: llvm.func @local_size.y() -> i32 {
391+
llvm.func @local_size.y() -> i32 {
392+
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
393+
%1 = xevm.local_size.y : i32
394+
llvm.return %1 : i32
395+
}
396+
397+
// -----
398+
// CHECK-LABEL: llvm.func @local_size.z() -> i32 {
399+
llvm.func @local_size.z() -> i32 {
400+
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
401+
%1 = xevm.local_size.z : i32
402+
llvm.return %1 : i32
403+
}
404+
405+
// -----
406+
// CHECK-LABEL: llvm.func @group_id.x() -> i32 {
407+
llvm.func @group_id.x() -> i32 {
408+
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
409+
// CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_group_idj(%[[VAR0]])
410+
// CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
411+
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
412+
// CHECK-SAME: no_unwind, sym_name = "_Z12get_group_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32
413+
%1 = xevm.group_id.x : i32
414+
llvm.return %1 : i32
415+
}
416+
417+
// -----
418+
// CHECK-LABEL: llvm.func @group_id.y() -> i32 {
419+
llvm.func @group_id.y() -> i32 {
420+
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
421+
%1 = xevm.group_id.y : i32
422+
llvm.return %1 : i32
423+
}
424+
425+
// -----
426+
// CHECK-LABEL: llvm.func @group_id.z() -> i32 {
427+
llvm.func @group_id.z() -> i32 {
428+
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
429+
%1 = xevm.group_id.z : i32
430+
llvm.return %1 : i32
431+
}
432+
433+
// -----
434+
// CHECK-LABEL: llvm.func @group_count.x() -> i32 {
435+
llvm.func @group_count.x() -> i32 {
436+
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
437+
// CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_num_groupsj(%[[VAR0]])
438+
// CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
439+
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
440+
// CHECK-SAME: no_unwind, sym_name = "_Z14get_num_groupsj", visibility_ = 0 : i64, will_return} : (i32) -> i32
441+
%1 = xevm.group_count.x : i32
442+
llvm.return %1 : i32
443+
}
444+
445+
// -----
446+
// CHECK-LABEL: llvm.func @group_count.y() -> i32 {
447+
llvm.func @group_count.y() -> i32 {
448+
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
449+
%1 = xevm.group_count.y : i32
450+
llvm.return %1 : i32
451+
}
452+
453+
// -----
454+
// CHECK-LABEL: llvm.func @group_count.z() -> i32 {
455+
llvm.func @group_count.z() -> i32 {
456+
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
457+
%1 = xevm.group_count.z : i32
458+
llvm.return %1 : i32
459+
}
460+
461+
// -----
462+
// CHECK-LABEL: llvm.func spir_funccc @_Z22get_sub_group_local_id() -> i32 attributes {no_unwind, will_return}
463+
llvm.func @lane_id() -> i32 {
464+
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
465+
// CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
466+
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
467+
// CHECK-SAME: no_unwind, sym_name = "_Z22get_sub_group_local_id", visibility_ = 0 : i64, will_return} : () -> i32
468+
%1 = xevm.lane_id : i32
469+
llvm.return %1 : i32
470+
}
471+
472+
// -----
473+
// CHECK-LABEL: llvm.func spir_funccc @_Z18get_sub_group_size() -> i32 attributes {no_unwind, will_return}
474+
llvm.func @subgroup_size() -> i32 {
475+
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z18get_sub_group_size()
476+
// CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
477+
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
478+
// CHECK-SAME: no_unwind, sym_name = "_Z18get_sub_group_size", visibility_ = 0 : i64, will_return} : () -> i32
479+
%1 = xevm.subgroup_size : i32
480+
llvm.return %1 : i32
481+
}
482+
483+
// -----
484+
// CHECK-LABEL: llvm.func spir_funccc @_Z16get_sub_group_id() -> i32 attributes {no_unwind, will_return}
485+
llvm.func @subgroup_id() -> i32 {
486+
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
487+
// CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
488+
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
489+
// CHECK-SAME: no_unwind, sym_name = "_Z16get_sub_group_id", visibility_ = 0 : i64, will_return} : () -> i32
490+
%1 = xevm.subgroup_id : i32
491+
llvm.return %1 : i32
492+
}

0 commit comments

Comments
 (0)