-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR][Math][XeVM] Add MathToXeVM (math-to-xevm
) pass
#159878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
@llvm/pr-subscribers-mlir Author: Ian Li (ianayl) ChangesThis PR introduces a These intrinsic functions are supported by the SPIRV backend, and are automatically converted to Note:
Patch is 29.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159878.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
new file mode 100644
index 0000000000000..91d3c92fd6296
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
@@ -0,0 +1,27 @@
+//===- MathToXeVM.h - Utils for converting Math to XeVM -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
+#define MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to XeVM calls.
+void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
+ bool convertArith);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index da061b269daf7..40d866ec7bf10 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -49,6 +49,7 @@
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 1a37d057776e2..5817babf68ddb 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -796,6 +796,32 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> {
let dependentDialects = ["spirv::SPIRVDialect"];
}
+//===----------------------------------------------------------------------===//
+// MathToXeVM
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
+ let summary =
+ "Convert (fast) math operations to native XeVM/SPIRV equivalents";
+ let description = [{
+ This pass converts supported math ops marked with the `afn` fastmath flag
+ to function calls for OpenCL `native_` math intrinsics: These intrinsics
+ are typically mapped directly to native device instructions, often resulting
+ in better performance. However, the precision/error of these intrinsics
+ are implementation-defined, and thus math ops are only converted when they
+ have the `afn` fastmath flag enabled.
+ }];
+ let options = [Option<
+ "convertArith", "convert-arith", "bool", /*default=*/"true",
+ "Convert supported Arith ops (e.g. arith.divf) as well.">];
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "func::FuncDialect",
+ "xevm::XeVMDialect",
+ "vector::VectorDialect",
+ ];
+}
+
//===----------------------------------------------------------------------===//
// MathToEmitC
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 71986f83c4870..bebf1b8fff3f9 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -40,6 +40,7 @@ add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
add_subdirectory(MathToROCDL)
add_subdirectory(MathToSPIRV)
+add_subdirectory(MathToXeVM)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
new file mode 100644
index 0000000000000..711c6876bb168
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
@@ -0,0 +1,24 @@
+# TODO check if everything here is needed
+add_mlir_conversion_library(MLIRMathToXeVM
+ MathToXeVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRDialectUtils
+ MLIRFuncDialect
+ MLIRGPUToGPURuntimeTransforms
+ MLIRMathDialect
+ MLIRLLVMCommonConversion
+ MLIRPass
+ MLIRTransformUtils
+ MLIRVectorDialect
+ MLIRVectorUtils
+ )
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
new file mode 100644
index 0000000000000..46833735a79dd
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -0,0 +1,188 @@
+//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-xevm"
+
+// GPUCommon/OpToFunctionCallLowering is not used here, as it doesn't handle
+// native functions/intrinsics that take vector operands.
+
+/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
+template <typename Op>
+struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
+
+ ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
+
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isSPIRVCompatibleFloatOrVec(op.getType()))
+ return failure();
+
+ arith::FastMathFlags fastFlags = op.getFastmath();
+ if (!((uint32_t)fastFlags & (uint32_t)arith::FastMathFlags::afn))
+ return failure();
+
+ SmallVector<Type, 1> operandTypes;
+ for (auto operand : adaptor.getOperands()) {
+ // This pass only supports operations on vectors that are already in SPIRV
+ // supported vector sizes: Distributing unsupported vector sizes to SPIRV
+ // supported vetor sizes are done in other blocking optimization passes.
+ if (!isSPIRVCompatibleFloatOrVec(operand.getType()))
+ return failure();
+ operandTypes.push_back(operand.getType());
+ }
+ LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes);
+ auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+ op, funcOp, adaptor.getOperands());
+ arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
+ mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
+ callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
+ return success();
+ }
+
+ inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
+ if (type.isFloat()) {
+ return true;
+ } else if (auto vecType = dyn_cast<VectorType>(type)) {
+ if (!vecType.getElementType().isFloat())
+ return false;
+ // SPIRV distinguishes between vectors and matrices: OpenCL native math
+ // intrsinics are not compatible with matrices.
+ ArrayRef<int64_t> shape = vecType.getShape();
+ if (shape.size() != 1)
+ return false;
+ // SPIRV only allows vectors of size 2, 3, 4, 8, 16.
+ if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
+ shape[0] == 16)
+ return true;
+ }
+ return false;
+ }
+
+ LLVM::LLVMFuncOp
+ appendOrGetFuncOp(Op &op, const SmallVector<Type, 1> &operandTypes) const {
+ // This function assumes op types have already been validated using
+ // isSPIRVCompatibleFloatOrVec.
+ using LLVM::LLVMFuncOp;
+
+ std::string mangledNativeFunc =
+ "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
+
+ auto appendFloatToMangledFunc = [&mangledNativeFunc](Type type) {
+ if (type.isF32())
+ mangledNativeFunc += "f";
+ else if (type.isF16())
+ mangledNativeFunc += "Dh";
+ else if (type.isF64())
+ mangledNativeFunc += "d";
+ };
+
+ for (auto type : operandTypes) {
+ if (auto vecType = dyn_cast<VectorType>(type)) {
+ mangledNativeFunc += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
+ appendFloatToMangledFunc(vecType.getElementType());
+ } else
+ appendFloatToMangledFunc(type);
+ }
+
+ auto funcAttr = StringAttr::get(op->getContext(), mangledNativeFunc);
+ auto funcOp =
+ SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
+ if (funcOp)
+ return funcOp;
+
+ auto parentFunc = op->template getParentOfType<FunctionOpInterface>();
+ assert(parentFunc && "expected there to be a parent function");
+ OpBuilder b(parentFunc);
+
+ // Create a valid global location removing any metadata attached to the
+ // location, as debug info metadata inside of a function cannot be used
+ // outside of that function.
+ auto funcType = LLVM::LLVMFunctionType::get(op.getType(), operandTypes);
+ auto globalloc =
+ op->getLoc()->template findInstanceOfOrUnknown<FileLineColLoc>();
+ return LLVMFuncOp::create(b, globalloc, mangledNativeFunc, funcType);
+ }
+
+ const StringRef nativeFunc;
+};
+
+void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
+ bool convertArith) {
+ patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
+ "__spirv_ocl_native_exp");
+ patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(),
+ "__spirv_ocl_native_cos");
+ patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>(
+ patterns.getContext(), "__spirv_ocl_native_exp2");
+ patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(),
+ "__spirv_ocl_native_log");
+ patterns.add<ConvertNativeFuncPattern<math::Log2Op>>(
+ patterns.getContext(), "__spirv_ocl_native_log2");
+ patterns.add<ConvertNativeFuncPattern<math::Log10Op>>(
+ patterns.getContext(), "__spirv_ocl_native_log10");
+ patterns.add<ConvertNativeFuncPattern<math::PowFOp>>(
+ patterns.getContext(), "__spirv_ocl_native_powr");
+ patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>(
+ patterns.getContext(), "__spirv_ocl_native_rsqrt");
+ patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(),
+ "__spirv_ocl_native_sin");
+ patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>(
+ patterns.getContext(), "__spirv_ocl_native_sqrt");
+ patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(),
+ "__spirv_ocl_native_tan");
+ if (convertArith)
+ patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>(
+ patterns.getContext(), "__spirv_ocl_native_divide");
+}
+
+namespace {
+struct ConvertMathToXeVMPass
+ : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
+ using Base::Base;
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToXeVMPass::runOnOperation() {
+ auto m = getOperation();
+
+ RewritePatternSet patterns(&getContext());
+ populateMathToXeVMConversionPatterns(patterns, convertArith);
+ ConversionTarget target(getContext());
+ target.addLegalDialect<BuiltinDialect, func::FuncDialect,
+ vector::VectorDialect, LLVM::LLVMDialect>();
+ if (failed(applyPartialConversion(m, target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
new file mode 100644
index 0000000000000..ba5de228da411
--- /dev/null
+++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
@@ -0,0 +1,158 @@
+// RUN: mlir-opt %s -convert-math-to-xevm \
+// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-ARITH'
+// RUN: mlir-opt %s -convert-math-to-xevm='convert-arith=false' \
+// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-NO-ARITH'
+
+module @test_module {
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
+ //
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64>
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64>
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64>
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64>
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64>
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_f(vector<16xf32>) -> vector<16xf32>
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_Dh(vector<4xf16>) -> vector<4xf16>
+ //
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16
+ // CHECK: llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16
+ // CHECK: llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32
+ // CHECK: llvm.func @_Z24__spirv_ocl_native_log10d(f64) -> f64
+ // CHECK: llvm.func @_Z23__spirv_ocl_native_powrDhDh(f16, f16) -> f16
+ // CHECK: llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
+ // CHECK: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
+ // CHECK-ARITH: llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32
+
+ // CHECK-LABEL: func @math_ops
+ func.func @math_ops() {
+
+ %c1_f16 = arith.constant 1. : f16
+ %c1_f32 = arith.constant 1. : f32
+ %c1_f64 = arith.constant 1. : f64
+
+ // CHECK: math.exp
+ %exp_normal_f16 = math.exp %c1_f16 : f16
+ // CHECK: math.exp
+ %exp_normal_f32 = math.exp %c1_f32 : f32
+ // CHECK: math.exp
+ %exp_normal_f64 = math.exp %c1_f64 : f64
+
+ // Check float operations are converted properly:
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f16) -> f16
+ %exp_fast_f16 = math.exp %c1_f16 fastmath<fast> : f16
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32
+ %exp_fast_f32 = math.exp %c1_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f64) -> f64
+ %exp_fast_f64 = math.exp %c1_f64 fastmath<fast> : f64
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+ %exp_afn_f16 = math.exp %c1_f16 fastmath<afn> : f16
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+ %exp_afn_f32 = math.exp %c1_f32 fastmath<afn> : f32
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+ %exp_afn_f64 = math.exp %c1_f64 fastmath<afn> : f64
+
+ // CHECK: math.exp
+ %exp_none_f16 = math.exp %c1_f16 fastmath<none> : f16
+ // CHECK: math.exp
+ %exp_none_f32 = math.exp %c1_f32 fastmath<none> : f32
+ // CHECK: math.exp
+ %exp_none_f64 = math.exp %c1_f64 fastmath<none> : f64
+
+ // Check vector operations:
+
+ %v2_c1_f64 = arith.constant dense<1.> : vector<2xf64>
+ %v3_c1_f64 = arith.constant dense<1.> : vector<3xf64>
+ %v4_c1_f64 = arith.constant dense<1.> : vector<4xf64>
+ %v8_c1_f64 = arith.constant dense<1.> : vector<8xf64>
+ %v16_c1_f64 = arith.constant dense<1.> : vector<16xf64>
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv2_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<2xf64>) -> vector<2xf64>
+ %exp_v2_f64 = math.exp %v2_c1_f64 fastmath<afn> : vector<2xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv3_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<3xf64>) -> vector<3xf64>
+ %exp_v3_f64 = math.exp %v3_c1_f64 fastmath<afn> : vector<3xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<4xf64>) -> vector<4xf64>
+ %exp_v4_f64 = math.exp %v4_c1_f64 fastmath<afn> : vector<4xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv8_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<8xf64>) -> vector<8xf64>
+ %exp_v8_f64 = math.exp %v8_c1_f64 fastmath<afn> : vector<8xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<16xf64>) -> vector<16xf64>
+ %exp_v16_f64 = math.exp %v16_c1_f64 fastmath<afn> : vector<16xf64>
+
+ %v16_c1_f32 = arith.constant dense<1.> : vector<16xf32>
+ %v4_c1_f16 = arith.constant dense<1.> : vector<4xf16>
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_f(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (vector<16xf32>) -> vector<16xf32>
+ %exp_v16_f32 = math.exp %v16_c1_f32 fastmath<fast> : vector<16xf32>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_Dh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (vector<4xf16>) -> vector<4xf16>
+ %exp_v4_f16 = math.exp %v4_c1_f16 fastmath<fast> : vector<4xf16>
+
+ // Check unsupported vector sizes are not converted:
+
+ %v5_c1_f64 = arith.constant dense<1.> : vector<5xf64>
+ %v32_c1_f64 = arith.constant dense<1.> : vector<32xf64>
+
+ // CHECK: math.exp
+ %exp_v5_f64 = math.exp %v5_c1_f64 fastmath<afn> : vector<5xf64>
+ // CHECK: math.exp
+ %exp_v32_f64 = math.exp %v32_c1_f64 fastmath<afn> : vector<32xf64>
+
+ // Check fastmath flags propagate properly:
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f16) -> f16
+ %exp_fastmath_all_f16 = math.exp %c1_f16 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f16
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, ninf, nsz, arcp, contract, afn>} : (f32) -> f32
+ %exp_fastmath_most_f32 = math.exp %c1_f32 fastmath<nnan,ninf,nsz,arcp,contract,afn> : f32
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, afn, reassoc>} : (f32) -> f32
+ %exp_afn_reassoc_nnan_f32 = math.exp %c1_f32 fastmath<afn,reassoc,nnan> : f32
+
+ // Check all other math operations:
+
+ // native_divide(gentype x, gentype y)
+ // TODO: convert arith.divf to arith/native_divide if option is enabled
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_cosDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+ %cos_afn_f16 = math.cos %c1_f16 fastmath<afn> : f16
+
+ // CHECK: llvm.call @_Z23__spirv_ocl_native_exp2f(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+ %exp2_afn_f32 = math.exp2 %c1_f32 fastmath<afn> : f32
+
+ // CHECK: l...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a few comments
afn
option using native OpenCL intrinsicsmath-to-xevm
) pass
CI failure seems to be an infrastructure issue, I don't seem to have permissions to retrigger the CI. If needed I'll make a dummy commit to restart the CI.
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/204/builds/24710 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/203/builds/25898 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/205/builds/24687 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/177/builds/22411 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/20278 Here is the relevant piece of the build log for the reference
|
)" This reverts commit fabd1c4.
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/19569 Here is the relevant piece of the build log for the reference
|
Looking into the build issues: I am currently unable to recreate this build error, it didn't seem to occur in the precommit either |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/129/builds/31239 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/140/builds/32234 Here is the relevant piece of the build log for the reference
|
Able to recreate the issue, looking into fix now |
Fixed version reintroduced at #162934 |
… pass" (#162923) Reverts llvm/llvm-project#159878
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/207/builds/8303 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/117/builds/14134 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/80/builds/16628 Here is the relevant piece of the build log for the reference
|
Is there a plan to get the failing build bots working again, or should this PR be reverted? |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/169/builds/15862 Here is the relevant piece of the build log for the reference
|
This PR introduces a `MathToXeVM` pass, which implements support for the `afn` fastmath flag for SPIRV/XeVM targets - It takes supported `Math` Ops with the `afn` flag, and converts them to function calls to OpenCL `native_` intrinsics. These intrinsic functions are supported by the SPIRV backend, and are automatically converted to `OpExtInst` calls to `native_` ops from the OpenCL SPIRV ext. inst. set when outputting to SPIRV/XeVM. Note: - This pass also supports converting `arith.divf` to native equivalents. There is an option provided in the pass to turn this behavior off. - This pass preserves fastmath flags, but these flags are currently ignored by the SPIRV backend. Thus, in order to generate SPIRV that truly preserves fastmath flags, support needs to be added to the SPIRV backend.
This PR is a fix for #159878, which failed in postcommit testing due to linker errors that were not caught in precommit. Original PR: --- This PR introduces a `MathToXeVM` pass, which implements support for the `afn` fastmath flag for SPIRV/XeVM targets - It takes supported `Math` Ops with the `afn` flag, and converts them to function calls to OpenCL `native_` intrinsics. These intrinsic functions are supported by the SPIRV backend, and are automatically converted to `OpExtInst` calls to `native_` ops from the OpenCL SPIRV ext. inst. set when outputting to SPIRV/XeVM. Note: - This pass also supports converting `arith.divf` to native equivalents. There is an option provided in the pass to turn this behavior off. - This pass preserves fastmath flags, but these flags are currently ignored by the SPIRV backend. Thus, in order to generate SPIRV that truly preserves fastmath flags, support needs to be added to the SPIRV backend.
…ass (#162934) This PR is a fix for llvm/llvm-project#159878, which failed in postcommit testing due to linker errors that were not caught in precommit. Original PR: --- This PR introduces a `MathToXeVM` pass, which implements support for the `afn` fastmath flag for SPIRV/XeVM targets - It takes supported `Math` Ops with the `afn` flag, and converts them to function calls to OpenCL `native_` intrinsics. These intrinsic functions are supported by the SPIRV backend, and are automatically converted to `OpExtInst` calls to `native_` ops from the OpenCL SPIRV ext. inst. set when outputting to SPIRV/XeVM. Note: - This pass also supports converting `arith.divf` to native equivalents. There is an option provided in the pass to turn this behavior off. - This pass preserves fastmath flags, but these flags are currently ignored by the SPIRV backend. Thus, in order to generate SPIRV that truly preserves fastmath flags, support needs to be added to the SPIRV backend.
This PR introduces a `MathToXeVM` pass, which implements support for the `afn` fastmath flag for SPIRV/XeVM targets - It takes supported `Math` Ops with the `afn` flag, and converts them to function calls to OpenCL `native_` intrinsics. These intrinsic functions are supported by the SPIRV backend, and are automatically converted to `OpExtInst` calls to `native_` ops from the OpenCL SPIRV ext. inst. set when outputting to SPIRV/XeVM. Note: - This pass also supports converting `arith.divf` to native equivalents. There is an option provided in the pass to turn this behavior off. - This pass preserves fastmath flags, but these flags are currently ignored by the SPIRV backend. Thus, in order to generate SPIRV that truly preserves fastmath flags, support needs to be added to the SPIRV backend.
…62934) This PR is a fix for llvm#159878, which failed in postcommit testing due to linker errors that were not caught in precommit. Original PR: --- This PR introduces a `MathToXeVM` pass, which implements support for the `afn` fastmath flag for SPIRV/XeVM targets - It takes supported `Math` Ops with the `afn` flag, and converts them to function calls to OpenCL `native_` intrinsics. These intrinsic functions are supported by the SPIRV backend, and are automatically converted to `OpExtInst` calls to `native_` ops from the OpenCL SPIRV ext. inst. set when outputting to SPIRV/XeVM. Note: - This pass also supports converting `arith.divf` to native equivalents. There is an option provided in the pass to turn this behavior off. - This pass preserves fastmath flags, but these flags are currently ignored by the SPIRV backend. Thus, in order to generate SPIRV that truly preserves fastmath flags, support needs to be added to the SPIRV backend.
This PR introduces a
MathToXeVM
pass, which implements support for theafn
fastmath flag for SPIRV/XeVM targets - It takes supportedMath
Ops with theafn
flag, and converts them to function calls to OpenCLnative_
intrinsics.These intrinsic functions are supported by the SPIRV backend, and are automatically converted to
OpExtInst
calls tonative_
ops from the OpenCL SPIRV ext. inst. set when outputting to SPIRV/XeVM.Note:
arith.divf
to native equivalents. There is an option provided in the pass to turn this behavior off.