diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index e6dd6f135884e..83c5700e00263 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -113,6 +113,60 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern { } }; +/// Lowers gpu.shuffle xor to ds_swizzle if possible. +struct GPUShuffleOpLoweringSwizzle + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getMode() != gpu::ShuffleMode::XOR) + return rewriter.notifyMatchFailure(op, "only xor mode is supported"); + + // Check unconverted width and offset. + if (!isConstantIntValue(op.getWidth(), 64)) + return rewriter.notifyMatchFailure(op, "width must be 64"); + + std::optional offsetVal = getConstantIntValue(op.getOffset()); + if (!offsetVal) + return rewriter.notifyMatchFailure(op, "offset must be a constant"); + + int64_t offset = *offsetVal; + if (offset < 0 || offset >= (1 << 5)) + return rewriter.notifyMatchFailure(op, "unsupported offset value"); + + Location loc = op.getLoc(); + Value initShflValue = adaptor.getValue(); + + auto int32Type = rewriter.getI32Type(); + + // TODO: It may be possible to lower specific xor patterns to DPP ops. + + // bit 15 is 0 for the BitMode swizzle. + // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ + int64_t mask = ((1 << 5) - 1) | (offset << 10); + Value maskValue = rewriter.create(loc, int32Type, mask); + + SmallVector decomposed = + LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type); + SmallVector swizzled; + for (Value v : decomposed) { + Value res = + rewriter.create(loc, v.getType(), v, maskValue); + swizzled.emplace_back(res); + } + Value shflValue = + LLVM::composeValue(rewriter, loc, swizzled, initShflValue.getType()); + + // We checked width is 64, so it's always true. + Value isActiveSrcLane = + rewriter.create(loc, rewriter.getI1Type(), 1); + rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); + return success(); + } +}; + struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -135,13 +189,13 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); + Location loc = op.getLoc(); Value initShflValue = adaptor.getValue(); const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth); - auto int32Type = IntegerType::get(rewriter.getContext(), 32); + auto int32Type = rewriter.getI32Type(); Value width = adaptor.getWidth(); Value zero = rewriter.create(loc, int32Type, 0); Value negwidth = rewriter.create(loc, int32Type, zero, width); @@ -177,14 +231,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { SmallVector decomposed = LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type); - SmallVector swizzled; + SmallVector permuted; for (Value v : decomposed) { Value res = rewriter.create(loc, int32Type, dwordAlignedDstLane, v); - swizzled.emplace_back(res); + permuted.emplace_back(res); } Value shflValue = - LLVM::composeValue(rewriter, loc, swizzled, initShflValue.getType()); + LLVM::composeValue(rewriter, loc, permuted, initShflValue.getType()); rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); return success(); } @@ -405,6 +459,8 @@ void mlir::populateGpuToROCDLConversionPatterns( // TODO: Add alignment for workgroup memory patterns.add(converter); + // Try to lower to swizzle first + patterns.add(converter, /*benefit*/ 10); patterns.add(converter); populateMathToROCDLConversionPatterns(converter, patterns); diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir index 071cae9d5789f..6072c21e59bb3 100644 --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir @@ -659,52 +659,47 @@ gpu.module @test_module { // ----- gpu.module @test_module { - // CHECK-LABEL: func @gpu_shuffle() - func.func @gpu_shuffle() -> (f32, f32, f32) { - // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 - %arg0 = arith.constant 1.0 : f32 - // CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : i32 - %arg1 = arith.constant 4 : i32 - // CHECK: %[[#WIDTH:]] = llvm.mlir.constant(23 : i32) : i32 - %arg2 = arith.constant 23 : i32 + // CHECK-LABEL: func @gpu_shuffle + // CHECK-SAME: (%[[VALUE:.*]]: f32, %[[OFFSET:.*]]: i32, %[[WIDTH:.*]]: i32) + func.func @gpu_shuffle(%arg0: f32, %arg1: i32, %arg2: i32) -> (f32, f32, f32) { // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi // CHECK: %[[#ZERO:]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[#WIDTH]] : i32 - // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32 + // CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[WIDTH]] : i32 + // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[WIDTH]] : i32 // CHECK: %[[#WARP_OR_ZERO:]] = llvm.and %[[#ADD]], %[[#NEG_WIDTH]] : i32 // CHECK: %[[#XOR:]] = llvm.xor %[[#LANE_ID]], %{{.*}} : i32 // CHECK: %[[#CMP:]] = llvm.icmp "slt" %[[#XOR]], %[[#WARP_OR_ZERO]] : i32 // CHECK: %[[#DST_LANE:]] = llvm.select %[[#CMP]], %[[#XOR]], %{{.*}} : i1, i32 // CHECK: %[[#TWO:]] = llvm.mlir.constant(2 : i32) : i32 // CHECK: %[[#ALIGNED_DST_LANE:]] = llvm.shl %[[#DST_LANE]], %[[#TWO]] : i32 - // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32 + // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[VALUE]] : f32 to i32 // CHECK: %[[#PERMUTE:]] = rocdl.ds_bpermute %[[#ALIGNED_DST_LANE]], %[[#CAST_VALUE]] : (i32, i32) -> i32 // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32 %shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : f32 // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi // CHECK: %[[#ZERO:]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[#WIDTH]] : i32 - // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32 + // CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[WIDTH]] : i32 + // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[WIDTH]] : i32 // CHECK: %[[#WARP_OR_ZERO:]] = llvm.and %[[#ADD]], %[[#NEG_WIDTH]] : i32 - // CHECK: %[[#CMP:]] = llvm.icmp "slt" %[[#OFFSET]], %[[#WARP_OR_ZERO]] : i32 - // CHECK: %[[#DST_LANE:]] = llvm.select %[[#CMP]], %[[#OFFSET]], %{{.*}} : i1, i32 + // CHECK: %[[#CMP:]] = llvm.icmp "slt" %[[OFFSET]], %[[#WARP_OR_ZERO]] : i32 + // CHECK: %[[#DST_LANE:]] = llvm.select %[[#CMP]], %[[OFFSET]], %{{.*}} : i1, i32 // CHECK: %[[#TWO:]] = llvm.mlir.constant(2 : i32) : i32 // CHECK: %[[#ALIGNED_DST_LANE:]] = llvm.shl %[[#DST_LANE]], %[[#TWO]] : i32 - // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32 + // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[VALUE]] : f32 to i32 // CHECK: %[[#PERMUTE:]] = rocdl.ds_bpermute %[[#ALIGNED_DST_LANE]], %[[#CAST_VALUE]] : (i32, i32) -> i32 // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32 %shfli, %predi = gpu.shuffle idx %arg0, %arg1, %arg2 : f32 // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi // CHECK: %[[#ZERO:]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[#WIDTH]] : i32 - // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32 + // CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[WIDTH]] : i32 + // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[WIDTH]] : i32 // CHECK: %[[#WARP_OR_ZERO:]] = llvm.and %[[#ADD]], %[[#NEG_WIDTH]] : i32 // CHECK: %[[#DOWN:]] = llvm.add %[[#LANE_ID]], %{{.*}} : i32 // CHECK: %[[#CMP:]] = llvm.icmp "slt" %[[#DOWN]], %[[#WARP_OR_ZERO]] : i32 // CHECK: %[[#DST_LANE:]] = llvm.select %[[#CMP]], %[[#DOWN]], %{{.*}} : i1, i32 // CHECK: %[[#TWO:]] = llvm.mlir.constant(2 : i32) : i32 // CHECK: %[[#ALIGNED_DST_LANE:]] = llvm.shl %[[#DST_LANE]], %[[#TWO]] : i32 - // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32 + // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[VALUE]] : f32 to i32 // CHECK: %[[#PERMUTE:]] = rocdl.ds_bpermute %[[#ALIGNED_DST_LANE]], %[[#CAST_VALUE]] : (i32, i32) -> i32 // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32 %shfld, %predd = gpu.shuffle down %arg0, %arg1, %arg2 : f32 @@ -731,6 +726,18 @@ gpu.module @test_module { %shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : vector<4xf16> func.return %shfl : vector<4xf16> } + + // CHECK-LABEL: func @gpu_shuffle_swizzle + // CHECK-SAME: (%[[ARG:.*]]: i32) + func.func @gpu_shuffle_swizzle(%arg0: i32) -> i32 { + // CHECK: %[[MASK:.*]] = llvm.mlir.constant(4127 : i32) : i32 + // CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ARG]], %[[MASK]] : (i32, i32) -> i32 + // CHECK: llvm.return %[[RES]] : i32 + %width = arith.constant 64 : i32 + %offset = arith.constant 4 : i32 + %shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : i32 + func.return %shfl : i32 + } } // -----