Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/IR/PatternMatch.h"
#include <memory>

Expand All @@ -21,8 +20,7 @@ class Pass;

/// Populate the given list with patterns that convert from Math to ROCDL calls.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
amdgpu::Chipset chipset);
RewritePatternSet &patterns);
} // namespace mlir

#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
7 changes: 0 additions & 7 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -778,20 +778,13 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
let summary = "Convert Math dialect to ROCDL library calls";
let description = [{
This pass converts supported Math ops to ROCDL library calls.

The chipset option specifies the target AMDGPU architecture. If the chipset
is empty, none of the chipset-dependent patterns are added and the pass
will not attempt to parse the chipset.
}];
let dependentDialects = [
"arith::ArithDialect",
"func::FuncDialect",
"ROCDL::ROCDLDialect",
"vector::VectorDialect",
];
let options = [Option<"chipset", "chipset", "std::string",
/*default=*/"\"gfx000\"",
"Chipset that these operations will run on">];
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns(
GPUSubgroupBroadcastOpToROCDL>(converter);
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);

populateMathToROCDLConversionPatterns(converter, patterns, chipset);
populateMathToROCDLConversionPatterns(converter, patterns);
}
89 changes: 9 additions & 80 deletions mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
Expand Down Expand Up @@ -44,65 +42,8 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
f32ApproxFunc, f16Func);
}

struct ClampFOpConversion final
: public ConvertOpToLLVMPattern<math::ClampFOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
ClampFOpConversion(const LLVMTypeConverter &converter,
amdgpu::Chipset chipset)
: ConvertOpToLLVMPattern<math::ClampFOp>(converter), chipset(chipset) {}

LogicalResult
matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only f16 and f32 types are supported by fmed3
Type opTy = op.getType();
auto resultType = getTypeConverter()->convertType(opTy);

if (auto vectorType = dyn_cast<VectorType>(opTy)) {
opTy = vectorType.getElementType();
}

if (!isa<Float16Type, Float32Type>(opTy)) {
return rewriter.notifyMatchFailure(
op, "fmed3 only supports f16 and f32 types");
}

// Handle multi-dimensional vectors (converted to LLVM arrays)
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType)) {
// Handle multi-dimensional vectors (converted to LLVM arrays)
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
typename math::ClampFOp::Adaptor adaptor(operands);
return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
adaptor.getValue(), adaptor.getMin(),
adaptor.getMax());
},
rewriter);
}

// Handle 1D vectors and scalars directly
rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
op.getMin(), op.getMax());
return success();
}

amdgpu::Chipset chipset;
};

static void addChipsetDependentPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
amdgpu::Chipset chipset) {

// V_MED3_F16/F32 only exists in gfx9+ architectures
if (chipset.majorVersion >= 9) {
patterns.add<ClampFOpConversion>(converter, chipset);
}
}

void mlir::populateMathToROCDLConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
amdgpu::Chipset chipset) {
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::AbsFOp
// Handled by mathToLLVM: math::CopySignOp
Expand Down Expand Up @@ -177,17 +118,15 @@ void mlir::populateMathToROCDLConversionPatterns(
// worth creating a separate pass for it.
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
"__ocml_fmod_f64", "__ocml_fmod_f16");

addChipsetDependentPatterns(converter, patterns, chipset);
}

struct ConvertMathToROCDLPass final
: impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
using impl::ConvertMathToROCDLBase<
ConvertMathToROCDLPass>::ConvertMathToROCDLBase;

namespace {
struct ConvertMathToROCDLPass
: public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
ConvertMathToROCDLPass() = default;
void runOnOperation() override;
};
} // namespace

void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
Expand All @@ -196,20 +135,10 @@ void ConvertMathToROCDLPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);

// Only populate chipset-dependent patterns if chipset is specified
if (!chipset.empty()) {
FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
if (failed(maybeChipset)) {
return signalPassFailure();
}
populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset);
}

populateMathToROCDLConversionPatterns(converter, patterns);
ConversionTarget target(getContext());
target
.addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
vector::VectorDialect, LLVM::LLVMDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
Expand Down
76 changes: 1 addition & 75 deletions mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' | FileCheck %s --check-prefix=PRE9
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' | FileCheck %s --check-prefix=POST9
// RUN: mlir-opt %s -convert-math-to-rocdl -allow-unregistered-dialect -split-input-file | FileCheck %s

module @test_module {
// CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
Expand Down Expand Up @@ -597,76 +596,3 @@ module @test_module {
func.return %result : vector<2x2xf16>
}
}

// -----

// f16 clamp → rocdl.fmed3 on gfx9+
// CHECK-LABEL: func.func @clampf_f16
func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 {
%r = math.clampf %x to [%lo, %hi] : f16
return %r : f16
// POST9: rocdl.fmed3 {{.*}} : f16
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : f16
}

// f32 clamp → rocdl.fmed3 on gfx9+
// CHECK-LABEL: func.func @clampf_f32
func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 {
%r = math.clampf %x to [%lo, %hi] : f32
return %r : f32
// POST9: rocdl.fmed3 {{.*}} : f32
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : f32
}

// -----

// Vector f16 clamp → rocdl.fmed3 on gfx9+
// CHECK-LABEL: func.func @clampf_vector_f16
func.func @clampf_vector_f16(%x: vector<2xf16>, %lo: vector<2xf16>, %hi: vector<2xf16>) -> vector<2xf16> {
%r = math.clampf %x to [%lo, %hi] : vector<2xf16>
return %r : vector<2xf16>
// POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : vector<2xf16>
}

// -----

// Vector f32 clamp → rocdl.fmed3 on gfx9+
// CHECK-LABEL: func.func @clampf_vector_f32
func.func @clampf_vector_f32(%x: vector<2xf32>, %lo: vector<2xf32>, %hi: vector<2xf32>) -> vector<2xf32> {
%r = math.clampf %x to [%lo, %hi] : vector<2xf32>
return %r : vector<2xf32>
// POST9: rocdl.fmed3 {{.*}} : vector<2xf32>
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : vector<2xf32>
}

// -----

// Multi-dimensional vector f16 clamp → rocdl.fmed3 on gfx9+ (unrolled to 1D vectors)
// CHECK-LABEL: func.func @clampf_vector_2d_f16
func.func @clampf_vector_2d_f16(%x: vector<2x2xf16>, %lo: vector<2x2xf16>, %hi: vector<2x2xf16>) -> vector<2x2xf16> {
%r = math.clampf %x to [%lo, %hi] : vector<2x2xf16>
return %r : vector<2x2xf16>
// POST9: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
// POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
// POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
// POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
// POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
// POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
// POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : vector<2x2xf16>
}

// -----
// CHECK-LABEL: func.func @clampf_bf16
func.func @clampf_bf16(%x: bf16, %lo: bf16, %hi: bf16) -> bf16 {
%r = math.clampf %x to [%lo, %hi] : bf16
return %r : bf16
// CHECK: math.clampf {{.*}} : bf16
// CHECK-NOT: rocdl.fmed3
}
Loading