Skip to content

Commit fabd1c4

Browse files
authored
[MLIR][Math][XeVM] Add MathToXeVM (math-to-xevm) pass (#159878)
This PR introduces a `MathToXeVM` pass, which implements support for the `afn` fastmath flag for SPIRV/XeVM targets - It takes supported `Math` Ops with the `afn` flag, and converts them to function calls to OpenCL `native_` intrinsics. These intrinsic functions are supported by the SPIRV backend, and are automatically converted to `OpExtInst` calls to `native_` ops from the OpenCL SPIRV ext. inst. set when outputting to SPIRV/XeVM. Note: - This pass also supports converting `arith.divf` to native equivalents. There is an option provided in the pass to turn this behavior off. - This pass preserves fastmath flags, but these flags are currently ignored by the SPIRV backend. Thus, in order to generate SPIRV that truly preserves fastmath flags, support needs to be added to the SPIRV backend.
1 parent d78c930 commit fabd1c4

File tree

8 files changed

+520
-0
lines changed

8 files changed

+520
-0
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- MathToXeVM.h - Utils for converting Math to XeVM -------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
9+
#define MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
10+
11+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12+
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
#include <memory>
15+
16+
namespace mlir {
17+
class Pass;
18+
19+
#define GEN_PASS_DECL_CONVERTMATHTOXEVM
20+
#include "mlir/Conversion/Passes.h.inc"
21+
22+
/// Populate the given list with patterns that convert from Math to XeVM calls.
23+
void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
24+
bool convertArith);
25+
} // namespace mlir
26+
27+
#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
5050
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
5151
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
52+
#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
5253
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
5354
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
5455
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,31 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> {
796796
let dependentDialects = ["spirv::SPIRVDialect"];
797797
}
798798

