1- // ===- GPUToAMDGPU .cpp - GPU to AMDGPU dialect conversion ------ -===//
1+ // ===- DecomposeSubgroupReduceToDPP .cpp - Decompose subgroup reduce pass -===//
22//
33// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66//
77// ===----------------------------------------------------------------------===//
8+ //
9+ // This file implements decompose subgroup reduce to DPP pass.
10+ //
11+ // ===----------------------------------------------------------------------===//
812
9- #include " mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU .h"
13+ #include " mlir/Dialect/GPU/Transforms/Passes .h"
1014
11- #include " mlir/Conversion/LLVMCommon/ConversionTarget.h"
12- #include " mlir/Conversion/LLVMCommon/Pattern.h"
13- #include " mlir/Conversion/LLVMCommon/TypeConverter.h"
1415#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1516#include " mlir/Dialect/AMDGPU/Utils/Chipset.h"
1617#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
2223#include " mlir/Conversion/GPUCommon/GPUCommonPass.h"
2324#include " mlir/Dialect/GPU/IR/GPUDialect.h"
2425#include " mlir/Dialect/Vector/IR/VectorOps.h"
26+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
27+ #include " mlir/Dialect/GPU/Transforms/Passes.h"
2528
2629#include " llvm/Support/FormatVariadic.h"
27- #include " llvm/Support/MathExtras.h"
28- #include < cassert>
29- #include < cstdint>
30-
31- #include " ../LLVMCommon/MemRefDescriptor.h"
32-
33- #include " llvm/ADT/STLExtras.h"
34- #include < optional>
3530
3631namespace mlir {
37- #define GEN_PASS_DEF_CONVERTGPUTOAMDGPUPASS
38- #include " mlir/Conversion /Passes.h.inc"
32+ #define GEN_PASS_DEF_GPUDECOMPOSESUBGROUPREDUCETODPPPASS
33+ #include " mlir/Dialect/GPU/Transforms /Passes.h.inc"
3934} // namespace mlir
4035
4136using namespace mlir ;
@@ -144,8 +139,8 @@ Value createSubgroupDPPReduction(OpBuilder &b, Location loc, Value input,
144139struct ScalarSubgroupReduceToShuffles final
145140 : OpRewritePattern<gpu::SubgroupReduceOp> {
146141 ScalarSubgroupReduceToShuffles (MLIRContext *ctx, unsigned subgroupSize,
147- bool matchClustered, PatternBenefit benefit )
148- : OpRewritePattern(ctx, benefit ), subgroupSize(subgroupSize),
142+ bool matchClustered)
143+ : OpRewritePattern(ctx), subgroupSize(subgroupSize),
149144 matchClustered (matchClustered) {}
150145
151146 LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
@@ -174,30 +169,24 @@ struct ScalarSubgroupReduceToShuffles final
174169 bool matchClustered = false ;
175170};
176171
177- struct ConvertGPUToAMDGPUPass
178- : public impl::ConvertGPUToAMDGPUPassBase<ConvertGPUToAMDGPUPass> {
172+ struct GpuDecomposeSubgroupReduceToDppPass
173+ : public impl::GpuDecomposeSubgroupReduceToDppPassBase<
174+ GpuDecomposeSubgroupReduceToDppPass> {
179175 using Base::Base;
180176
181177 void runOnOperation () override {
182178 RewritePatternSet patterns (&getContext ());
183- LLVMTypeConverter converter (&getContext ());
184- LLVMConversionTarget target (getContext ());
185- target.addLegalDialect <::mlir::LLVM::LLVMDialect>();
186- target.addLegalDialect <::mlir::amdgpu::AMDGPUDialect>();
187- target.addLegalDialect <::mlir::ROCDL::ROCDLDialect>();
188-
189- int subgroupSizeInt = static_cast <int >(subgroupSize);
190- populateSubgroupReduceLoweringPatterns (converter, patterns, subgroupSizeInt,
191- PatternBenefit (1 ));
192- if (failed (applyPartialConversion (getOperation (), target,
193- std::move (patterns))))
194- signalPassFailure ();
179+ // int subgroupSizeInt = static_cast<int>(subgroupSize);
180+ populateGpuDecomposeSubgroupReduceToDppPatterns (patterns, subgroupSize);
181+ if (failed (applyPatternsGreedily (getOperation (), std::move (patterns))))
182+ return signalPassFailure ();
195183 }
196184};
185+
197186} // namespace
198187
199- void mlir::populateSubgroupReduceLoweringPatterns (
200- LLVMTypeConverter &converter, RewritePatternSet &patterns, unsigned subgroupSize, PatternBenefit benefit ) {
188+ void mlir::populateGpuDecomposeSubgroupReduceToDppPatterns (
189+ RewritePatternSet &patterns, unsigned subgroupSize) {
201190 patterns.add <ScalarSubgroupReduceToShuffles>(
202- patterns.getContext (), subgroupSize, /* matchClustered=*/ true , benefit );
203- }
191+ patterns.getContext (), subgroupSize, /* matchClustered=*/ true );
192+ }
0 commit comments