Skip to content

Commit 8824668

Browse files
authored
Add generic interface for lowering elementwise math operations to SPIRV (#696)
1 parent 08c371d commit 8824668

File tree

3 files changed

+215
-23
lines changed

3 files changed

+215
-23
lines changed

include/imex/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class Pass;
2424
} // namespace mlir
2525

2626
namespace imex {
27+
28+
template <typename SPIRVOp> std::string getVCIntrinsicName(SPIRVOp op);
2729
// XeGPU to VC Intrinsics pattern
2830
void populateXeGPUToVCIntrinsicsPatterns(
2931
mlir::SPIRVTypeConverter &typeConverter, mlir::RewritePatternSet &patterns);

lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "imex/Dialect/XeGPU/IR/XeGPU.h"
1717

1818
#include "../PassDetail.h"
19+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1920
#include "mlir/IR/Builders.h"
2021
#include "mlir/IR/BuiltinAttributes.h"
2122
#include "mlir/IR/BuiltinTypes.h"
@@ -44,6 +45,7 @@
4445
#include <mlir/Transforms/DialectConversion.h>
4546

4647
#include <numeric>
48+
#include <type_traits>
4749

4850
using namespace imex;
4951
using namespace imex::xegpu;
@@ -1458,39 +1460,70 @@ struct VectorShuffle final : public OpConversionPattern<vector::ShuffleOp> {
14581460
}
14591461
};
14601462