799+
//===----------------------------------------------------------------------===//
800+
// MathToXeVM
801+
//===----------------------------------------------------------------------===//
802+
803+
def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
804+
let summary =
805+
"Convert (fast) math operations to native XeVM/SPIRV equivalents";
806+
let description = [{
807+
This pass converts supported math ops marked with the `afn` fastmath flag
808+
to function calls for OpenCL `native_` math intrinsics: These intrinsics
809+
are typically mapped directly to native device instructions, often resulting
810+
in better performance. However, the precision/error of these intrinsics
811+
are implementation-defined, and thus math ops are only converted when they
812+
have the `afn` fastmath flag enabled.
813+
}];
814+
let options = [Option<
815+
"convertArith", "convert-arith", "bool", /*default=*/"true",
816+
"Convert supported Arith ops (e.g. arith.divf) as well.">];
817+
let dependentDialects = [
818+
"arith::ArithDialect",
819+
"xevm::XeVMDialect",
820+
"LLVM::LLVMDialect",
821+
];
822+
}
823+
799824
//===----------------------------------------------------------------------===//
800825
// MathToEmitC
801826
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ add_subdirectory(MathToLibm)
4040
add_subdirectory(MathToLLVM)
4141
add_subdirectory(MathToROCDL)
4242
add_subdirectory(MathToSPIRV)
43+
add_subdirectory(MathToXeVM)
4344
add_subdirectory(MemRefToEmitC)
4445
add_subdirectory(MemRefToLLVM)
4546
add_subdirectory(MemRefToSPIRV)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
add_mlir_conversion_library(MLIRMathToXeVM
2+
MathToXeVM.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRMathDialect
15+
MLIRLLVMCommonConversion
16+
MLIRPass
17+
MLIRTransformUtils
18+
MLIRVectorDialect
19+
)
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
10+
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
11+
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12+
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
13+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14+
#include "mlir/Dialect/Func/IR/FuncOps.h"
15+
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
16+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17+
#include "mlir/Dialect/Math/IR/Math.h"
18+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
19+
#include "mlir/IR/BuiltinDialect.h"
20+
#include "mlir/IR/PatternMatch.h"
21+
#include "mlir/Pass/Pass.h"
22+
#include "mlir/Transforms/DialectConversion.h"
23+
#include "llvm/Support/FormatVariadic.h"
24+
25+
namespace mlir {
26+
#define GEN_PASS_DEF_CONVERTMATHTOXEVM
27+
#include "mlir/Conversion/Passes.h.inc"
28+
} // namespace mlir
29+
30+
using namespace mlir;
31+
32+
#define DEBUG_TYPE "math-to-xevm"
33+
34+
/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
35+
template <typename Op>
36+
struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
37+
38+
ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc,
39+
PatternBenefit benefit = 1)
40+
: OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
41+
42+
LogicalResult
43+
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
44+
ConversionPatternRewriter &rewriter) const override {
45+
if (!isSPIRVCompatibleFloatOrVec(op.getType()))
46+
return failure();
47+
48+
arith::FastMathFlags fastFlags = op.getFastmath();
49+
if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn))
50+
return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation");
51+
52+
SmallVector<Type, 1> operandTypes;
53+
for (auto operand : adaptor.getOperands()) {
54+
Type opTy = operand.getType();
55+
// This pass only supports operations on vectors that are already in SPIRV
56+
// supported vector sizes: Distributing unsupported vector sizes to SPIRV
57+
// supported vector sizes are done in other blocking optimization passes.
58+
if (!isSPIRVCompatibleFloatOrVec(opTy))
59+
return rewriter.notifyMatchFailure(
60+
op, llvm::formatv("incompatible operand type: '{0}'", opTy));
61+
operandTypes.push_back(opTy);
62+
}
63+
64+
auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
65+
auto funcOpRes = LLVM::lookupOrCreateFn(
66+
rewriter, moduleOp, getMangledNativeFuncName(operandTypes),
67+
operandTypes, op.getType());
68+
assert(!failed(funcOpRes));
69+
LLVM::LLVMFuncOp funcOp = funcOpRes.value();
70+
71+
auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
72+
op, funcOp, adaptor.getOperands());
73+
// Preserve fastmath flags in our MLIR op when converting to llvm function
74+
// calls, in order to allow further fastmath optimizations: We thus need to
75+
// convert arith fastmath attrs into attrs recognized by llvm.
76+
arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
77+
mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
78+
callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
79+
return success();
80+
}
81+
82+
inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
83+
if (type.isFloat())
84+
return true;
85+
if (auto vecType = dyn_cast<VectorType>(type)) {
86+
if (!vecType.getElementType().isFloat())
87+
return false;
88+
// SPIRV distinguishes between vectors and matrices: OpenCL native math
89+
// intrsinics are not compatible with matrices.
90+
ArrayRef<int64_t> shape = vecType.getShape();
91+
if (shape.size() != 1)
92+
return false;
93+
// SPIRV only allows vectors of size 2, 3, 4, 8, 16.
94+
if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
95+
shape[0] == 16)
96+
return true;
97+
}
98+
return false;
99+
}
100+
101+
inline std::string
102+
getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
103+
std::string mangledFuncName =
104+
"_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
105+
106+
auto appendFloatToMangledFunc = [&mangledFuncName](Type type) {
107+
if (type.isF32())
108+
mangledFuncName += "f";
109+
else if (type.isF16())
110+
mangledFuncName += "Dh";
111+
else if (type.isF64())
112+
mangledFuncName += "d";
113+
};
114+
115+
for (auto type : operandTypes) {
116+
if (auto vecType = dyn_cast<VectorType>(type)) {
117+
mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
118+
appendFloatToMangledFunc(vecType.getElementType());
119+
} else
120+
appendFloatToMangledFunc(type);
121+
}
122+
123+
return mangledFuncName;
124+
}
125+
126+
const StringRef nativeFunc;
127+
};
128+
129+
void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
130+
bool convertArith) {
131+
patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
132+
"__spirv_ocl_native_exp");
133+
patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(),
134+
"__spirv_ocl_native_cos");
135+
patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>(
136+
patterns.getContext(), "__spirv_ocl_native_exp2");
137+
patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(),
138+
"__spirv_ocl_native_log");
139+
patterns.add<ConvertNativeFuncPattern<math::Log2Op>>(
140+
patterns.getContext(), "__spirv_ocl_native_log2");
141+
patterns.add<ConvertNativeFuncPattern<math::Log10Op>>(
142+
patterns.getContext(), "__spirv_ocl_native_log10");
143+
patterns.add<ConvertNativeFuncPattern<math::PowFOp>>(
144+
patterns.getContext(), "__spirv_ocl_native_powr");
145+
patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>(
146+
patterns.getContext(), "__spirv_ocl_native_rsqrt");
147+
patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(),
148+
"__spirv_ocl_native_sin");
149+
patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>(
150+
patterns.getContext(), "__spirv_ocl_native_sqrt");
151+
patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(),
152+
"__spirv_ocl_native_tan");
153+
if (convertArith)
154+
patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>(
155+
patterns.getContext(), "__spirv_ocl_native_divide");
156+
}
157+
158+
namespace {
159+
struct ConvertMathToXeVMPass
160+
: public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
161+
using Base::Base;
162+
void runOnOperation() override;
163+
};
164+
} // namespace
165+
166+
void ConvertMathToXeVMPass::runOnOperation() {
167+
RewritePatternSet patterns(&getContext());
168+
populateMathToXeVMConversionPatterns(patterns, convertArith);
169+
ConversionTarget target(getContext());
170+
target.addLegalDialect<BuiltinDialect, LLVM::LLVMDialect>();
171+
if (failed(
172+
applyPartialConversion(getOperation(), target, std::move(patterns))))
173+
signalPassFailure();
174+
}

0 commit comments

Comments
 (0)