Skip to content
Merged
4 changes: 3 additions & 1 deletion mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/IR/PatternMatch.h"
#include <memory>

Expand All @@ -20,7 +21,8 @@ class Pass;

/// Populate the given list with patterns that convert from Math to ROCDL calls.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
RewritePatternSet &patterns,
amdgpu::Chipset chipset);
} // namespace mlir

#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
7 changes: 7 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -778,13 +778,20 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
let summary = "Convert Math dialect to ROCDL library calls";
let description = [{
This pass converts supported Math ops to ROCDL library calls.

The chipset option specifies the target AMDGPU architecture. If the chipset
is empty, none of the chipset-dependent patterns are added and the pass
will not attempt to parse the chipset.
}];
let dependentDialects = [
"arith::ArithDialect",
"func::FuncDialect",
"ROCDL::ROCDLDialect",
"vector::VectorDialect",
];
let options = [Option<"chipset", "chipset", "std::string",
/*default=*/"\"gfx000\"",
"Chipset that these operations will run on">];
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns(
GPUSubgroupBroadcastOpToROCDL>(converter);
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);

populateMathToROCDLConversionPatterns(converter, patterns);
populateMathToROCDLConversionPatterns(converter, patterns, chipset);
}
89 changes: 80 additions & 9 deletions mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
Expand Down Expand Up @@ -42,8 +44,65 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
f32ApproxFunc, f16Func);
}

struct ClampFOpConversion final
: public ConvertOpToLLVMPattern<math::ClampFOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
ClampFOpConversion(const LLVMTypeConverter &converter,
amdgpu::Chipset chipset)
: ConvertOpToLLVMPattern<math::ClampFOp>(converter), chipset(chipset) {}

LogicalResult
matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only f16 and f32 types are supported by fmed3
Type opTy = op.getType();
auto resultType = getTypeConverter()->convertType(opTy);

if (auto vectorType = dyn_cast<VectorType>(opTy)) {
opTy = vectorType.getElementType();
}

if (!isa<Float16Type, Float32Type>(opTy)) {
return rewriter.notifyMatchFailure(
op, "fmed3 only supports f16 and f32 types");
}

// Handle multi-dimensional vectors (converted to LLVM arrays)
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType)) {
// Handle multi-dimensional vectors (converted to LLVM arrays)
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
typename math::ClampFOp::Adaptor adaptor(operands);
return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
adaptor.getValue(), adaptor.getMin(),
adaptor.getMax());
},
rewriter);
}

// Handle 1D vectors and scalars directly
rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
op.getMin(), op.getMax());
return success();
}

amdgpu::Chipset chipset;
};

static void addChipsetDependentPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
amdgpu::Chipset chipset) {

// V_MED3_F16/F32 only exists in gfx9+ architectures
if (chipset.majorVersion >= 9) {
patterns.add<ClampFOpConversion>(converter, chipset);
}
}

void mlir::populateMathToROCDLConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
amdgpu::Chipset chipset) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::AbsFOp
// Handled by mathToLLVM: math::CopySignOp
Expand Down Expand Up @@ -118,15 +177,17 @@ void mlir::populateMathToROCDLConversionPatterns(
// worth creating a separate pass for it.
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
"__ocml_fmod_f64", "__ocml_fmod_f16");

addChipsetDependentPatterns(converter, patterns, chipset);
}

namespace {
struct ConvertMathToROCDLPass
: public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
ConvertMathToROCDLPass() = default;
struct ConvertMathToROCDLPass final
: impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
using impl::ConvertMathToROCDLBase<
ConvertMathToROCDLPass>::ConvertMathToROCDLBase;

void runOnOperation() override;
};
} // namespace

void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
Expand All @@ -135,10 +196,20 @@ void ConvertMathToROCDLPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
populateMathToROCDLConversionPatterns(converter, patterns);

// Only populate chipset-dependent patterns if chipset is specified
if (!chipset.empty()) {
FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
if (failed(maybeChipset)) {
return signalPassFailure();
}
populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset);
}
Comment on lines +201 to +207
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong, the call to populateMathToROCDLConversionPatterns shouldn't be guarded by an if.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


ConversionTarget target(getContext());
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
vector::VectorDialect, LLVM::LLVMDialect>();
target
.addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
Expand Down
76 changes: 75 additions & 1 deletion mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -convert-math-to-rocdl -allow-unregistered-dialect -split-input-file | FileCheck %s
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' | FileCheck %s --check-prefix=PRE9
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' | FileCheck %s --check-prefix=POST9

module @test_module {
// CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
Expand Down Expand Up @@ -596,3 +597,76 @@ module @test_module {
func.return %result : vector<2x2xf16>
}
}

// -----

// f16 clamp → rocdl.fmed3 on gfx9+
// CHECK-LABEL: func.func @clampf_f16
func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 {
%r = math.clampf %x to [%lo, %hi] : f16
return %r : f16
// POST9: rocdl.fmed3 {{.*}} : f16
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : f16
}

// f32 clamp → rocdl.fmed3 on gfx9+
// CHECK-LABEL: func.func @clampf_f32
func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 {
%r = math.clampf %x to [%lo, %hi] : f32
return %r : f32
// POST9: rocdl.fmed3 {{.*}} : f32
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : f32
}

// -----

// Vector f16 clamp → rocdl.fmed3 on gfx9+
// CHECK-LABEL: func.func @clampf_vector_f16
func.func @clampf_vector_f16(%x: vector<2xf16>, %lo: vector<2xf16>, %hi: vector<2xf16>) -> vector<2xf16> {
%r = math.clampf %x to [%lo, %hi] : vector<2xf16>
return %r : vector<2xf16>
// POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : vector<2xf16>
}

// -----

// Vector f32 clamp → rocdl.fmed3 on gfx9+
// CHECK-LABEL: func.func @clampf_vector_f32
func.func @clampf_vector_f32(%x: vector<2xf32>, %lo: vector<2xf32>, %hi: vector<2xf32>) -> vector<2xf32> {
%r = math.clampf %x to [%lo, %hi] : vector<2xf32>
return %r : vector<2xf32>
// POST9: rocdl.fmed3 {{.*}} : vector<2xf32>
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : vector<2xf32>
}

// -----

// Multi-dimensional vector f16 clamp → rocdl.fmed3 on gfx9+ (unrolled to 1D vectors)
// CHECK-LABEL: func.func @clampf_vector_2d_f16
func.func @clampf_vector_2d_f16(%x: vector<2x2xf16>, %lo: vector<2x2xf16>, %hi: vector<2x2xf16>) -> vector<2x2xf16> {
%r = math.clampf %x to [%lo, %hi] : vector<2x2xf16>
return %r : vector<2x2xf16>
// POST9: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
// POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
// POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
// POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
// POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
// POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
// POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : vector<2x2xf16>
}

// -----
// CHECK-LABEL: func.func @clampf_bf16
func.func @clampf_bf16(%x: bf16, %lo: bf16, %hi: bf16) -> bf16 {
%r = math.clampf %x to [%lo, %hi] : bf16
return %r : bf16
// CHECK: math.clampf {{.*}} : bf16
// CHECK-NOT: rocdl.fmed3
}