diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index 66c8731ec2bf4..7a58e4fc2f984 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -31,6 +31,18 @@ LogicalResult oneToOneRewrite( IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none); } // namespace detail + +/// Decomposes a `src` value into a set of values of type `dstType` through +/// series of bitcasts and vector ops. Src and dst types are expected to be int +/// or float types or vector types of them. +SmallVector decomposeValue(OpBuilder &builder, Location loc, Value src, + Type dstType); + +/// Composes a set of `src` values into a single value of type `dstType` through +/// series of bitcasts and vector ops. Inversely to `decomposeValue`, this +/// function is used to combine multiple values into a single value. +Value composeValue(OpBuilder &builder, Location loc, ValueRange src, + Type dstType); } // namespace LLVM /// Base class for operation conversions targeting the LLVM IR dialect. It diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index a92ebf6d8e108..f14aa5a2e1564 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -38,7 +38,7 @@ def AMDGPU_Dialect : Dialect { def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">; def AnyIntegerOrFloatOr1DVector : - AnyTypeOf<[AnyIntegerOrFloat, VectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>; + AnyTypeOf<[AnyIntegerOrFloat, FixedVectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>; //===----------------------------------------------------------------------===// // AMDGPU general attribute definitions diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 5f697bdeef566..91dbc2de65c4e 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1377,6 +1377,39 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern { } }; +struct AMDGPUSwizzleBitModeLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Type i32 = rewriter.getI32Type(); + Value src = adaptor.getSrc(); + SmallVector decomposed = + LLVM::decomposeValue(rewriter, loc, src, i32); + unsigned andMask = op.getAndMask(); + unsigned orMask = op.getOrMask(); + unsigned xorMask = op.getXorMask(); + + // bit 15 is 0 for the BitMode swizzle. + // https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ + unsigned mask = andMask | (orMask << 5) | (xorMask << 10); + Value maskValue = createI32Constant(rewriter, loc, mask); + SmallVector swizzled; + for (Value v : decomposed) { + Value res = + rewriter.create(loc, v.getType(), v, maskValue); + swizzled.emplace_back(res); + } + + Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType()); + rewriter.replaceOp(op, result); + return success(); + } +}; + struct ConvertAMDGPUToROCDLPass : public impl::ConvertAMDGPUToROCDLPassBase { using Base::Base; @@ -1444,4 +1477,5 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter, chipset); + patterns.add(converter); } diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 32bfd72475569..1ae99561e9d1b 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -381,3 +381,96 @@ LogicalResult LLVM::detail::oneToOneRewrite( rewriter.replaceOp(op, results); return success(); } + +static unsigned getBitWidth(Type type) { + if (type.isIntOrFloat()) + return type.getIntOrFloatBitWidth(); + + auto vec = cast(type); + assert(!vec.isScalable() && "scalable vectors are not supported"); + return vec.getNumElements() * getBitWidth(vec.getElementType()); +} + +static Value createI32Constant(OpBuilder &builder, Location loc, + int32_t value) { + Type i32 = builder.getI32Type(); + return builder.create(loc, i32, value); +} + +SmallVector mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc, + Value src, Type dstType) { + Type srcType = src.getType(); + if (srcType == dstType) + return {src}; + + unsigned srcBitWidth = getBitWidth(srcType); + unsigned dstBitWidth = getBitWidth(dstType); + if (srcBitWidth == dstBitWidth) { + Value cast = builder.create(loc, dstType, src); + return {cast}; + } + + if (dstBitWidth > srcBitWidth) { + auto smallerInt = builder.getIntegerType(srcBitWidth); + if (srcType != smallerInt) + src = builder.create(loc, smallerInt, src); + + auto largerInt = builder.getIntegerType(dstBitWidth); + Value res = builder.create(loc, largerInt, src); + return {res}; + } + assert(srcBitWidth % dstBitWidth == 0 && + "src bit width must be a multiple of dst bit width"); + int64_t numElements = srcBitWidth / dstBitWidth; + auto vecType = VectorType::get(numElements, dstType); + + src = builder.create(loc, vecType, src); + + SmallVector res; + for (auto i : llvm::seq(numElements)) { + Value idx = createI32Constant(builder, loc, i); + Value elem = builder.create(loc, src, idx); + res.emplace_back(elem); + } + + return res; +} + +Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src, + Type dstType) { + assert(!src.empty() && "src range must not be empty"); + if (src.size() == 1) { + Value res = src.front(); + if (res.getType() == dstType) + return res; + + unsigned srcBitWidth = getBitWidth(res.getType()); + unsigned dstBitWidth = getBitWidth(dstType); + if (dstBitWidth < srcBitWidth) { + auto largerInt = builder.getIntegerType(srcBitWidth); + if (res.getType() != largerInt) + res = builder.create(loc, largerInt, res); + + auto smallerInt = builder.getIntegerType(dstBitWidth); + res = builder.create(loc, smallerInt, res); + } + + if (res.getType() != dstType) + res = builder.create(loc, dstType, res); + + return res; + } + + int64_t numElements = src.size(); + auto srcType = VectorType::get(numElements, src.front().getType()); + Value res = builder.create(loc, srcType); + for (auto &&[i, elem] : llvm::enumerate(src)) { + Value idx = createI32Constant(builder, loc, i); + res = builder.create(loc, srcType, res, elem, idx); + } + + if (res.getType() != dstType) + res = builder.create(loc, dstType, res); + + return res; +} diff --git a/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir b/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir new file mode 100644 index 0000000000000..ef439efde1bd0 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir @@ -0,0 +1,75 @@ +// RUN: mlir-opt -convert-amdgpu-to-rocdl --canonicalize %s | FileCheck %s + +// CHECK-LABEL: func @test_swizzle_i32 +// CHECK-SAME: (%[[ARG0:.*]]: i32) +func.func @test_swizzle_i32(%arg0 : i32) -> i32 { +// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32 +// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ARG0]], %[[C]] : (i32, i32) -> i32 +// CHECK: return %[[RES]] : i32 + %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : i32 + return %0 : i32 +} + +// CHECK-LABEL: func @test_swizzle_f32 +// CHECK-SAME: (%[[ARG0:.*]]: f32) +func.func @test_swizzle_f32(%arg0 : f32) -> f32 { +// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32 +// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32 +// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[CAST]], %[[C]] : (i32, i32) -> i32 +// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32 +// CHECK: return %[[RES_CAST]] : f32 + %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32 + return %0 : f32 +} + +// CHECK-LABEL: func @test_swizzle_f16 +// CHECK-SAME: (%[[ARG0:.*]]: f16) +func.func @test_swizzle_f16(%arg0 : f16) -> f16 { +// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32 +// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16 +// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32 +// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ZEXT]], %[[C]] : (i32, i32) -> i32 +// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16 +// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16 +// CHECK: return %[[RES_CAST]] : f16 + %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f16 + return %0 : f16 +} + +// CHECK-LABEL: func @test_swizzle_2xi32 +// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>) +func.func @test_swizzle_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> { +// CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32> +// CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32 +// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[E0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32> +// CHECK: %[[E1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32> +// CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32 +// CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32 +// CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32> +// CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32> +// CHECK: return %[[V3]] : vector<2xi32> + %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func @test_swizzle_4xf16 +// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>) +func.func @test_swizzle_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> { +// CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32> +// CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32 +// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32> +// CHECK: %[[E0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32> +// CHECK: %[[E1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32> +// CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32 +// CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32 +// CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32> +// CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32> +// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[V3]] : vector<2xi32> to vector<4xf16> +// CHECK: return %[[CAST2]] : vector<4xf16> + %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<4xf16> + return %0 : vector<4xf16> +} diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir index 40f98ff85688c..73306ba6b3f93 100644 --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -154,7 +154,15 @@ func.func @fat_raw_buffer_cast_stripping_offset_affine_map(%m: memref<8xi32, aff // ----- func.func @swizzle_invalid_type(%arg0 : si32) -> si32 { - // expected-error@+1 {{amdgpu.swizzle_bitmode' op operand #0 must be Integer or Float or vector of Integer or Float values of ranks 1}} + // expected-error@+1 {{'amdgpu.swizzle_bitmode' op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1}} %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : si32 func.return %0 : si32 } + +// ----- + +func.func @swizzle_scalable_vec(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> { + // expected-error@+1 {{'amdgpu.swizzle_bitmode' op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1}} + %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<[4]xf32> + func.return %0 : vector<[4]xf32> +}