Skip to content
Merged
4 changes: 3 additions & 1 deletion mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/IR/PatternMatch.h"
#include <memory>

Expand All @@ -20,7 +21,8 @@ class Pass;

/// Populate the given list with patterns that convert from Math to ROCDL calls.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
RewritePatternSet &patterns,
amdgpu::Chipset chipset);
} // namespace mlir

#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
8 changes: 8 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,14 @@ def ConvertMathToLibmPass : Pass<"convert-math-to-libm", "ModuleOp"> {
"func::FuncDialect",
"vector::VectorDialect",
];
let options = [
Option<"chipset", "chipset", "std::string",


/*default=*/"\"gfx000\"",
"Chipset that these operations will run on">
];

}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns(
GPUSubgroupBroadcastOpToROCDL>(converter);
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);

populateMathToROCDLConversionPatterns(converter, patterns);
populateMathToROCDLConversionPatterns(converter, patterns, chipset);
}
54 changes: 45 additions & 9 deletions mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
Expand Down Expand Up @@ -42,8 +43,39 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
f32ApproxFunc, f16Func);
}

struct ClampFOpConversion final
: public ConvertOpToLLVMPattern<math::ClampFOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
ClampFOpConversion(const LLVMTypeConverter &converter,
amdgpu::Chipset chipset)
: ConvertOpToLLVMPattern<math::ClampFOp>(converter), chipset(chipset) {}

LogicalResult
matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// V_MED3_F16/F32 only exists in gfx9+ artchitectures
if (chipset.majorVersion < 9) {
return rewriter.notifyMatchFailure(
op, ("pre-gfx9 (gfx" + std::to_string(chipset.majorVersion) +
"): V_MED_F16 / V_MED3_F32 not supported."));
}
rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
op.getMin(), op.getMax());
return success();
}
amdgpu::Chipset chipset;
};

static void addChipsetDependentPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
amdgpu::Chipset chipset) {

patterns.add<ClampFOpConversion>(converter, chipset);
}

void mlir::populateMathToROCDLConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
amdgpu::Chipset chipset) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::AbsFOp
// Handled by mathToLLVM: math::CopySignOp
Expand Down Expand Up @@ -118,27 +150,31 @@ void mlir::populateMathToROCDLConversionPatterns(
// worth creating a separate pass for it.
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
"__ocml_fmod_f64", "__ocml_fmod_f16");

addChipsetDependentPatterns(converter, patterns, chipset);
}

namespace {
struct ConvertMathToROCDLPass
: public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
ConvertMathToROCDLPass() = default;
struct ConvertMathToROCDLPass final
: impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
using impl::ConvertMathToROCDLBase<
ConvertMathToROCDLPass>::ConvertMathToROCDLBase;

void runOnOperation() override;
};
} // namespace

void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
MLIRContext *ctx = m.getContext();
FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);

RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
populateMathToROCDLConversionPatterns(converter, patterns);
populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset);
ConversionTarget target(getContext());
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
vector::VectorDialect, LLVM::LLVMDialect>();
target
.addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
Expand Down
32 changes: 31 additions & 1 deletion mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -convert-math-to-rocdl -allow-unregistered-dialect -split-input-file | FileCheck %s
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' | FileCheck %s --check-prefix=PRE9
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' | FileCheck %s --check-prefix=POST9

module @test_module {
// CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
Expand Down Expand Up @@ -596,3 +597,32 @@ module @test_module {
func.return %result : vector<2x2xf16>
}
}

// -----

// f16 clamp → rocdl.fmed3 on gfx9+
func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 {
%r = math.clampf %x to [%lo, %hi] : f16
return %r : f16
}

// f32 clamp → rocdl.fmed3 on gfx9+
func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 {
%r = math.clampf %x to [%lo, %hi] : f32
return %r : f32
}

// POST9-LABEL: func.func @clampf_f16
// POST9: rocdl.fmed3 {{.*}} : f16
// POST9: return

// POST9-LABEL: func.func @clampf_f32
// POST9: rocdl.fmed3 {{.*}} : f32
// POST9: return

// PRE9-LABEL: func.func @clampf_f16
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : f16

// PRE9-LABEL: func.func @clampf_f32
// PRE9-NOT: rocdl.fmed3
Loading