Skip to content

Commit 385bcd2

Browse files
authored
Use arith/math ops instead of spirv.CL.* ops in tests. (#702)
1 parent b081001 commit 385bcd2

File tree

6 files changed

+47
-26
lines changed

6 files changed

+47
-26
lines changed

include/imex/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef IMEX_CONVERSION_XEGPUTOSPIRV_H
1515
#define IMEX_CONVERSION_XEGPUTOSPIRV_H
1616

17+
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
1718
#include <mlir/Dialect/SPIRV/IR/SPIRVDialect.h>
1819
#include <mlir/Dialect/SPIRV/IR/SPIRVOps.h>
1920
#include <mlir/Transforms/DialectConversion.h>
@@ -25,7 +26,9 @@ class Pass;
2526

2627
namespace imex {
2728

28-
template <typename SPIRVOp> std::string getVCIntrinsicName(SPIRVOp op);
29+
// helper to check the legal vector lengths for arith/math ops
30+
bool isGenericVectorTy(mlir::Type type);
31+
2932
// XeGPU to VC Intrinsics pattern
3033
void populateXeGPUToVCIntrinsicsPatterns(
3134
mlir::SPIRVTypeConverter &typeConverter, mlir::RewritePatternSet &patterns);

lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ void GPUXToSPIRVPass::runOnOperation() {
211211
mlir::RewritePatternSet patterns(context);
212212
mlir::SPIRVConversionOptions options;
213213
options.use64bitIndex = true;
214+
// FIXME: activate fast math per operator basis.
215+
options.enableFastMathMode = true;
214216

215217
mlir::SPIRVTypeConverter typeConverter(targetAttr, options);
216218

@@ -360,9 +362,24 @@ void GPUXToSPIRVPass::runOnOperation() {
360362
});
361363
}
362364

365+
// SPIR-V elementwise arith/math ops require special handling if the operate
366+
// on large vectors. We dynamically legalize these ops based on the vector
367+
// size they consume.
368+
// FIXME: this is not an exhaustive list of arith/math ops that need special
369+
// handling.
370+
target->addDynamicallyLegalOp<mlir::spirv::CLExpOp>(
371+
[&](mlir::spirv::CLExpOp op) {
372+
return imex::isGenericVectorTy(op.getType());
373+
});
374+
target->addDynamicallyLegalOp<mlir::spirv::CLFMaxOp>(
375+
[&](mlir::spirv::CLFMaxOp op) {
376+
return imex::isGenericVectorTy(op.getType());
377+
});
378+
363379
//------- Upstream Conversion------------
364380
mlir::populateGPUToSPIRVPatterns(typeConverter, patterns);
365381
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
382+
mlir::populateMathToSPIRVPatterns(typeConverter, patterns);
366383
mlir::populateMemRefToSPIRVPatterns(typeConverter, patterns);
367384
mlir::populateFuncToSPIRVPatterns(typeConverter, patterns);
368385
// ---------------------------------------

lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ class UpdateNDOffsetToVCPattern : public OpConversionPattern<UpdateNDOffsetOp> {
363363
auto desc = adaptor.getTensorDesc();
364364
for (size_t i = 0; i < offsets.size(); i++) {
365365
auto offset = offsets[i];
366-
if (auto cst = dyn_cast<spirv::ConstantOp>(offset.getDefiningOp()))
366+
if (auto cst = offset.getDefiningOp<arith::ConstantOp>())
367367
if (auto attr = dyn_cast<mlir::IntegerAttr>(cst.getValue());
368368
attr && attr.getInt() == 0)
369369
continue;
@@ -1538,21 +1538,13 @@ struct SPIRVElementwiseToVC : public OpConversionPattern<SPIRVOp> {
15381538
if (!dstType)
15391539
return failure();
15401540

1541+
// This lowering pattern is needed only for spirv ops with large vector
1542+
// lengths.
1543+
assert(
1544+
!imex::isGenericVectorTy(dstType) &&
1545+
"Vector size is considered generic and op does not require lowering to "
1546+
"VC intrinsic. Consider marking this op + vector length as legal.");
15411547
auto vecSize = dstType.template dyn_cast<VectorType>().getNumElements();
1542-
auto hasGenericVecSize = [&]() -> bool {
1543-
// if the input is scalar, we keep the operation as is.
1544-
if (isa<spirv::ScalarType>(dstType))
1545-
return true;
1546-
// or, if the vector size is 2, 3, 4, 8, or 16, we keep the operation.
1547-
return vecSize == 2 || vecSize == 3 || vecSize == 4 || vecSize == 8 ||
1548-
vecSize == 16;
1549-
};
1550-
1551-
if (hasGenericVecSize()) {
1552-
rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
1553-
return success();
1554-
}
1555-
15561548
// for larger vector lengths, "llvm.genx.exp" returns the base 2
15571549
// exponentiation of the input. To get the base e exponentiation, we need to
15581550
// scale the input by log2(e)
@@ -1588,6 +1580,14 @@ struct SPIRVElementwiseToVC : public OpConversionPattern<SPIRVOp> {
15881580
};
15891581
} // namespace
15901582

1583+
bool imex::isGenericVectorTy(mlir::Type type) {
1584+
if (isa<spirv::ScalarType>(type))
1585+
return true;
1586+
auto vecSize = type.dyn_cast<VectorType>().getNumElements();
1587+
return vecSize == 2 || vecSize == 3 || vecSize == 4 || vecSize == 8 ||
1588+
vecSize == 16;
1589+
}
1590+
15911591
void imex::populateXeGPUToVCIntrinsicsPatterns(
15921592
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
15931593
patterns.add<CreateNdDescToSPIRV, CreateDescToVCPattern, DpasToVCPattern,

test/Integration/Dialect/XeGPU/exp_f32.vc.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ module @gemm attributes {gpu.container_module} {
3535
// do DPAS
3636
%val4 = xegpu.dpas %val0, %val2 : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
3737
// take exp
38-
%t6 = spirv.CL.exp %val4 : vector<8x16xf32>
38+
%t6 = math.exp %val4 : vector<8x16xf32>
3939
// store
4040
%out_tile = xegpu.create_nd_tdesc %Out [%c0, %c0] { mode = vc } : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
4141
xegpu.store_nd %t6, %out_tile { mode = vc} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
@@ -65,14 +65,14 @@ module @gemm attributes {gpu.container_module} {
6565
%v6 = vector.extract %val4[6] : vector<16xf32> from vector<8x16xf32>
6666
%v7 = vector.extract %val4[7] : vector<16xf32> from vector<8x16xf32>
6767
// do generic size exp
68-
%v0_exp = spirv.CL.exp %v0 : vector<16xf32>
69-
%v1_exp = spirv.CL.exp %v1 : vector<16xf32>
70-
%v2_exp = spirv.CL.exp %v2 : vector<16xf32>
71-
%v3_exp = spirv.CL.exp %v3 : vector<16xf32>
72-
%v4_exp = spirv.CL.exp %v4 : vector<16xf32>
73-
%v5_exp = spirv.CL.exp %v5 : vector<16xf32>
74-
%v6_exp = spirv.CL.exp %v6 : vector<16xf32>
75-
%v7_exp = spirv.CL.exp %v7 : vector<16xf32>
68+
%v0_exp = math.exp %v0 : vector<16xf32>
69+
%v1_exp = math.exp %v1 : vector<16xf32>
70+
%v2_exp = math.exp %v2 : vector<16xf32>
71+
%v3_exp = math.exp %v3 : vector<16xf32>
72+
%v4_exp = math.exp %v4 : vector<16xf32>
73+
%v5_exp = math.exp %v5 : vector<16xf32>
74+
%v6_exp = math.exp %v6 : vector<16xf32>
75+
%v7_exp = math.exp %v7 : vector<16xf32>
7676
%v0_exp_cast = vector.shape_cast %v0_exp : vector<16xf32> to vector<1x16xf32>
7777
%v1_exp_cast = vector.shape_cast %v1_exp : vector<16xf32> to vector<1x16xf32>
7878
%v2_exp_cast = vector.shape_cast %v2_exp : vector<16xf32> to vector<1x16xf32>

test/Integration/Dialect/XeGPU/fmax_f32.vc.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ module @gemm attributes {gpu.container_module} {
3737
%val4 = xegpu.dpas %val0, %val2 : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
3838
%val5 = xegpu.dpas %val1, %val3 : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
3939
// take fmax
40-
%val6 = spirv.CL.fmax %val4, %val5 : vector<8x16xf32>
40+
%val6 = arith.maximumf %val4, %val5 fastmath<nnan> : vector<8x16xf32>
4141
// store fmax
4242
%out_tile = xegpu.create_nd_tdesc %Out [%c0, %c0] { mode = vc } : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
4343
xegpu.store_nd %val6, %out_tile { mode = vc} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>

test/Integration/Dialect/XeGPU/xegpu-to-llvm.pp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// linalg dialect to gpu dialect lowering pipeline
22
// Ready for vulkan runner or narrow scope l0/sycl runner starting from GPU dialect.
33
builtin.module(
4+
imex-vector-linearize
45
imex-convert-gpu-to-spirv{enable-vc-intrinsic=true}
56
spirv.module(spirv-lower-abi-attrs
67
spirv-update-vce)

0 commit comments

Comments
 (0)