Skip to content
Merged
27 changes: 27 additions & 0 deletions mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===- MathToXeVM.h - Utils for converting Math to XeVM -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
#define MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/IR/PatternMatch.h"
#include <memory>

namespace mlir {
class Pass;

#define GEN_PASS_DECL_CONVERTMATHTOXEVM
#include "mlir/Conversion/Passes.h.inc"

/// Populate the given list with patterns that convert from Math to XeVM calls.
void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
bool convertArith);
} // namespace mlir

#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
Expand Down
26 changes: 26 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,32 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> {
let dependentDialects = ["spirv::SPIRVDialect"];
}

//===----------------------------------------------------------------------===//
// MathToXeVM
//===----------------------------------------------------------------------===//

def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
let summary =
"Convert (fast) math operations to native XeVM/SPIRV equivalents";
let description = [{
This pass converts supported math ops marked with the `afn` fastmath flag
to function calls for OpenCL `native_` math intrinsics: These intrinsics
are typically mapped directly to native device instructions, often resulting
in better performance. However, the precision/error of these intrinsics
are implementation-defined, and thus math ops are only converted when they
have the `afn` fastmath flag enabled.
}];
let options = [Option<
"convertArith", "convert-arith", "bool", /*default=*/"true",
"Convert supported Arith ops (e.g. arith.divf) as well.">];
let dependentDialects = [
"arith::ArithDialect",
"func::FuncDialect",
"xevm::XeVMDialect",
"vector::VectorDialect",
];
}

//===----------------------------------------------------------------------===//
// MathToEmitC
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
add_subdirectory(MathToROCDL)
add_subdirectory(MathToSPIRV)
add_subdirectory(MathToXeVM)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
Expand Down
19 changes: 19 additions & 0 deletions mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
add_mlir_conversion_library(MLIRMathToXeVM
MathToXeVM.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRMathDialect
MLIRLLVMCommonConversion
MLIRPass
MLIRTransformUtils
MLIRVectorDialect
)
178 changes: 178 additions & 0 deletions mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"

#include "../GPUCommon/GPUOpsLowering.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOXEVM
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

#define DEBUG_TYPE "math-to-xevm"

/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
template <typename Op>
struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {

ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc,
PatternBenefit benefit = 1)
: OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}

LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isSPIRVCompatibleFloatOrVec(op.getType()))
return failure();

arith::FastMathFlags fastFlags = op.getFastmath();
if (!(static_cast<uint32_t>(fastFlags) &
static_cast<uint32_t>(arith::FastMathFlags::afn)))
return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation");

SmallVector<Type, 1> operandTypes;
for (auto operand : adaptor.getOperands()) {
Type opTy = operand.getType();
// This pass only supports operations on vectors that are already in SPIRV
// supported vector sizes: Distributing unsupported vector sizes to SPIRV
// supported vector sizes are done in other blocking optimization passes.
if (!isSPIRVCompatibleFloatOrVec(opTy))
return rewriter.notifyMatchFailure(
op, llvm::formatv("incompatible operand type: '{0}'", opTy));
operandTypes.push_back(opTy);
}

auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
auto funcOpRes = LLVM::lookupOrCreateFn(
rewriter, moduleOp, getMangledNativeFuncName(operandTypes),
operandTypes, op.getType());
assert(!failed(funcOpRes));
LLVM::LLVMFuncOp funcOp = funcOpRes.value();

auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, funcOp, adaptor.getOperands());
// Preserve fastmath flags in our MLIR op when converting to llvm function
// calls, in order to allow further fastmath optimizations: We thus need to
// convert arith fastmath attrs into attrs recognized by llvm.
arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
return success();
}

inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
if (type.isFloat()) {
return true;
} else if (auto vecType = dyn_cast<VectorType>(type)) {
if (!vecType.getElementType().isFloat())
return false;
// SPIRV distinguishes between vectors and matrices: OpenCL native math
// intrsinics are not compatible with matrices.
ArrayRef<int64_t> shape = vecType.getShape();
if (shape.size() != 1)
return false;
// SPIRV only allows vectors of size 2, 3, 4, 8, 16.
if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
shape[0] == 16)
return true;
}
return false;
}

inline std::string
getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
std::string mangledFuncName =
"_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();

auto appendFloatToMangledFunc = [&mangledFuncName](Type type) {
if (type.isF32())
mangledFuncName += "f";
else if (type.isF16())
mangledFuncName += "Dh";
else if (type.isF64())
mangledFuncName += "d";
};

for (auto type : operandTypes) {
if (auto vecType = dyn_cast<VectorType>(type)) {
mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
appendFloatToMangledFunc(vecType.getElementType());
} else
appendFloatToMangledFunc(type);
}

return mangledFuncName;
}

const StringRef nativeFunc;
};

void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
bool convertArith) {
patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
"__spirv_ocl_native_exp");
patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(),
"__spirv_ocl_native_cos");
patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>(
patterns.getContext(), "__spirv_ocl_native_exp2");
patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(),
"__spirv_ocl_native_log");
patterns.add<ConvertNativeFuncPattern<math::Log2Op>>(
patterns.getContext(), "__spirv_ocl_native_log2");
patterns.add<ConvertNativeFuncPattern<math::Log10Op>>(
patterns.getContext(), "__spirv_ocl_native_log10");
patterns.add<ConvertNativeFuncPattern<math::PowFOp>>(
patterns.getContext(), "__spirv_ocl_native_powr");
patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>(
patterns.getContext(), "__spirv_ocl_native_rsqrt");
patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(),
"__spirv_ocl_native_sin");
patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>(
patterns.getContext(), "__spirv_ocl_native_sqrt");
patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(),
"__spirv_ocl_native_tan");
if (convertArith)
patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>(
patterns.getContext(), "__spirv_ocl_native_divide");
}

namespace {
struct ConvertMathToXeVMPass
: public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
using Base::Base;
void runOnOperation() override;
};
} // namespace

void ConvertMathToXeVMPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateMathToXeVMConversionPatterns(patterns, convertArith);
ConversionTarget target(getContext());
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
vector::VectorDialect, LLVM::LLVMDialect>();
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}
Loading