1212
1313#include " mlir/Dialect/Arith/IR/Arith.h"
1414#include " mlir/Dialect/GPU/IR/GPUDialect.h"
15+ #include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
16+ #include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1517#include " mlir/Dialect/GPU/Transforms/Passes.h"
1618#include " mlir/Dialect/GPU/Utils/GPUUtils.h"
1719#include " mlir/Dialect/Vector/IR/VectorOps.h"
@@ -362,6 +364,106 @@ struct VectorSubgroupReduceToShuffles final
362364 unsigned shuffleBitwidth = 0 ;
363365 bool matchClustered = false ;
364366};
367+
368+ Value createSubgroupDPPReduction (OpBuilder &b, Location loc, Value input,
369+ gpu::AllReduceOperation mode,
370+ const ClusterInfo &ci) {
371+ Value result = input;
372+ if (ci.clusterSize >= 2 ) {
373+ auto permArg = b.getIntegerAttr (b.getIntegerType (32 ), 1 );
374+ Value dppResult =
375+ b.create <amdgpu::DPPOp>(loc, result.getType (), result, result,
376+ amdgpu::DPPPerm::row_shl, permArg);
377+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
378+ result, dppResult);
379+ }
380+
381+ if (ci.clusterSize >= 4 ) {
382+ auto permArg = b.getIntegerAttr (b.getIntegerType (32 ), 2 );
383+ Value dppResult =
384+ b.create <amdgpu::DPPOp>(loc, result.getType (), result, result,
385+ amdgpu::DPPPerm::row_shl, permArg);
386+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
387+ result, dppResult);
388+ }
389+
390+ if (ci.clusterSize >= 8 ) {
391+ Value dppResult = b.create <amdgpu::DPPOp>(
392+ loc, result.getType (), result, result, amdgpu::DPPPerm::row_half_mirror,
393+ b.getUnitAttr ());
394+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
395+ result, dppResult);
396+ }
397+
398+ if (ci.clusterSize >= 16 ) {
399+ Value dppResult =
400+ b.create <amdgpu::DPPOp>(loc, result.getType (), result, result,
401+ amdgpu::DPPPerm::row_mirror, b.getUnitAttr ());
402+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
403+ result, dppResult);
404+ }
405+
406+ const int allRows = 0xf ;
407+ const int allBanks = 0xf ;
408+ auto int32Type = IntegerType::get (b.getContext (), 32 );
409+ if (ci.clusterSize >= 32 ) {
410+ auto permArg = b.getIntegerAttr (b.getIntegerType (32 ), 15 );
411+ Value dppResult = b.create <amdgpu::DPPOp>(
412+ loc, result.getType (), result, result, amdgpu::DPPPerm::row_bcast_15,
413+ b.getUnitAttr (), 0xa , allBanks, false );
414+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
415+ result, dppResult);
416+ if (ci.subgroupSize == 32 ) {
417+ Value lane01 = b.create <LLVM::ConstantOp>(loc, int32Type, 1 );
418+ result =
419+ b.create <ROCDL::ReadlaneOp>(loc, input.getType (), result, lane01);
420+ }
421+ }
422+
423+ if (ci.clusterSize == 64 ) {
424+ auto permArg = b.getIntegerAttr (b.getIntegerType (32 ), 31 );
425+ Value dppResult = b.create <amdgpu::DPPOp>(
426+ loc, result.getType (), result, result, amdgpu::DPPPerm::row_bcast_31,
427+ b.getUnitAttr (), allRows, allBanks, false );
428+ result = vector::makeArithReduction (b, loc, gpu::convertReductionKind (mode),
429+ result, dppResult);
430+ Value lane63 = b.create <LLVM::ConstantOp>(loc, int32Type, 63 );
431+ result = b.create <ROCDL::ReadlaneOp>(loc, input.getType (), result, lane63);
432+ }
433+
434+ assert (result.getType () == input.getType ());
435+ return result;
436+ }
437+
438+ struct ScalarSubgroupReduceToDPP final
439+ : OpRewritePattern<gpu::SubgroupReduceOp> {
440+ ScalarSubgroupReduceToDPP (MLIRContext *ctx, unsigned subgroupSize,
441+ bool matchClustered, PatternBenefit benefit)
442+ : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
443+ matchClustered (matchClustered) {}
444+
445+ LogicalResult matchAndRewrite (gpu::SubgroupReduceOp op,
446+ PatternRewriter &rewriter) const override {
447+ if (op.getClusterSize ().has_value () != matchClustered) {
448+ return rewriter.notifyMatchFailure (
449+ op, llvm::formatv (" op is {0}clustered but pattern is configured to "
450+ " only match {1}clustered ops" ,
451+ matchClustered ? " non-" : " " ,
452+ matchClustered ? " " : " non-" ));
453+ }
454+ auto ci = getAndValidateClusterInfo (op, subgroupSize);
455+ if (failed (ci))
456+ return failure ();
457+ Location loc = op.getLoc ();
458+ rewriter.replaceOp (op, createSubgroupDPPReduction (
459+ rewriter, loc, op.getValue (), op.getOp (), *ci));
460+ return success ();
461+ }
462+
463+ private:
464+ unsigned subgroupSize = 0 ;
465+ bool matchClustered = false ;
466+ };
365467} // namespace
366468
367469void mlir::populateGpuBreakDownSubgroupReducePatterns (
@@ -372,6 +474,13 @@ void mlir::populateGpuBreakDownSubgroupReducePatterns(
372474 patterns.add <ScalarizeSingleElementReduce>(patterns.getContext (), benefit);
373475}
374476
477+ void mlir::populateGpuLowerSubgroupReduceToDPPPatterns (
478+ RewritePatternSet &patterns, unsigned subgroupSize,
479+ PatternBenefit benefit) {
480+ patterns.add <ScalarSubgroupReduceToDPP>(patterns.getContext (), subgroupSize,
481+ /* matchClustered=*/ true , benefit);
482+ }
483+
375484void mlir::populateGpuLowerSubgroupReduceToShufflePatterns (
376485 RewritePatternSet &patterns, unsigned subgroupSize,
377486 unsigned shuffleBitwidth, PatternBenefit benefit) {
0 commit comments