1- // ===- DecomposeSubgroupReduceToDPP .cpp - Decompose subgroup reduce pass -===//
1+ // ===- GPUToAMDGPU .cpp - GPU to AMDGPU dialect conversion ------ -===//
22//
33// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66//
77// ===----------------------------------------------------------------------===//
8- //
9- // This file implements decompose subgroup reduce to DPP pass.
10- //
11- // ===----------------------------------------------------------------------===//
128
13- #include " mlir/Dialect/GPU/Transforms/Passes .h"
9+ #include " mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU .h"
1410
11+ #include " mlir/Conversion/LLVMCommon/ConversionTarget.h"
12+ #include " mlir/Conversion/LLVMCommon/Pattern.h"
13+ #include " mlir/Conversion/LLVMCommon/TypeConverter.h"
1514#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1615#include " mlir/Dialect/AMDGPU/Utils/Chipset.h"
1716#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
2322#include " mlir/Conversion/GPUCommon/GPUCommonPass.h"
2423#include " mlir/Dialect/GPU/IR/GPUDialect.h"
2524#include " mlir/Dialect/Vector/IR/VectorOps.h"
26- #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
27- #include " mlir/Dialect/GPU/Transforms/Passes.h"
2825
2926#include " llvm/Support/FormatVariadic.h"
27+ #include " llvm/Support/MathExtras.h"
28+ #include < cassert>
29+ #include < cstdint>
30+
31+ #include " ../LLVMCommon/MemRefDescriptor.h"
32+
33+ #include " llvm/ADT/STLExtras.h"
34+ #include < optional>
3035
3136namespace mlir {
32- #define GEN_PASS_DEF_GPUDECOMPOSESUBGROUPREDUCETODPPPASS
33- #include " mlir/Dialect/GPU/Transforms /Passes.h.inc"
37+ #define GEN_PASS_DEF_CONVERTGPUTOAMDGPUPASS
38+ #include " mlir/Conversion /Passes.h.inc"
3439} // namespace mlir
3540
3641using namespace mlir ;
@@ -139,8 +144,8 @@ Value createSubgroupDPPReduction(OpBuilder &b, Location loc, Value input,
139144struct ScalarSubgroupReduceToShuffles final
140145 : OpRewritePattern<gpu::SubgroupReduceOp> {
141146 ScalarSubgroupReduceToShuffles (MLIRContext *ctx, unsigned subgroupSize,
142- bool matchClustered)
143- : OpRewritePattern(ctx), subgroupSize(subgroupSize),
147+ bool matchClustered, PatternBenefit benefit )
148+ : OpRewritePattern(ctx, benefit ), subgroupSize(subgroupSize),
144149 matchClustered (matchClustered) {}
145150
146151 LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
@@ -169,24 +174,30 @@ struct ScalarSubgroupReduceToShuffles final
169174 bool matchClustered = false ;
170175};
171176
172- struct GpuDecomposeSubgroupReduceToDppPass
173- : public impl::GpuDecomposeSubgroupReduceToDppPassBase<
174- GpuDecomposeSubgroupReduceToDppPass> {
177+ struct ConvertGPUToAMDGPUPass
178+ : public impl::ConvertGPUToAMDGPUPassBase<ConvertGPUToAMDGPUPass> {
175179 using Base::Base;
176180
177181 void runOnOperation () override {
178182 RewritePatternSet patterns (&getContext ());
179- // int subgroupSizeInt = static_cast<int>(subgroupSize);
180- populateGpuDecomposeSubgroupReduceToDppPatterns (patterns, subgroupSize);
181- if (failed (applyPatternsGreedily (getOperation (), std::move (patterns))))
182- return signalPassFailure ();
183+ LLVMTypeConverter converter (&getContext ());
184+ LLVMConversionTarget target (getContext ());
185+ target.addLegalDialect <::mlir::LLVM::LLVMDialect>();
186+ target.addLegalDialect <::mlir::amdgpu::AMDGPUDialect>();
187+ target.addLegalDialect <::mlir::ROCDL::ROCDLDialect>();
188+
189+ int subgroupSizeInt = static_cast <int >(subgroupSize);
190+ populateSubgroupReduceLoweringPatterns (converter, patterns, subgroupSizeInt,
191+ PatternBenefit (1 ));
192+ if (failed (applyPartialConversion (getOperation (), target,
193+ std::move (patterns))))
194+ signalPassFailure ();
183195 }
184196};
185-
186197} // namespace
187198
188- void mlir::populateGpuDecomposeSubgroupReduceToDppPatterns (
189- RewritePatternSet &patterns, unsigned subgroupSize) {
199+ void mlir::populateSubgroupReduceLoweringPatterns (
200+ LLVMTypeConverter &converter, RewritePatternSet &patterns, unsigned subgroupSize, PatternBenefit benefit ) {
190201 patterns.add <ScalarSubgroupReduceToShuffles>(
191- patterns.getContext (), subgroupSize, /* matchClustered=*/ true );
192- }
202+ patterns.getContext (), subgroupSize, /* matchClustered=*/ true , benefit );
203+ }
0 commit comments