Skip to content

Commit aec6648

Browse files
authored
Merge pull request #725 from mshahneo/xegpu-to-intel-intrinsic-to-spirv_for_upstream
[func][spirv] Add necessary patch, pass-pipeline, and test case for n…
2 parents 7b5d13c + 4e437b4 commit aec6648

File tree

4 files changed

+407
-0
lines changed

4 files changed

+407
-0
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
Upstream PR in progress: https://github.com/llvm/llvm-project/pull/86750
2+
3+
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
4+
index 57d8e894a24b..3fc68c65de05 100644
5+
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
6+
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
7+
@@ -306,6 +306,18 @@ public:
8+
}
9+
};
10+
11+
+/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
12+
+class ExtractAlignedPointerAsIndexOpPattern
13+
+ : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
14+
+public:
15+
+ using OpConversionPattern::OpConversionPattern;
16+
+
17+
+ LogicalResult
18+
+ matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
19+
+ OpAdaptor adaptor,
20+
+ ConversionPatternRewriter &rewriter) const override;
21+
+};
22+
+
23+
} // namespace
24+
25+
//===----------------------------------------------------------------------===//
26+
@@ -905,6 +917,20 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
27+
return success();
28+
}
29+
30+
+//===----------------------------------------------------------------------===//
31+
+// ExtractAlignedPointerAsIndexOp
32+
+//===----------------------------------------------------------------------===//
33+
+
34+
+LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
35+
+ memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
36+
+ ConversionPatternRewriter &rewriter) const {
37+
+ auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
38+
+ Type indexType = typeConverter.getIndexType();
39+
+ rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
40+
+ adaptor.getSource());
41+
+ return success();
42+
+}
43+
+
44+
//===----------------------------------------------------------------------===//
45+
// Pattern population
46+
//===----------------------------------------------------------------------===//
47+
@@ -912,10 +938,11 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
48+
namespace mlir {
49+
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
50+
RewritePatternSet &patterns) {
51+
- patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
52+
- DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
53+
- LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
54+
- ReinterpretCastPattern, CastPattern>(typeConverter,
55+
- patterns.getContext());
56+
+ patterns
57+
+ .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
58+
+ DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
59+
+ MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
60+
+ CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
61+
+ typeConverter, patterns.getContext());
62+
}
63+
} // namespace mlir
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// gpu dialect with intel intrinsic functions (func dialect) to
2+
// llvm dialect (for host code) and
3+
// spirv dialect (for device code) lowering pipeline.
4+
// Ready for imex runner starting from GPU dialect.
5+
builtin.module(
6+
gpu.module(convert-func-to-spirv)
7+
gpu.module(convert-vector-to-spirv)
8+
imex-convert-gpu-to-spirv{enable-vc-intrinsic=true}
9+
spirv.module(spirv-lower-abi-attrs
10+
spirv-update-vce)
11+
func.func(llvm-request-c-wrappers)
12+
serialize-spirv
13+
convert-vector-to-scf
14+
convert-gpu-to-gpux
15+
convert-scf-to-cf
16+
convert-cf-to-llvm
17+
convert-vector-to-llvm
18+
convert-index-to-llvm
19+
convert-arith-to-llvm
20+
convert-func-to-llvm
21+
convert-math-to-llvm
22+
convert-gpux-to-llvm
23+
convert-index-to-llvm
24+
expand-strided-metadata
25+
lower-affine
26+
finalize-memref-to-llvm
27+
reconcile-unrealized-casts)
28+
// End
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/func-to-llvm.pp \
2+
// RUN: --runner imex-cpu-runner -e main \
3+
// RUN: --entry-point-result=void \
4+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
5+
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/func-to-llvm.pp \
6+
// RUN: --runner imex-cpu-runner -e main \
7+
// RUN: --entry-point-result=void \
8+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
9+
10+
module @gemm attributes {gpu.container_module,
11+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
12+
memref.global "private" constant @__constant_8x16xf16 : memref<8x16xf16> = dense<5.000000e-01>
13+
memref.global "private" constant @__constant_16x16xf16 : memref<16x16xf16> = dense<1.099610e+00>
14+
func.func @test(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
15+
%c1 = arith.constant 1 : index
16+
%memref = gpu.alloc host_shared () : memref<8x16xf16>
17+
memref.copy %arg0, %memref : memref<8x16xf16> to memref<8x16xf16>
18+
%memref_0 = gpu.alloc host_shared () : memref<16x16xf16>
19+
memref.copy %arg1, %memref_0 : memref<16x16xf16> to memref<16x16xf16>
20+
%memref_1 = gpu.alloc host_shared () : memref<8x16xf32>
21+
gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf16>, %memref_0 : memref<16x16xf16>, %memref_1 : memref<8x16xf32>)
22+
gpu.dealloc %memref : memref<8x16xf16>
23+
gpu.dealloc %memref_0 : memref<16x16xf16>
24+
return %memref_1 : memref<8x16xf32>
25+
}
26+
27+
gpu.module @test_kernel {
28+
func.func private @llvm.genx.raw.sends2.noresult.i1.v8i32.v64i64(i8, i8, i1, i8, i8, i8, i32, i32, vector<8xi32>, vector<64xi64>) attributes{
29+
linkage_attributes=#spirv.linkage_attributes<
30+
linkage_name="llvm.genx.raw.sends2.noresult.i1.v8i32.v64i64",
31+
linkage_type=<Import>
32+
>,
33+
VectorComputeFunctionINTEL}
34+
func.func private @llvm.genx.dpas.nosrc0.v128f32.v128i32.v64i32(vector<128xi32>, vector<64xi32>, i32) -> vector<128xf32> attributes{
35+
linkage_attributes=#spirv.linkage_attributes<
36+
linkage_name="llvm.genx.dpas.nosrc0.v128f32.v128i32.v64i32",
37+
linkage_type=<Import>
38+
>,
39+
VectorComputeFunctionINTEL}
40+
func.func private @llvm.genx.raw.send2.v128i32.i1.v8i32(i8, i8, i1, i8, i8, i8, i32, i32, vector<8xi32>, vector<128xi32>) -> vector<128xi32> attributes{
41+
linkage_attributes=#spirv.linkage_attributes<
42+
linkage_name="llvm.genx.raw.send2.v128i32.i1.v8i32",
43+
linkage_type=<Import>
44+
>,
45+
VectorComputeFunctionINTEL}
46+
func.func private @llvm.genx.raw.send2.v32i64.i1.v8i32(i8, i8, i1, i8, i8, i8, i32, i32, vector<8xi32>, vector<32xi64>) -> vector<32xi64> attributes{
47+
linkage_attributes=#spirv.linkage_attributes<
48+
linkage_name="llvm.genx.raw.send2.v32i64.i1.v8i32",
49+
linkage_type=<Import>
50+
>,
51+
VectorComputeFunctionINTEL}
52+
gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL,spirv.entry_point_abi = #spirv.entry_point_abi<>} {
53+
%reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [128], strides: [1] : memref<8x16xf16> to memref<128xf16>
54+
%cst = arith.constant dense<0> : vector<4xi64>
55+
%intptr = memref.extract_aligned_pointer_as_index %reinterpret_cast : memref<128xf16> -> index
56+
%0 = arith.index_castui %intptr : index to i64
57+
%1 = vector.insert %0, %cst[0] : i64 into vector<4xi64>
58+
%2 = vector.bitcast %1 : vector<4xi64> to vector<8xi32>
59+
%cst_0 = arith.constant dense<0> : vector<4xi64>
60+
%cst_to_non_cst = arith.addi %cst_0, %cst_0 : vector<4xi64>
61+
%intptr_1 = memref.extract_aligned_pointer_as_index %arg1 : memref<16x16xf16> -> index
62+
%3 = arith.index_castui %intptr_1 : index to i64
63+
%4 = vector.insert %3, %cst_to_non_cst[0] : i64 into vector<4xi64>
64+
%5 = vector.bitcast %4 : vector<4xi64> to vector<8xi32>
65+
%c31_i32 = arith.constant 31 : i32
66+
%c15_i32 = arith.constant 15 : i32
67+
%c31_i32_2 = arith.constant 31 : i32
68+
%6 = vector.insert %c31_i32, %5 [2] : i32 into vector<8xi32>
69+
%7 = vector.insert %c15_i32, %6 [3] : i32 into vector<8xi32>
70+
%8 = vector.insert %c31_i32_2, %7 [4] : i32 into vector<8xi32>
71+
%c0_i32 = arith.constant 0 : i32
72+
%c0_i32_3 = arith.constant 0 : i32
73+
%9 = vector.insert %c0_i32, %8 [5] : i32 into vector<8xi32>
74+
%10 = vector.insert %c0_i32_3, %9 [6] : i32 into vector<8xi32>
75+
%c3855_i32 = arith.constant 3855 : i32
76+
%11 = vector.insert %c3855_i32, %10 [7] : i32 into vector<8xi32>
77+
%reinterpret_cast_4 = memref.reinterpret_cast %arg2 to offset: [0], sizes: [128], strides: [1] : memref<8x16xf32> to memref<128xf32>
78+
%cst_5_t = arith.constant dense<0> : vector<4xi64>
79+
%cst_5 = arith.addi %cst_5_t, %cst_5_t : vector<4xi64>
80+
%intptr_6 = memref.extract_aligned_pointer_as_index %reinterpret_cast_4 : memref<128xf32> -> index
81+
%12 = arith.index_castui %intptr_6 : index to i64
82+
%13 = vector.insert %12, %cst_5 [0] : i64 into vector<4xi64>
83+
%14 = vector.bitcast %13 : vector<4xi64> to vector<8xi32>
84+
%c0_i8 = arith.constant 0 : i8
85+
%c0_i8_7 = arith.constant 0 : i8
86+
%true = arith.constant true
87+
%c1_i8 = arith.constant 1 : i8
88+
%c4_i8 = arith.constant 4 : i8
89+
%c15_i8 = arith.constant 15 : i8
90+
%c0_i32_8 = arith.constant 0 : i32
91+
%c42133376_i32 = arith.constant 42133376 : i32
92+
%cst_9 = arith.constant dense<0> : vector<32xi64>
93+
%15 = func.call @llvm.genx.raw.send2.v32i64.i1.v8i32(%c0_i8, %c0_i8_7, %true, %c1_i8, %c4_i8, %c15_i8, %c0_i32_8, %c42133376_i32, %2, %cst_9) : (i8, i8, i1, i8, i8, i8, i32, i32, vector<8xi32>, vector<32xi64>) -> vector<32xi64>
94+
%16 = vector.bitcast %15 : vector<32xi64> to vector<128xf16>
95+
%c0_i8_10 = arith.constant 0 : i8
96+
%c0_i8_11 = arith.constant 0 : i8
97+
%true_12 = arith.constant true
98+
%c1_i8_13 = arith.constant 1 : i8
99+
%c8_i8 = arith.constant 8 : i8
100+
%c15_i8_14 = arith.constant 15 : i8
101+
%c0_i32_15 = arith.constant 0 : i32
102+
%c42074755_i32 = arith.constant 42074755 : i32
103+
%cst_16 = arith.constant dense<0> : vector<128xi32>
104+
%17 = func.call @llvm.genx.raw.send2.v128i32.i1.v8i32(%c0_i8_10, %c0_i8_11, %true_12, %c1_i8_13, %c8_i8, %c15_i8_14, %c0_i32_15, %c42074755_i32, %11, %cst_16) : (i8, i8, i1, i8, i8, i8, i32, i32, vector<8xi32>, vector<128xi32>) -> vector<128xi32>
105+
%18 = vector.bitcast %16 : vector<128xf16> to vector<64xi32>
106+
%c134744586_i32 = arith.constant 134744586 : i32
107+
%19 = func.call @llvm.genx.dpas.nosrc0.v128f32.v128i32.v64i32(%17, %18, %c134744586_i32) : (vector<128xi32>, vector<64xi32>, i32) -> vector<128xf32>
108+
%c0_i8_17 = arith.constant 0 : i8
109+
%c0_i8_18 = arith.constant 0 : i8
110+
%true_19 = arith.constant true
111+
%c1_i8_20 = arith.constant 1 : i8
112+
%c8_i8_21 = arith.constant 8 : i8
113+
%c15_i8_22 = arith.constant 15 : i8
114+
%c0_i32_23 = arith.constant 0 : i32
115+
%c33748868_i32 = arith.constant 33748868 : i32
116+
%20 = vector.bitcast %19 : vector<128xf32> to vector<64xi64>
117+
func.call @llvm.genx.raw.sends2.noresult.i1.v8i32.v64i64(%c0_i8_17, %c0_i8_18, %true_19, %c1_i8_20, %c8_i8_21, %c15_i8_22, %c0_i32_23, %c33748868_i32, %14, %20) : (i8, i8, i1, i8, i8, i8, i32, i32, vector<8xi32>, vector<64xi64>) -> ()
118+
gpu.return
119+
}
120+
}
121+
func.func @main() attributes {llvm.emit_c_interface} {
122+
%0 = memref.get_global @__constant_8x16xf16 : memref<8x16xf16>
123+
%1 = memref.get_global @__constant_16x16xf16 : memref<16x16xf16>
124+
%2 = call @test(%0, %1) : (memref<8x16xf16>, memref<16x16xf16>) -> memref<8x16xf32>
125+
%cast = memref.cast %2 : memref<8x16xf32> to memref<*xf32>
126+
call @printMemrefF32(%cast) : (memref<*xf32>) -> ()
127+
// CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}}
128+
// CHECK-COUNT-128: 8.79688
129+
return
130+
}
131+
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
132+
}

0 commit comments

Comments
 (0)