diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h deleted file mode 100644 index 91d3c92fd6296..0000000000000 --- a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h +++ /dev/null @@ -1,27 +0,0 @@ -//===- MathToXeVM.h - Utils for converting Math to XeVM -------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -#ifndef MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ -#define MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ - -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/LLVMIR/XeVMDialect.h" -#include "mlir/IR/PatternMatch.h" -#include - -namespace mlir { -class Pass; - -#define GEN_PASS_DECL_CONVERTMATHTOXEVM -#include "mlir/Conversion/Passes.h.inc" - -/// Populate the given list with patterns that convert from Math to XeVM calls. -void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, - bool convertArith); -} // namespace mlir - -#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 40d866ec7bf10..da061b269daf7 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -49,7 +49,6 @@ #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" -#include "mlir/Conversion/MathToXeVM/MathToXeVM.h" #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 25e9d34f3e653..3c18ecc753d0f 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -796,31 +796,6 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> { let dependentDialects = ["spirv::SPIRVDialect"]; } -//===----------------------------------------------------------------------===// -// MathToXeVM -//===----------------------------------------------------------------------===// - -def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> { - let summary = - "Convert (fast) math operations to native XeVM/SPIRV equivalents"; - let description = [{ - This pass converts supported math ops marked with the `afn` fastmath flag - to function calls for OpenCL `native_` math intrinsics: These intrinsics - are typically mapped directly to native device instructions, often resulting - in better performance. However, the precision/error of these intrinsics - are implementation-defined, and thus math ops are only converted when they - have the `afn` fastmath flag enabled. - }]; - let options = [Option< - "convertArith", "convert-arith", "bool", /*default=*/"true", - "Convert supported Arith ops (e.g. arith.divf) as well.">]; - let dependentDialects = [ - "arith::ArithDialect", - "xevm::XeVMDialect", - "LLVM::LLVMDialect", - ]; -} - //===----------------------------------------------------------------------===// // MathToEmitC //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index bebf1b8fff3f9..71986f83c4870 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -40,7 +40,6 @@ add_subdirectory(MathToLibm) add_subdirectory(MathToLLVM) add_subdirectory(MathToROCDL) add_subdirectory(MathToSPIRV) -add_subdirectory(MathToXeVM) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt deleted file mode 100644 index 95aaba31a993e..0000000000000 --- a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -add_mlir_conversion_library(MLIRMathToXeVM - MathToXeVM.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM - - DEPENDS - MLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRMathDialect - MLIRLLVMCommonConversion - MLIRPass - MLIRTransformUtils - MLIRVectorDialect - ) diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp deleted file mode 100644 index 03053dee5af40..0000000000000 --- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp +++ /dev/null @@ -1,174 +0,0 @@ -//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Conversion/MathToXeVM/MathToXeVM.h" -#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" -#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" -#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/BuiltinDialect.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/Support/FormatVariadic.h" - -namespace mlir { -#define GEN_PASS_DEF_CONVERTMATHTOXEVM -#include "mlir/Conversion/Passes.h.inc" -} // namespace mlir - -using namespace mlir; - -#define DEBUG_TYPE "math-to-xevm" - -/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics. -template -struct ConvertNativeFuncPattern final : public OpConversionPattern { - - ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, - PatternBenefit benefit = 1) - : OpConversionPattern(context, benefit), nativeFunc(nativeFunc) {} - - LogicalResult - matchAndRewrite(Op op, typename Op::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!isSPIRVCompatibleFloatOrVec(op.getType())) - return failure(); - - arith::FastMathFlags fastFlags = op.getFastmath(); - if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn)) - return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation"); - - SmallVector operandTypes; - for (auto operand : adaptor.getOperands()) { - Type opTy = operand.getType(); - // This pass only supports operations on vectors that are already in SPIRV - // supported vector sizes: Distributing unsupported vector sizes to SPIRV - // supported vector sizes are done in other blocking optimization passes. - if (!isSPIRVCompatibleFloatOrVec(opTy)) - return rewriter.notifyMatchFailure( - op, llvm::formatv("incompatible operand type: '{0}'", opTy)); - operandTypes.push_back(opTy); - } - - auto moduleOp = op->template getParentWithTrait(); - auto funcOpRes = LLVM::lookupOrCreateFn( - rewriter, moduleOp, getMangledNativeFuncName(operandTypes), - operandTypes, op.getType()); - assert(!failed(funcOpRes)); - LLVM::LLVMFuncOp funcOp = funcOpRes.value(); - - auto callOp = rewriter.replaceOpWithNewOp( - op, funcOp, adaptor.getOperands()); - // Preserve fastmath flags in our MLIR op when converting to llvm function - // calls, in order to allow further fastmath optimizations: We thus need to - // convert arith fastmath attrs into attrs recognized by llvm. - arith::AttrConvertFastMathToLLVM fastAttrConverter(op); - mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0]; - callOp->setAttr(fastAttr.getName(), fastAttr.getValue()); - return success(); - } - - inline bool isSPIRVCompatibleFloatOrVec(Type type) const { - if (type.isFloat()) - return true; - if (auto vecType = dyn_cast(type)) { - if (!vecType.getElementType().isFloat()) - return false; - // SPIRV distinguishes between vectors and matrices: OpenCL native math - // intrsinics are not compatible with matrices. - ArrayRef shape = vecType.getShape(); - if (shape.size() != 1) - return false; - // SPIRV only allows vectors of size 2, 3, 4, 8, 16. - if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 || - shape[0] == 16) - return true; - } - return false; - } - - inline std::string - getMangledNativeFuncName(const ArrayRef operandTypes) const { - std::string mangledFuncName = - "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str(); - - auto appendFloatToMangledFunc = [&mangledFuncName](Type type) { - if (type.isF32()) - mangledFuncName += "f"; - else if (type.isF16()) - mangledFuncName += "Dh"; - else if (type.isF64()) - mangledFuncName += "d"; - }; - - for (auto type : operandTypes) { - if (auto vecType = dyn_cast(type)) { - mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_"; - appendFloatToMangledFunc(vecType.getElementType()); - } else - appendFloatToMangledFunc(type); - } - - return mangledFuncName; - } - - const StringRef nativeFunc; -}; - -void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, - bool convertArith) { - patterns.add>(patterns.getContext(), - "__spirv_ocl_native_exp"); - patterns.add>(patterns.getContext(), - "__spirv_ocl_native_cos"); - patterns.add>( - patterns.getContext(), "__spirv_ocl_native_exp2"); - patterns.add>(patterns.getContext(), - "__spirv_ocl_native_log"); - patterns.add>( - patterns.getContext(), "__spirv_ocl_native_log2"); - patterns.add>( - patterns.getContext(), "__spirv_ocl_native_log10"); - patterns.add>( - patterns.getContext(), "__spirv_ocl_native_powr"); - patterns.add>( - patterns.getContext(), "__spirv_ocl_native_rsqrt"); - patterns.add>(patterns.getContext(), - "__spirv_ocl_native_sin"); - patterns.add>( - patterns.getContext(), "__spirv_ocl_native_sqrt"); - patterns.add>(patterns.getContext(), - "__spirv_ocl_native_tan"); - if (convertArith) - patterns.add>( - patterns.getContext(), "__spirv_ocl_native_divide"); -} - -namespace { -struct ConvertMathToXeVMPass - : public impl::ConvertMathToXeVMBase { - using Base::Base; - void runOnOperation() override; -}; -} // namespace - -void ConvertMathToXeVMPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - populateMathToXeVMConversionPatterns(patterns, convertArith); - ConversionTarget target(getContext()); - target.addLegalDialect(); - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) - signalPassFailure(); -} diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir deleted file mode 100644 index d76627bb4201c..0000000000000 --- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir +++ /dev/null @@ -1,155 +0,0 @@ -// RUN: mlir-opt %s -convert-math-to-xevm \ -// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-ARITH' -// RUN: mlir-opt %s -convert-math-to-xevm='convert-arith=false' \ -// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-NO-ARITH' - -module @test_module { - // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 - // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 - // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64 - // - // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64> - // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64> - // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64> - // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64> - // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64> - // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv16_f(vector<16xf32>) -> vector<16xf32> - // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv4_Dh(vector<4xf16>) -> vector<4xf16> - // - // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16 - // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32 - // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16 - // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32 - // CHECK-DAG: llvm.func @_Z24__spirv_ocl_native_log10d(f64) -> f64 - // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_powrDhDh(f16, f16) -> f16 - // CHECK-DAG: llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64 - // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16 - // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32 - // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64 - // CHECK-ARITH-DAG: llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32 - - // CHECK-LABEL: func @math_ops - func.func @math_ops() { - - %c1_f16 = arith.constant 1. : f16 - %c1_f32 = arith.constant 1. : f32 - %c1_f64 = arith.constant 1. : f64 - - // CHECK: math.exp - %exp_normal_f16 = math.exp %c1_f16 : f16 - // CHECK: math.exp - %exp_normal_f32 = math.exp %c1_f32 : f32 - // CHECK: math.exp - %exp_normal_f64 = math.exp %c1_f64 : f64 - - // Check float operations are converted properly: - - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 - %exp_fast_f16 = math.exp %c1_f16 fastmath : f16 - // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 - %exp_fast_f32 = math.exp %c1_f32 fastmath : f32 - // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 - %exp_fast_f64 = math.exp %c1_f64 fastmath : f64 - - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 - %exp_afn_f16 = math.exp %c1_f16 fastmath : f16 - // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 - %exp_afn_f32 = math.exp %c1_f32 fastmath : f32 - // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 - %exp_afn_f64 = math.exp %c1_f64 fastmath : f64 - - // CHECK: math.exp - %exp_none_f16 = math.exp %c1_f16 fastmath : f16 - // CHECK: math.exp - %exp_none_f32 = math.exp %c1_f32 fastmath : f32 - // CHECK: math.exp - %exp_none_f64 = math.exp %c1_f64 fastmath : f64 - - // Check vector operations: - - %v2_c1_f64 = arith.constant dense<1.> : vector<2xf64> - %v3_c1_f64 = arith.constant dense<1.> : vector<3xf64> - %v4_c1_f64 = arith.constant dense<1.> : vector<4xf64> - %v8_c1_f64 = arith.constant dense<1.> : vector<8xf64> - %v16_c1_f64 = arith.constant dense<1.> : vector<16xf64> - - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv2_d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<2xf64>) -> vector<2xf64> - %exp_v2_f64 = math.exp %v2_c1_f64 fastmath : vector<2xf64> - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv3_d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<3xf64>) -> vector<3xf64> - %exp_v3_f64 = math.exp %v3_c1_f64 fastmath : vector<3xf64> - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<4xf64>) -> vector<4xf64> - %exp_v4_f64 = math.exp %v4_c1_f64 fastmath : vector<4xf64> - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv8_d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<8xf64>) -> vector<8xf64> - %exp_v8_f64 = math.exp %v8_c1_f64 fastmath : vector<8xf64> - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<16xf64>) -> vector<16xf64> - %exp_v16_f64 = math.exp %v16_c1_f64 fastmath : vector<16xf64> - - %v16_c1_f32 = arith.constant dense<1.> : vector<16xf32> - %v4_c1_f16 = arith.constant dense<1.> : vector<4xf16> - - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_f(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<16xf32>) -> vector<16xf32> - %exp_v16_f32 = math.exp %v16_c1_f32 fastmath : vector<16xf32> - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_Dh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<4xf16>) -> vector<4xf16> - %exp_v4_f16 = math.exp %v4_c1_f16 fastmath : vector<4xf16> - - // Check unsupported vector sizes are not converted: - - %v5_c1_f64 = arith.constant dense<1.> : vector<5xf64> - %v32_c1_f64 = arith.constant dense<1.> : vector<32xf64> - - // CHECK: math.exp - %exp_v5_f64 = math.exp %v5_c1_f64 fastmath : vector<5xf64> - // CHECK: math.exp - %exp_v32_f64 = math.exp %v32_c1_f64 fastmath : vector<32xf64> - - // Check fastmath flags propagate properly: - - // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 - %exp_fastmath_all_f16 = math.exp %c1_f16 fastmath : f16 - // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 - %exp_fastmath_most_f32 = math.exp %c1_f32 fastmath : f32 - // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 - %exp_afn_reassoc_nnan_f32 = math.exp %c1_f32 fastmath : f32 - - // Check all other math operations: - - // CHECK: llvm.call @_Z22__spirv_ocl_native_cosDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 - %cos_afn_f16 = math.cos %c1_f16 fastmath : f16 - - // CHECK: llvm.call @_Z23__spirv_ocl_native_exp2f(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 - %exp2_afn_f32 = math.exp2 %c1_f32 fastmath : f32 - - // CHECK: llvm.call @_Z22__spirv_ocl_native_logDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 - %log_afn_f16 = math.log %c1_f16 fastmath : f16 - - // CHECK: llvm.call @_Z23__spirv_ocl_native_log2f(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 - %log2_afn_f32 = math.log2 %c1_f32 fastmath : f32 - - // CHECK: llvm.call @_Z24__spirv_ocl_native_log10d(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 - %log10_afn_f64 = math.log10 %c1_f64 fastmath : f64 - - // CHECK: llvm.call @_Z23__spirv_ocl_native_powrDhDh(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16, f16) -> f16 - %powr_afn_f16 = math.powf %c1_f16, %c1_f16 fastmath : f16 - - // CHECK: llvm.call @_Z24__spirv_ocl_native_rsqrtd(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 - %rsqrt_afn_f64 = math.rsqrt %c1_f64 fastmath : f64 - - // CHECK: llvm.call @_Z22__spirv_ocl_native_sinDh(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 - %sin_afn_f16 = math.sin %c1_f16 fastmath : f16 - - // CHECK: llvm.call @_Z23__spirv_ocl_native_sqrtf(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 - %sqrt_afn_f32 = math.sqrt %c1_f32 fastmath : f32 - - // CHECK: llvm.call @_Z22__spirv_ocl_native_tand(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 - %tan_afn_f64 = math.tan %c1_f64 fastmath : f64 - - %c6_9_f32 = arith.constant 6.9 : f32 - %c7_f32 = arith.constant 7. : f32 - - // CHECK-ARITH: llvm.call @_Z25__spirv_ocl_native_divideff(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32, f32) -> f32 - // CHECK-NO-ARITH: arith.divf - %divf_afn_f32 = arith.divf %c6_9_f32, %c7_f32 fastmath : f32 - - return - } -} diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir deleted file mode 100644 index 2492adafd6a50..0000000000000 --- a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir +++ /dev/null @@ -1,118 +0,0 @@ -// RUN: mlir-opt %s -gpu-module-to-binary="format=isa" \ -// RUN: -debug-only=serialize-to-isa 2> %t -// RUN: FileCheck --input-file=%t %s -// -// MathToXeVM pass generates OpenCL intrinsics function calls when converting -// Math ops with `fastmath` attr to native function calls. It is assumed that -// the SPIRV backend would correctly convert these intrinsics calls to OpenCL -// ExtInst instructions in SPIRV (See llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp). -// -// To ensure this assumption holds, this test verifies that the SPIRV backend -// behaves as expected. - -module @test_ocl_intrinsics attributes {gpu.container_module} { - gpu.module @kernel [#xevm.target] { - llvm.func spir_kernelcc @native_fcns() attributes {gpu.kernel} { - // CHECK-DAG: %[[F16T:.+]] = OpTypeFloat 16 - // CHECK-DAG: %[[ZERO_F16:.+]] = OpConstantNull %[[F16T]] - %c0_f16 = llvm.mlir.constant(0. : f16) : f16 - // CHECK-DAG: %[[F32T:.+]] = OpTypeFloat 32 - // CHECK-DAG: %[[ZERO_F32:.+]] = OpConstantNull %[[F32T]] - %c0_f32 = llvm.mlir.constant(0. : f32) : f32 - // CHECK-DAG: %[[F64T:.+]] = OpTypeFloat 64 - // CHECK-DAG: %[[ZERO_F64:.+]] = OpConstantNull %[[F64T]] - %c0_f64 = llvm.mlir.constant(0. : f64) : f64 - - // CHECK-DAG: %[[V2F64T:.+]] = OpTypeVector %[[F64T]] 2 - // CHECK-DAG: %[[V2_ZERO_F64:.+]] = OpConstantNull %[[V2F64T]] - %v2_c0_f64 = llvm.mlir.constant(dense<0.> : vector<2xf64>) : vector<2xf64> - // CHECK-DAG: %[[V3F32T:.+]] = OpTypeVector %[[F32T]] 3 - // CHECK-DAG: %[[V3_ZERO_F32:.+]] = OpConstantNull %[[V3F32T]] - %v3_c0_f32 = llvm.mlir.constant(dense<0.> : vector<3xf32>) : vector<3xf32> - // CHECK-DAG: %[[V4F64T:.+]] = OpTypeVector %[[F64T]] 4 - // CHECK-DAG: %[[V4_ZERO_F64:.+]] = OpConstantNull %[[V4F64T]] - %v4_c0_f64 = llvm.mlir.constant(dense<0.> : vector<4xf64>) : vector<4xf64> - // CHECK-DAG: %[[V8F64T:.+]] = OpTypeVector %[[F64T]] 8 - // CHECK-DAG: %[[V8_ZERO_F64:.+]] = OpConstantNull %[[V8F64T]] - %v8_c0_f64 = llvm.mlir.constant(dense<0.> : vector<8xf64>) : vector<8xf64> - // CHECK-DAG: %[[V16F16T:.+]] = OpTypeVector %[[F16T]] 16 - // CHECK-DAG: %[[V16_ZERO_F16:.+]] = OpConstantNull %[[V16F16T]] - %v16_c0_f16 = llvm.mlir.constant(dense<0.> : vector<16xf16>) : vector<16xf16> - - // CHECK: OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]] - %exp_f16 = llvm.call @_Z22__spirv_ocl_native_expDh(%c0_f16) : (f16) -> f16 - // CHECK: OpExtInst %[[F32T]] %{{.+}} native_exp %[[ZERO_F32]] - %exp_f32 = llvm.call @_Z22__spirv_ocl_native_expf(%c0_f32) : (f32) -> f32 - // CHECK: OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]] - %exp_f64 = llvm.call @_Z22__spirv_ocl_native_expd(%c0_f64) : (f64) -> f64 - - // CHECK: OpExtInst %[[V2F64T]] %{{.+}} native_exp %[[V2_ZERO_F64]] - %exp_v2_f64 = llvm.call @_Z22__spirv_ocl_native_expDv2_f64(%v2_c0_f64) : (vector<2xf64>) -> vector<2xf64> - // CHECK: OpExtInst %[[V3F32T]] %{{.+}} native_exp %[[V3_ZERO_F32]] - %exp_v3_f32 = llvm.call @_Z22__spirv_ocl_native_expDv3_f32(%v3_c0_f32) : (vector<3xf32>) -> vector<3xf32> - // CHECK: OpExtInst %[[V4F64T]] %{{.+}} native_exp %[[V4_ZERO_F64]] - %exp_v4_f64 = llvm.call @_Z22__spirv_ocl_native_expDv4_f64(%v4_c0_f64) : (vector<4xf64>) -> vector<4xf64> - // CHECK: OpExtInst %[[V8F64T]] %{{.+}} native_exp %[[V8_ZERO_F64]] - %exp_v8_f64 = llvm.call @_Z22__spirv_ocl_native_expDv8_f64(%v8_c0_f64) : (vector<8xf64>) -> vector<8xf64> - // CHECK: OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]] - %exp_v16_f16 = llvm.call @_Z22__spirv_ocl_native_expDv16_f16(%v16_c0_f16) : (vector<16xf16>) -> vector<16xf16> - - // SPIRV backend does not currently handle fastmath flags: The SPIRV - // backend would need to generate OpDecorate calls to decorate math ops - // with FPFastMathMode/FPFastMathModeINTEL decorations. - // - // FIXME: When support for fastmath flags in the SPIRV backend is added, - // add tests here to ensure fastmath flags are converted to the correct - // OpDecorate calls. - // - // See: - // - https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_math_extended_instructions - // - https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDecorate - - // CHECK: OpExtInst %[[F16T]] %{{.+}} native_cos %[[ZERO_F16]] - %cos_afn_f16 = llvm.call @_Z22__spirv_ocl_native_cosDh(%c0_f16) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 - // CHECK: OpExtInst %[[F32T]] %{{.+}} native_exp2 %[[ZERO_F32]] - %exp2_afn_f32 = llvm.call @_Z23__spirv_ocl_native_exp2f(%c0_f32) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 - // CHECK: OpExtInst %[[F16T]] %{{.+}} native_log %[[ZERO_F16]] - %log_afn_f16 = llvm.call @_Z22__spirv_ocl_native_logDh(%c0_f16) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 - // CHECK: OpExtInst %[[F32T]] %{{.+}} native_log2 %[[ZERO_F32]] - %log2_afn_f32 = llvm.call @_Z23__spirv_ocl_native_log2f(%c0_f32) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 - // CHECK: OpExtInst %[[V8F64T]] %{{.+}} native_log10 %[[V8_ZERO_F64]] - %log10_afn_f64 = llvm.call @_Z24__spirv_ocl_native_log10Dv8_d(%v8_c0_f64) {fastmathFlags = #llvm.fastmath} : (vector<8xf64>) -> vector<8xf64> - // CHECK: OpExtInst %[[V16F16T]] %{{.+}} native_powr %[[V16_ZERO_F16]] %[[V16_ZERO_F16]] - %powr_afn_f16 = llvm.call @_Z23__spirv_ocl_native_powrDv16_DhS_(%v16_c0_f16, %v16_c0_f16) {fastmathFlags = #llvm.fastmath} : (vector<16xf16>, vector<16xf16>) -> vector<16xf16> - // CHECK: OpExtInst %[[F64T]] %{{.+}} native_rsqrt %[[ZERO_F64]] - %rsqrt_afn_f64 = llvm.call @_Z24__spirv_ocl_native_rsqrtd(%c0_f64) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 - // CHECK: OpExtInst %[[F16T]] %{{.+}} native_sin %[[ZERO_F16]] - %sin_afn_f16 = llvm.call @_Z22__spirv_ocl_native_sinDh(%c0_f16) {fastmathFlags = #llvm.fastmath} : (f16) -> f16 - // CHECK: OpExtInst %[[F32T]] %{{.+}} native_sqrt %[[ZERO_F32]] - %sqrt_afn_f32 = llvm.call @_Z23__spirv_ocl_native_sqrtf(%c0_f32) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 - // CHECK: OpExtInst %[[F64T]] %{{.+}} native_tan %[[ZERO_F64]] - %tan_afn_f64 = llvm.call @_Z22__spirv_ocl_native_tand(%c0_f64) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 - // CHECK: OpExtInst %[[F32T]] %{{.+}} native_divide %[[ZERO_F32]] %[[ZERO_F32]] - %divide_afn_f32 = llvm.call @_Z25__spirv_ocl_native_divideff(%c0_f32, %c0_f32) {fastmathFlags = #llvm.fastmath} : (f32, f32) -> f32 - - llvm.return - } - - llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 - llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 - llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64 - llvm.func @_Z22__spirv_ocl_native_expDv2_f64(vector<2xf64>) -> vector<2xf64> - llvm.func @_Z22__spirv_ocl_native_expDv3_f32(vector<3xf32>) -> vector<3xf32> - llvm.func @_Z22__spirv_ocl_native_expDv4_f64(vector<4xf64>) -> vector<4xf64> - llvm.func @_Z22__spirv_ocl_native_expDv8_f64(vector<8xf64>) -> vector<8xf64> - llvm.func @_Z22__spirv_ocl_native_expDv16_f16(vector<16xf16>) -> vector<16xf16> - llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16 - llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32 - llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16 - llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32 - llvm.func @_Z24__spirv_ocl_native_log10Dv8_d(vector<8xf64>) -> vector<8xf64> - llvm.func @_Z23__spirv_ocl_native_powrDv16_DhS_(vector<16xf16>, vector<16xf16>) -> vector<16xf16> - llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64 - llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16 - llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32 - llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64 - llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32 - } -}