diff --git a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp index 67cef8af1e3b5..01a6c93965f96 100644 --- a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp @@ -11,12 +11,13 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" -#include "mlir/Dialect/GPU/Transforms/Passes.h" - #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/IR/PatternMatch.h" #include @@ -85,7 +86,7 @@ struct PromoteShuffleToPermlanePattern int64_t offsetValue = *offset; if (offsetValue != 16 && offsetValue != 32) - return rewriter.notifyMatchFailure(op, "offset must be either 15 or 31"); + return rewriter.notifyMatchFailure(op, "offset must be either 16 or 32"); Location loc = op.getLoc(); Value res = amdgpu::PermlaneSwapOp::create( @@ -96,13 +97,151 @@ struct PromoteShuffleToPermlanePattern } }; +static Value getLaneId(RewriterBase &rewriter, Location loc) { + auto int32Type = IntegerType::get(rewriter.getContext(), 32); + Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32); + Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32); + NamedAttribute noundef = {LLVM::LLVMDialect::getNoUndefAttrName(), + rewriter.getUnitAttr()}; + NamedAttribute lowRange = {LLVM::LLVMDialect::getRangeAttrName(), + LLVM::ConstantRangeAttr::get(rewriter.getContext(), + APInt::getZero(32), + APInt(32, 32))}; + NamedAttribute highRange = { + LLVM::LLVMDialect::getRangeAttrName(), + LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32), + APInt(32, 64))}; + Value mbcntLo = ROCDL::MbcntLoOp::create( + rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{}, + /*res_attrs=*/ + rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange}))); + Value laneId = ROCDL::MbcntHiOp::create( + rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{}, + rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange}))); + return laneId; +} + +/// Try to promote `gpu.shuffle` to `amdgpu.dpp`, width must be 64 +/// and offset must be a constant integer in the set {16, 32}. +struct PromoteShuffleToDPPPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(gpu::ShuffleOp op, + PatternRewriter &rewriter) const override { + std::optional width = getConstantIntValue(op.getWidth()); + if (!width) + return rewriter.notifyMatchFailure(op, + "width must be a constant integer"); + int64_t widthValue = *width; + if (!llvm::is_contained({4, 8, 12, 16, 32, 48, 64}, widthValue)) + return rewriter.notifyMatchFailure( + op, "width must be 4, 8, 12, 16, 32, 48 or 64"); + + std::optional offset = getConstantIntValue(op.getOffset()); + if (!offset) + return rewriter.notifyMatchFailure(op, + "offset must be a constant integer"); + + int64_t offsetValue = *offset; + Location loc = op.getLoc(); + auto int32Type = IntegerType::get(rewriter.getContext(), 32); + + amdgpu::DPPPerm kind; + Attribute permAttr = rewriter.getUnitAttr(); + Value srcLane; + Value dstLane; + switch (op.getMode()) { + case gpu::ShuffleMode::XOR: { + if (offsetValue != 1 && offsetValue != 2) + return rewriter.notifyMatchFailure( + op, "xor shuffle mode is only supported for offsets of 1 or 2"); + kind = amdgpu::DPPPerm::quad_perm; + srcLane = getLaneId(rewriter, loc); + dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLane, + op.getOffset()); + + if (offsetValue == 1) + permAttr = rewriter.getI32ArrayAttr({1, 0, 3, 2}); + else if (offsetValue == 2) + permAttr = rewriter.getI32ArrayAttr({2, 3, 0, 1}); + break; + } + case gpu::ShuffleMode::UP: { + if (offsetValue != 1) + return rewriter.notifyMatchFailure( + op, "up shuffle mode is only supported for offset 1"); + kind = amdgpu::DPPPerm::wave_shr; + srcLane = getLaneId(rewriter, loc); + dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLane, + op.getOffset()); + break; + } + case gpu::ShuffleMode::DOWN: { + if (offsetValue != 1) + return rewriter.notifyMatchFailure( + op, "down shuffle mode is only supported for offset 1"); + kind = amdgpu::DPPPerm::wave_shl; + srcLane = getLaneId(rewriter, loc); + dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLane, + op.getOffset()); + break; + } + case gpu::ShuffleMode::IDX: + return rewriter.notifyMatchFailure(op, + "idx shuffle mode is not supported"); + } + + unsigned bankMask = 0xF; + if (widthValue == 4) + bankMask = 0x1; + else if (widthValue == 8) + bankMask = 0x3; + else if (widthValue == 12) + bankMask = 0x7; + + unsigned rowMask = 0xF; + if (widthValue == 16) + rowMask = 0x1; + else if (widthValue == 32) + rowMask = 0x3; + else if (widthValue == 48) + rowMask = 0x7; + + constexpr bool boundCtrl = false; + + Value negwidth = + arith::ConstantIntOp::create(rewriter, loc, int32Type, -widthValue); + Value add = + arith::AddIOp::create(rewriter, loc, int32Type, srcLane, op.getWidth()); + Value widthOrZeroIfOutside = + arith::AndIOp::create(rewriter, loc, int32Type, add, negwidth); + Value isActiveSrcLane = + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, dstLane, + widthOrZeroIfOutside); + + Value dpp = amdgpu::DPPOp::create(rewriter, loc, op.getResult(0).getType(), + op.getValue(), op.getValue(), kind, + permAttr, rowMask, bankMask, boundCtrl); + Value poison = + LLVM::PoisonOp::create(rewriter, loc, op.getResult(0).getType()); + + Value selectResult = + arith::SelectOp::create(rewriter, loc, isActiveSrcLane, dpp, poison); + + rewriter.replaceOp(op, {selectResult, isActiveSrcLane}); + return success(); + } +}; + } // namespace void mlir::populateGpuPromoteShuffleToAMDGPUPatterns( RewritePatternSet &patterns, std::optional maybeChipset) { patterns.add(patterns.getContext(), /*benefit*/ 1); + patterns.add(patterns.getContext(), + /*benefit*/ 2); if (maybeChipset && *maybeChipset >= kGfx950) patterns.add(patterns.getContext(), - /*benefit*/ 2); + /*benefit*/ 3); } diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir index ef631ce8a12e5..5b5caf3dc0e8f 100755 --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir @@ -735,13 +735,18 @@ gpu.module @test_module { } // CHECK-LABEL: func @gpu_shuffle_promote() - func.func @gpu_shuffle_promote() -> (f32, f32, f32) { + func.func @gpu_shuffle_promote() -> (f32, f32, f32, f32, f32) { + // CHECK: %[[#POISON:]] = llvm.mlir.poison : f32 + // CHECK: %[[#NEGWIDTH:]] = llvm.mlir.constant(-64 : i32) : i32 // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 %arg0 = arith.constant 1.0 : f32 %arg1 = arith.constant 4 : i32 %arg2 = arith.constant 16 : i32 %arg3 = arith.constant 32 : i32 + // CHECK: %[[#WIDTH:]] = llvm.mlir.constant(64 : i32) : i32 %arg4 = arith.constant 64 : i32 + // CHECK: %[[#C1:]] = llvm.mlir.constant(1 : i32) : i32 + %arg5 = arith.constant 1 : i32 // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32 // CHECK: %[[#MASK:]] = llvm.mlir.constant(4127 : i32) : i32 // CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32 @@ -763,7 +768,84 @@ gpu.module @test_module { // CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32 // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32 %shfl3, %pred3 = gpu.shuffle xor %arg0, %arg3, %arg4 : f32 - func.return %shfl1, %shfl2, %shfl3 : f32, f32, f32 + // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi + // CHECK: %[[#SUB:]] = llvm.sub %[[#LANE_ID]], %[[#C1]] : i32 + // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32 + // CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32 + // CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#SUB]], %[[#AND]] : i32 + // CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 312, 15, 15, false : f32 + // CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32 + %shflu, %predu = gpu.shuffle up %arg0, %arg5, %arg4 : f32 + // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi + // CHECK: %[[#OP:]] = llvm.add %[[#LANE_ID]], %[[#C1]] : i32 + // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32 + // CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32 + // CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#OP]], %[[#AND]] : i32 + // CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 304, 15, 15, false : f32 + // CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32 + %shfld, %predd = gpu.shuffle down %arg0, %arg5, %arg4 : f32 + func.return %shfl1, %shfl2, %shfl3, %shflu, %shfld : f32, f32, f32, f32, f32 + } + + // CHECK-LABEL: func @gpu_butterfly_shuffle() + func.func @gpu_butterfly_shuffle() -> (f32, f32, f32, f32, f32, f32) { + // CHECK: %[[#POISON:]] = llvm.mlir.poison : f32 + // CHECK: %[[#NEGWIDTH:]] = llvm.mlir.constant(-64 : i32) : i32 + // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 + %arg0 = arith.constant 1.0 : f32 + // CHECK: %[[#C1:]] = llvm.mlir.constant(1 : i32) : i32 + %c1 = arith.constant 1 : i32 + // CHECK: %[[#C2:]] = llvm.mlir.constant(2 : i32) : i32 + %c2 = arith.constant 2 : i32 + %c4 = arith.constant 4 : i32 + %c8 = arith.constant 8 : i32 + %c16 = arith.constant 16 : i32 + %c32 = arith.constant 32 : i32 + // CHECK: %[[#WIDTH:]] = llvm.mlir.constant(64 : i32) : i32 + %c64 = arith.constant 64 : i32 + // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi + // CHECK: %[[#XOR:]] = llvm.xor %[[#LANE_ID]], %[[#C1]] : i32 + // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32 + // CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32 + // CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#XOR]], %[[#AND]] : i32 + // CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 177, 15, 15, false : f32 + // CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32 + %shfl1, %pred1 = gpu.shuffle xor %arg0, %c1, %c64 : f32 + // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi + // CHECK: %[[#XOR:]] = llvm.xor %[[#LANE_ID]], %[[#C2]] : i32 + // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32 + // CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32 + // CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#XOR]], %[[#AND]] : i32 + // CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 78, 15, 15, false : f32 + // CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32 + %shfl2, %pred2 = gpu.shuffle xor %arg0, %c2, %c64 : f32 + // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32 + // CHECK: %[[#MASK:]] = llvm.mlir.constant(4127 : i32) : i32 + // CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32 + // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32 + %shfl3, %pred3 = gpu.shuffle xor %arg0, %c4, %c64 : f32 + // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32 + // CHECK: %[[#MASK:]] = llvm.mlir.constant(8223 : i32) : i32 + // CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32 + // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32 + %shfl4, %pred4 = gpu.shuffle xor %arg0, %c8, %c64 : f32 + // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32 + // CHECK: %[[#PERMUTE:]] = rocdl.permlane16.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)> + // CHECK: %[[#EXTRACT0:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)> + // CHECK: %[[#EXTRACT1:]] = llvm.extractvalue %[[#PERMUTE:]][1] : !llvm.struct<(i32, i32)> + // CHECK: %[[#CMP:]] = llvm.icmp "eq" %[[#EXTRACT0]], %[[#CAST_VALUE]] : i32 + // CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32 + // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32 + %shfl5, %pred5 = gpu.shuffle xor %arg0, %c16, %c64 : f32 + // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32 + // CHECK: %[[#PERMUTE:]] = rocdl.permlane32.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)> + // CHECK: %[[#EXTRACT0:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)> + // CHECK: %[[#EXTRACT1:]] = llvm.extractvalue %[[#PERMUTE:]][1] : !llvm.struct<(i32, i32)> + // CHECK: %[[#CMP:]] = llvm.icmp "eq" %[[#EXTRACT0]], %[[#CAST_VALUE]] : i32 + // CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32 + // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32 + %shfl6, %pred6 = gpu.shuffle xor %arg0, %c32, %c64 : f32 + func.return %shfl1, %shfl2, %shfl3, %shfl4, %shfl5, %shfl6 : f32, f32, f32, f32, f32, f32 } // CHECK-LABEL: func @gpu_shuffle_vec