1461-
struct SpirvCLFMax : public OpConversionPattern<spirv::CLFMaxOp> {
1462-
using OpConversionPattern::OpConversionPattern;
1463+
template <typename SPIRVOp> std::string getVCIntrinsicName() {
1464+
constexpr bool isFMaxOp = std::is_same_v<SPIRVOp, spirv::CLFMaxOp>;
1465+
constexpr bool isExpOp = std::is_same_v<SPIRVOp, spirv::CLExpOp>;
1466+
if (isFMaxOp)
1467+
return "llvm.genx.fmax.";
1468+
else if (isExpOp)
1469+
return "llvm.genx.exp.";
1470+
else
1471+
assert(0 && "Unsupported SPIRV Op. Add more support!");
1472+
}
1473+
1474+
template <typename SPIRVOp>
1475+
struct SPIRVElementwiseToVC : public OpConversionPattern<SPIRVOp> {
1476+
using OpConversionPattern<SPIRVOp>::OpConversionPattern;
14631477
LogicalResult
1464-
matchAndRewrite(spirv::CLFMaxOp op, OpAdaptor adaptor,
1478+
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
14651479
ConversionPatternRewriter &rewriter) const override {
1466-
auto dstType = getTypeConverter()->convertType(op.getType());
1480+
auto dstType = this->getTypeConverter()->convertType(op.getType());
14671481
if (!dstType)
14681482
return failure();
14691483

1470-
// For scalar case, we keep the operation as is.
1471-
if (isa<spirv::ScalarType>(dstType)) {
1472-
rewriter.replaceOpWithNewOp<spirv::CLFMaxOp>(
1473-
op, dstType, adaptor.getLhs(), adaptor.getRhs());
1484+
auto vecSize = dstType.template dyn_cast<VectorType>().getNumElements();
1485+
auto hasGenericVecSize = [&]() -> bool {
1486+
// if the input is scalar, we keep the operation as is.
1487+
if (isa<spirv::ScalarType>(dstType))
1488+
return true;
1489+
// or, if the vector size is 2, 3, 4, 8, or 16, we keep the operation.
1490+
return vecSize == 2 || vecSize == 3 || vecSize == 4 || vecSize == 8 ||
1491+
vecSize == 16;
1492+
};
1493+
1494+
if (hasGenericVecSize()) {
1495+
rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
14741496
return success();
14751497
}
14761498

1477-
// For vector case, corresponding VC intrinsics is used.
1478-
// Currently only f16/f32/bf16 vectors are supported.
1479-
auto dstVecElemTy = dstType.dyn_cast<VectorType>().getElementType();
1480-
if (!dstVecElemTy.isF32() && !dstVecElemTy.isF16() &&
1481-
!dstVecElemTy.isBF16())
1482-
return rewriter.notifyMatchFailure(
1483-
op,
1484-
"Unsupported type in Spirv.CL.FMax. Only f16/f32/bf16 vectors are "
1485-
"currently supported.");
1499+
// for larger vector lengths, "llvm.genx.exp" returns the base 2
1500+
// exponentiation of the input. To get the base e exponentiation, we need to
1501+
// scale the input by log2(e)
1502+
bool isExpOp = std::is_same_v<SPIRVOp, spirv::CLExpOp>;
1503+
SmallVector<Value> args{adaptor.getOperands()};
1504+
auto operands = adaptor.getOperands();
1505+
if (isExpOp) {
1506+
SmallVector<float> log2e(vecSize, 1.442695040888963);
1507+
auto log2eConstVec = rewriter.create<spirv::ConstantOp>(
1508+
op.getLoc(), dstType, rewriter.getF32VectorAttr(log2e));
1509+
auto input = operands[0];
1510+
auto scaledInput =
1511+
rewriter.create<spirv::FMulOp>(op.getLoc(), input, log2eConstVec);
1512+
args.clear();
1513+
args.push_back(scaledInput);
1514+
}
14861515

1487-
std::string funcName = "llvm.genx.fmax.";
1488-
auto funcType = rewriter.getFunctionType(
1489-
{adaptor.getLhs().getType(), adaptor.getRhs().getType()}, {dstType});
1516+
// for large vectors, generate the corresponding VC intrinsic.
1517+
auto funcName = getVCIntrinsicName<SPIRVOp>();
1518+
SmallVector<Type> operandTypes;
1519+
for (auto operand : adaptor.getOperands())
1520+
operandTypes.push_back(operand.getType());
1521+
auto funcType = rewriter.getFunctionType(operandTypes, {dstType});
14901522
funcName +=
1491-
encodeVectorType(rewriter, dstType.dyn_cast<VectorType>()).first;
1523+
encodeVectorType(rewriter, dstType.template dyn_cast<VectorType>())
1524+
.first;
14921525
lookupOrInsertIntrinsic(rewriter, op, funcName, funcType);
1493-
SmallVector<Value> args{adaptor.getLhs(), adaptor.getRhs()};
1526+
14941527
rewriter.replaceOpWithNewOp<spirv::FunctionCallOp>(op, dstType, funcName,
14951528
args);
14961529
return success();
@@ -1505,7 +1538,9 @@ void imex::populateXeGPUToVCIntrinsicsPatterns(
15051538
NbarrierArriveToVCPattern, NbarrierWaitToVCPattern,
15061539
CompilerHintToVCPattern, MfenceToVCPattern, VectorShapeCast,
15071540
VectorExtract, VectorExtractStridedSlice, VectorShuffle,
1508-
SpirvCLFMax, GatherScatterToRawSend<LoadGatherOp>,
1541+
SPIRVElementwiseToVC<spirv::CLFMaxOp>,
1542+
SPIRVElementwiseToVC<spirv::CLExpOp>,
1543+
GatherScatterToRawSend<LoadGatherOp>,
15091544
GatherScatterToRawSend<StoreScatterOp>, AtomicToLsc,
15101545
UpdateNDOffsetToVCPattern>(typeConverter, patterns.getContext());
15111546
if (getenv("IMEX_NOT_PREFER_RAWSEND"))
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \
2+
// RUN: --runner imex-cpu-runner -e main \
3+
// RUN: --entry-point-result=void \
4+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
5+
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \
6+
// RUN: --runner imex-cpu-runner -e main \
7+
// RUN: --entry-point-result=void \
8+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
9+
module @gemm attributes {gpu.container_module} {
10+
func.func @test(%A: memref<8x16xf16>, %B: memref<16x16xf16> ) -> (memref<8x16xf32>, memref<8x16xf32>) attributes {llvm.emit_c_interface} {
11+
%c1 = arith.constant 1 : index
12+
%memref = gpu.alloc host_shared () : memref<8x16xf16>
13+
%memref_1 = gpu.alloc host_shared () : memref<16x16xf16>
14+
memref.copy %A, %memref : memref<8x16xf16> to memref<8x16xf16>
15+
memref.copy %B, %memref_1 : memref<16x16xf16> to memref<16x16xf16>
16+
%memref_2 = gpu.alloc host_shared () : memref<8x16xf32>
17+
%memref_3 = gpu.alloc host_shared () : memref<8x16xf32>
18+
gpu.launch_func @module0::@test_exp_larger_vec blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf16>, %memref_1 : memref<16x16xf16>, %memref_2 : memref<8x16xf32>)
19+
gpu.launch_func @module1::@test_exp_generic_vec blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf16>, %memref_1 : memref<16x16xf16>, %memref_3 : memref<8x16xf32>)
20+
gpu.dealloc %memref : memref<8x16xf16>
21+
gpu.dealloc %memref_1 : memref<16x16xf16>
22+
return %memref_2, %memref_3 : memref<8x16xf32>, memref<8x16xf32>
23+
}
24+
25+
gpu.module @module0 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
26+
gpu.func @test_exp_larger_vec(%A: memref<8x16xf16>, %B: memref<16x16xf16>, %Out: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
27+
%c0 = arith.constant 0 : index
28+
%c16 = arith.constant 16 : index
29+
// load A tile
30+
%a_tile0 = xegpu.create_nd_tdesc %A [%c0, %c0] { mode = vc } : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
31+
%val0 = xegpu.load_nd %a_tile0 { mode = vc, vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16>
32+
// load B tile
33+
%b_tile0 = xegpu.create_nd_tdesc %B [%c0, %c0] { mode = vc } : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
34+
%val2 = xegpu.load_nd %b_tile0 { mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
35+
// do DPAS
36+
%val4 = xegpu.dpas %val0, %val2 : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
37+
// take exp
38+
%t6 = spirv.CL.exp %val4 : vector<8x16xf32>
39+
// store
40+
%out_tile = xegpu.create_nd_tdesc %Out [%c0, %c0] { mode = vc } : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
41+
xegpu.store_nd %t6, %out_tile { mode = vc} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
42+
gpu.return
43+
}
44+
}
45+
gpu.module @module1 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
46+
gpu.func @test_exp_generic_vec(%A: memref<8x16xf16>, %B: memref<16x16xf16>, %Out: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
47+
%c0 = arith.constant 0 : index
48+
%c16 = arith.constant 16 : index
49+
// load A tile
50+
%a_tile0 = xegpu.create_nd_tdesc %A [%c0, %c0] { mode = vc } : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
51+
%val0 = xegpu.load_nd %a_tile0 { mode = vc, vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16>
52+
// load B tile
53+
%b_tile0 = xegpu.create_nd_tdesc %B [%c0, %c0] { mode = vc } : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
54+
%val2 = xegpu.load_nd %b_tile0 { mode = vc, vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
55+
// do DPAS
56+
%val4 = xegpu.dpas %val0, %val2 : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
57+
// extract dpas out into 16xf32 vectors
58+
%cst1 = arith.constant dense<1.4426950408889634> : vector<128xf32>
59+
%v0 = vector.extract %val4[0] : vector<16xf32> from vector<8x16xf32>
60+
%v1 = vector.extract %val4[1] : vector<16xf32> from vector<8x16xf32>
61+
%v2 = vector.extract %val4[2] : vector<16xf32> from vector<8x16xf32>
62+
%v3 = vector.extract %val4[3] : vector<16xf32> from vector<8x16xf32>
63+
%v4 = vector.extract %val4[4] : vector<16xf32> from vector<8x16xf32>
64+
%v5 = vector.extract %val4[5] : vector<16xf32> from vector<8x16xf32>
65+
%v6 = vector.extract %val4[6] : vector<16xf32> from vector<8x16xf32>
66+
%v7 = vector.extract %val4[7] : vector<16xf32> from vector<8x16xf32>
67+
// do generic size exp
68+
%v0_exp = spirv.CL.exp %v0 : vector<16xf32>
69+
%v1_exp = spirv.CL.exp %v1 : vector<16xf32>
70+
%v2_exp = spirv.CL.exp %v2 : vector<16xf32>
71+
%v3_exp = spirv.CL.exp %v3 : vector<16xf32>
72+
%v4_exp = spirv.CL.exp %v4 : vector<16xf32>
73+
%v5_exp = spirv.CL.exp %v5 : vector<16xf32>
74+
%v6_exp = spirv.CL.exp %v6 : vector<16xf32>
75+
%v7_exp = spirv.CL.exp %v7 : vector<16xf32>
76+
%v0_exp_cast = vector.shape_cast %v0_exp : vector<16xf32> to vector<1x16xf32>
77+
%v1_exp_cast = vector.shape_cast %v1_exp : vector<16xf32> to vector<1x16xf32>
78+
%v2_exp_cast = vector.shape_cast %v2_exp : vector<16xf32> to vector<1x16xf32>
79+
%v3_exp_cast = vector.shape_cast %v3_exp : vector<16xf32> to vector<1x16xf32>
80+
%v4_exp_cast = vector.shape_cast %v4_exp : vector<16xf32> to vector<1x16xf32>
81+
%v5_exp_cast = vector.shape_cast %v5_exp : vector<16xf32> to vector<1x16xf32>
82+
%v6_exp_cast = vector.shape_cast %v6_exp : vector<16xf32> to vector<1x16xf32>
83+
%v7_exp_cast = vector.shape_cast %v7_exp : vector<16xf32> to vector<1x16xf32>
84+
// construct 4x16xf32 vector from the smaller ones
85+
%t0 = vector.shuffle %v0_exp_cast, %v1_exp_cast [0, 1] : vector<1x16xf32>, vector<1x16xf32>
86+
%t1 = vector.shuffle %v2_exp_cast, %v3_exp_cast [0, 1] : vector<1x16xf32>, vector<1x16xf32>
87+
%t2 = vector.shuffle %v4_exp_cast, %v5_exp_cast [0, 1] : vector<1x16xf32>, vector<1x16xf32>
88+
%t3 = vector.shuffle %v6_exp_cast, %v7_exp_cast [0, 1] : vector<1x16xf32>, vector<1x16xf32>
89+
%t4 = vector.shuffle %t0, %t1 [0, 1, 2, 3] : vector<2x16xf32>, vector<2x16xf32>
90+
%t5 = vector.shuffle %t2, %t3 [0, 1, 2, 3] : vector<2x16xf32>, vector<2x16xf32>
91+
%t6 = vector.shuffle %t4, %t5 [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x16xf32>, vector<4x16xf32>
92+
// store
93+
%out_tile = xegpu.create_nd_tdesc %Out [%c0, %c0] { mode = vc } : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
94+
xegpu.store_nd %t6, %out_tile { mode = vc} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
95+
gpu.return
96+
}
97+
}
98+
func.func @main() attributes {llvm.emit_c_interface} {
99+
// init constants
100+
%c0 = arith.constant 0 : index
101+
%c1 = arith.constant 1 : index
102+
%c8 = arith.constant 8 : index
103+
%c16 = arith.constant 16 : index
104+
%rand_lower = arith.constant -1.0 : f32
105+
%rand_upper = arith.constant 1.0 : f32
106+
%gen_int = arith.constant 0 : i1
107+
%A = memref.alloc() : memref<8x16xf16>
108+
%B = memref.alloc() : memref<16x16xf16>
109+
%Out_cpu = memref.alloc() : memref<8x16xf32>
110+
%A_random = memref.cast %A : memref<8x16xf16> to memref<*xf16>
111+
%B_random = memref.cast %B : memref<16x16xf16> to memref<*xf16>
112+
call @fillResource1DRandomF16(%A_random, %rand_lower, %rand_upper, %gen_int) : (memref<*xf16>, f32, f32, i1) -> ()
113+
call @fillResource1DRandomF16(%B_random, %rand_lower, %rand_upper, %gen_int) : (memref<*xf16>, f32, f32, i1) -> ()
114+
// run GPU version
115+
%Out_gpu_large, %Out_gpu_generic = call @test(%A, %B) : (memref<8x16xf16>, memref<16x16xf16>) -> (memref<8x16xf32>, memref<8x16xf32>)
116+
%Out_gpu_generic_cast = memref.cast %Out_gpu_generic : memref<8x16xf32> to memref<*xf32>
117+
%Out_gpu_large_cast = memref.cast %Out_gpu_large : memref<8x16xf32> to memref<*xf32>
118+
// run CPU version
119+
scf.for %i = %c0 to %c8 step %c1 {
120+
scf.for %j = %c0 to %c16 step %c1 {
121+
%v0_init = arith.constant 0.0 : f32
122+
%result:1 = scf.for %k = %c0 to %c16 step %c1 iter_args(%v0 = %v0_init) -> f32 {
123+
%a0 = memref.load %A[%i, %k] : memref<8x16xf16>
124+
%b0 = memref.load %B[%k, %j] : memref<16x16xf16>
125+
%a0_f32 = arith.extf %a0 : f16 to f32
126+
%b0_f32 = arith.extf %b0 : f16 to f32
127+
%t0 = arith.mulf %a0_f32, %b0_f32 : f32
128+
%v0_new = arith.addf %v0, %t0 : f32
129+
scf.yield %v0_new : f32
130+
}
131+
%vexp = math.exp %result#0: f32
132+
memref.store %vexp, %Out_cpu[%i, %j] : memref<8x16xf32>
133+
}
134+
}
135+
%Out_cpu_cast = memref.cast %Out_cpu : memref<8x16xf32> to memref<*xf32>
136+
// print GPU and CPU outs
137+
// call @printMemrefF32(%Out_cpu_cast) : (memref<*xf32>) -> ()
138+
// call @printMemrefF32(%Out_gpu_cast) : (memref<*xf32>) -> ()
139+
// CHECK: [ALLCLOSE: TRUE]
140+
// CHECK: [ALLCLOSE: TRUE]
141+
call @printAllcloseF32(%Out_gpu_generic_cast, %Out_cpu_cast) : (memref<*xf32>, memref<*xf32>) -> ()
142+
call @printAllcloseF32(%Out_gpu_large_cast, %Out_cpu_cast) : (memref<*xf32>, memref<*xf32>) -> ()
143+
// dealloc
144+
memref.dealloc %A : memref<8x16xf16>
145+
memref.dealloc %B : memref<16x16xf16>
146+
memref.dealloc %Out_cpu : memref<8x16xf32>
147+
// gpu dealloc
148+
gpu.dealloc %Out_gpu_generic : memref<8x16xf32>
149+
gpu.dealloc %Out_gpu_large : memref<8x16xf32>
150+
return
151+
}
152+
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
153+
func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface}
154+
func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface}
155+
}

0 commit comments

Comments
 (0)