-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir] AMDGPUToROCDL: lower amdgpu.swizzle_bitmode
#136223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Repack `amdgpu.swizzle_bitmode` arguments and lower it to `rocdl.ds_swizzle`. Repacking logic is follows: * `sizeof(arg) < sizeof(i32)`: bitcast to integer and zext to i32 and then trunc and bitcast back. * `sizeof(arg) == sizeof(i32)`: just bitcast to i32 and back if not i32 * `sizeof(arg) > sizeof(i32)`: bitcast to `vector<Nxi32>`, extract individual elements and do a series of `rocdl.ds_swizzle` and then compose vector and bitcast back. Added repacking logic to LLVM utils so it can be used elsewhere. I'm planning to use it for `gpu.shuffle` later.
|
@llvm/pr-subscribers-mlir-amdgpu @llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) ChangesRepack Repacking logic is follows:
Added repacking logic to LLVM utils so it can be used elsewhere. I'm planning to use it for Full diff: https://github.com/llvm/llvm-project/pull/136223.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 66c8731ec2bf4..7a58e4fc2f984 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -31,6 +31,18 @@ LogicalResult oneToOneRewrite(
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
} // namespace detail
+
+/// Decomposes a `src` value into a set of values of type `dstType` through
+/// series of bitcasts and vector ops. Src and dst types are expected to be int
+/// or float types or vector types of them.
+SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
+ Type dstType);
+
+/// Composes a set of `src` values into a single value of type `dstType` through
+/// series of bitcasts and vector ops. Inversely to `decomposeValue`, this
+/// function is used to combine multiple values into a single value.
+Value composeValue(OpBuilder &builder, Location loc, ValueRange src,
+ Type dstType);
} // namespace LLVM
/// Base class for operation conversions targeting the LLVM IR dialect. It
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5f697bdeef566..5c4c95699142f 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1377,6 +1377,38 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
}
};
+struct AMDGPUSwizzleBitModeLowering
+ : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Type i32 = rewriter.getI32Type();
+ Value src = adaptor.getSrc();
+ SmallVector<Value> decomposed =
+ LLVM::decomposeValue(rewriter, loc, src, i32);
+ unsigned andMask = op.getAndMask();
+ unsigned orMask = op.getOrMask();
+ unsigned xorMask = op.getXorMask();
+
+ // bit 15 is 0 for the BitMode swizzle.
+ unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
+ Value maskValue = createI32Constant(rewriter, loc, mask);
+ SmallVector<Value> swizzled;
+ for (Value v : decomposed) {
+ Value res =
+ rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
+ swizzled.emplace_back(res);
+ }
+
+ Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
using Base::Base;
@@ -1444,4 +1476,5 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
GatherToLDSOpLowering>(converter, chipset);
+ patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 32bfd72475569..d2737c56369d1 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -381,3 +381,95 @@ LogicalResult LLVM::detail::oneToOneRewrite(
rewriter.replaceOp(op, results);
return success();
}
+
+static unsigned getBitWidth(Type type) {
+ if (type.isIntOrFloat())
+ return type.getIntOrFloatBitWidth();
+
+ auto vec = cast<VectorType>(type);
+ return vec.getNumElements() * getBitWidth(vec.getElementType());
+}
+
+static Value createI32Constant(OpBuilder &builder, Location loc,
+ int32_t value) {
+ Type i32 = builder.getI32Type();
+ return builder.create<LLVM::ConstantOp>(loc, i32, value);
+}
+
+SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
+ Value src, Type dstType) {
+ Type srcType = src.getType();
+ if (srcType == dstType)
+ return {src};
+
+ unsigned srcBitWidth = getBitWidth(srcType);
+ unsigned dstBitWidth = getBitWidth(dstType);
+ if (srcBitWidth == dstBitWidth) {
+ Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src);
+ return {cast};
+ }
+
+ if (dstBitWidth > srcBitWidth) {
+ auto smallerInt = builder.getIntegerType(srcBitWidth);
+ if (srcType != smallerInt)
+ src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src);
+
+ auto largerInt = builder.getIntegerType(dstBitWidth);
+ Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src);
+ return {res};
+ }
+ assert(srcBitWidth % dstBitWidth == 0 &&
+ "src bit width must be a multiple of dst bit width");
+ int64_t numElements = srcBitWidth / dstBitWidth;
+ auto vecType = VectorType::get(numElements, dstType);
+
+ src = builder.create<LLVM::BitcastOp>(loc, vecType, src);
+
+ SmallVector<Value> res;
+ for (auto i : llvm::seq<int64_t>(0, numElements)) {
+ Value idx = createI32Constant(builder, loc, i);
+ Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx);
+ res.emplace_back(elem);
+ }
+
+ return res;
+}
+
+Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
+ Type dstType) {
+ assert(!src.empty() && "src range must not be empty");
+ if (src.size() == 1) {
+ Value res = src.front();
+ if (res.getType() == dstType)
+ return res;
+
+ unsigned srcBitWidth = getBitWidth(res.getType());
+ unsigned dstBitWidth = getBitWidth(dstType);
+ if (dstBitWidth < srcBitWidth) {
+ auto largerInt = builder.getIntegerType(srcBitWidth);
+ if (res.getType() != largerInt)
+ res = builder.create<LLVM::BitcastOp>(loc, largerInt, res);
+
+ auto smallerInt = builder.getIntegerType(dstBitWidth);
+ res = builder.create<LLVM::TruncOp>(loc, smallerInt, res);
+ }
+
+ if (res.getType() != dstType)
+ res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+
+ return res;
+ }
+
+ int64_t numElements = src.size();
+ auto srcType = VectorType::get(numElements, src.front().getType());
+ Value res = builder.create<LLVM::PoisonOp>(loc, srcType);
+ for (auto &&[i, elem] : llvm::enumerate(src)) {
+ Value idx = createI32Constant(builder, loc, i);
+ res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx);
+ }
+
+ if (res.getType() != dstType)
+ res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+
+ return res;
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir b/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
new file mode 100644
index 0000000000000..ef439efde1bd0
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -convert-amdgpu-to-rocdl --canonicalize %s | FileCheck %s
+
+// CHECK-LABEL: func @test_swizzle_i32
+// CHECK-SAME: (%[[ARG0:.*]]: i32)
+func.func @test_swizzle_i32(%arg0 : i32) -> i32 {
+// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ARG0]], %[[C]] : (i32, i32) -> i32
+// CHECK: return %[[RES]] : i32
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: func @test_swizzle_f32
+// CHECK-SAME: (%[[ARG0:.*]]: f32)
+func.func @test_swizzle_f32(%arg0 : f32) -> f32 {
+// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
+// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[CAST]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK: return %[[RES_CAST]] : f32
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func @test_swizzle_f16
+// CHECK-SAME: (%[[ARG0:.*]]: f16)
+func.func @test_swizzle_f16(%arg0 : f16) -> f16 {
+// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
+// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
+// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ZEXT]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
+// CHECK: return %[[RES_CAST]] : f16
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: func @test_swizzle_2xi32
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
+func.func @test_swizzle_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
+// CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[E0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[E1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: return %[[V3]] : vector<2xi32>
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @test_swizzle_4xf16
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
+func.func @test_swizzle_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
+// CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
+// CHECK: %[[E0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[E1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[V3]] : vector<2xi32> to vector<4xf16>
+// CHECK: return %[[CAST2]] : vector<4xf16>
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<4xf16>
+ return %0 : vector<4xf16>
+}
|
|
@llvm/pr-subscribers-mlir-gpu Author: Ivan Butygin (Hardcode84) ChangesRepack Repacking logic is follows:
Added repacking logic to LLVM utils so it can be used elsewhere. I'm planning to use it for Full diff: https://github.com/llvm/llvm-project/pull/136223.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 66c8731ec2bf4..7a58e4fc2f984 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -31,6 +31,18 @@ LogicalResult oneToOneRewrite(
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
} // namespace detail
+
+/// Decomposes a `src` value into a set of values of type `dstType` through
+/// series of bitcasts and vector ops. Src and dst types are expected to be int
+/// or float types or vector types of them.
+SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
+ Type dstType);
+
+/// Composes a set of `src` values into a single value of type `dstType` through
+/// series of bitcasts and vector ops. Inversely to `decomposeValue`, this
+/// function is used to combine multiple values into a single value.
+Value composeValue(OpBuilder &builder, Location loc, ValueRange src,
+ Type dstType);
} // namespace LLVM
/// Base class for operation conversions targeting the LLVM IR dialect. It
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5f697bdeef566..5c4c95699142f 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1377,6 +1377,38 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
}
};
+struct AMDGPUSwizzleBitModeLowering
+ : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Type i32 = rewriter.getI32Type();
+ Value src = adaptor.getSrc();
+ SmallVector<Value> decomposed =
+ LLVM::decomposeValue(rewriter, loc, src, i32);
+ unsigned andMask = op.getAndMask();
+ unsigned orMask = op.getOrMask();
+ unsigned xorMask = op.getXorMask();
+
+ // bit 15 is 0 for the BitMode swizzle.
+ unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
+ Value maskValue = createI32Constant(rewriter, loc, mask);
+ SmallVector<Value> swizzled;
+ for (Value v : decomposed) {
+ Value res =
+ rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
+ swizzled.emplace_back(res);
+ }
+
+ Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
using Base::Base;
@@ -1444,4 +1476,5 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
GatherToLDSOpLowering>(converter, chipset);
+ patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 32bfd72475569..d2737c56369d1 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -381,3 +381,95 @@ LogicalResult LLVM::detail::oneToOneRewrite(
rewriter.replaceOp(op, results);
return success();
}
+
+static unsigned getBitWidth(Type type) {
+ if (type.isIntOrFloat())
+ return type.getIntOrFloatBitWidth();
+
+ auto vec = cast<VectorType>(type);
+ return vec.getNumElements() * getBitWidth(vec.getElementType());
+}
+
+static Value createI32Constant(OpBuilder &builder, Location loc,
+ int32_t value) {
+ Type i32 = builder.getI32Type();
+ return builder.create<LLVM::ConstantOp>(loc, i32, value);
+}
+
+SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
+ Value src, Type dstType) {
+ Type srcType = src.getType();
+ if (srcType == dstType)
+ return {src};
+
+ unsigned srcBitWidth = getBitWidth(srcType);
+ unsigned dstBitWidth = getBitWidth(dstType);
+ if (srcBitWidth == dstBitWidth) {
+ Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src);
+ return {cast};
+ }
+
+ if (dstBitWidth > srcBitWidth) {
+ auto smallerInt = builder.getIntegerType(srcBitWidth);
+ if (srcType != smallerInt)
+ src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src);
+
+ auto largerInt = builder.getIntegerType(dstBitWidth);
+ Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src);
+ return {res};
+ }
+ assert(srcBitWidth % dstBitWidth == 0 &&
+ "src bit width must be a multiple of dst bit width");
+ int64_t numElements = srcBitWidth / dstBitWidth;
+ auto vecType = VectorType::get(numElements, dstType);
+
+ src = builder.create<LLVM::BitcastOp>(loc, vecType, src);
+
+ SmallVector<Value> res;
+ for (auto i : llvm::seq<int64_t>(0, numElements)) {
+ Value idx = createI32Constant(builder, loc, i);
+ Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx);
+ res.emplace_back(elem);
+ }
+
+ return res;
+}
+
+Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
+ Type dstType) {
+ assert(!src.empty() && "src range must not be empty");
+ if (src.size() == 1) {
+ Value res = src.front();
+ if (res.getType() == dstType)
+ return res;
+
+ unsigned srcBitWidth = getBitWidth(res.getType());
+ unsigned dstBitWidth = getBitWidth(dstType);
+ if (dstBitWidth < srcBitWidth) {
+ auto largerInt = builder.getIntegerType(srcBitWidth);
+ if (res.getType() != largerInt)
+ res = builder.create<LLVM::BitcastOp>(loc, largerInt, res);
+
+ auto smallerInt = builder.getIntegerType(dstBitWidth);
+ res = builder.create<LLVM::TruncOp>(loc, smallerInt, res);
+ }
+
+ if (res.getType() != dstType)
+ res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+
+ return res;
+ }
+
+ int64_t numElements = src.size();
+ auto srcType = VectorType::get(numElements, src.front().getType());
+ Value res = builder.create<LLVM::PoisonOp>(loc, srcType);
+ for (auto &&[i, elem] : llvm::enumerate(src)) {
+ Value idx = createI32Constant(builder, loc, i);
+ res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx);
+ }
+
+ if (res.getType() != dstType)
+ res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+
+ return res;
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir b/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
new file mode 100644
index 0000000000000..ef439efde1bd0
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -convert-amdgpu-to-rocdl --canonicalize %s | FileCheck %s
+
+// CHECK-LABEL: func @test_swizzle_i32
+// CHECK-SAME: (%[[ARG0:.*]]: i32)
+func.func @test_swizzle_i32(%arg0 : i32) -> i32 {
+// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ARG0]], %[[C]] : (i32, i32) -> i32
+// CHECK: return %[[RES]] : i32
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: func @test_swizzle_f32
+// CHECK-SAME: (%[[ARG0:.*]]: f32)
+func.func @test_swizzle_f32(%arg0 : f32) -> f32 {
+// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
+// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[CAST]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK: return %[[RES_CAST]] : f32
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func @test_swizzle_f16
+// CHECK-SAME: (%[[ARG0:.*]]: f16)
+func.func @test_swizzle_f16(%arg0 : f16) -> f16 {
+// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
+// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
+// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ZEXT]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
+// CHECK: return %[[RES_CAST]] : f16
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: func @test_swizzle_2xi32
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
+func.func @test_swizzle_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
+// CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[E0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[E1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: return %[[V3]] : vector<2xi32>
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @test_swizzle_4xf16
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
+func.func @test_swizzle_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
+// CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
+// CHECK: %[[E0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[E1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
+// CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[V3]] : vector<2xi32> to vector<4xf16>
+// CHECK: return %[[CAST2]] : vector<4xf16>
+ %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<4xf16>
+ return %0 : vector<4xf16>
+}
|
| unsigned orMask = op.getOrMask(); | ||
| unsigned xorMask = op.getXorMask(); | ||
|
|
||
| // bit 15 is 0 for the BitMode swizzle. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we link to the ISA manual or llvm intrinsics here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| if (type.isIntOrFloat()) | ||
| return type.getIntOrFloatBitWidth(); | ||
|
|
||
| auto vec = cast<VectorType>(type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to assert this is not a scalable vector or bail out in some other way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, updated verifier to reject scalable vectors, added assert here.
| src = builder.create<LLVM::BitcastOp>(loc, vecType, src); | ||
|
|
||
| SmallVector<Value> res; | ||
| for (auto i : llvm::seq<int64_t>(0, numElements)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'd use a plain for loop here...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like llvm:seq more than C loops. But switched to 1-arg version.
| } | ||
|
|
||
| if (res.getType() != dstType) | ||
| res = builder.create<LLVM::BitcastOp>(loc, dstType, res); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very cool, didn't know llvm::bitcast can transform from scalar to vector types with same bitwidth
| }; | ||
|
|
||
| struct AMDGPUSwizzleBitModeLowering | ||
| : public ConvertOpToLLVMPattern<SwizzleBitModeOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated to this PR, but I think it'd be more intuitive if the op name isdsSwizzleOp, and we can have the QDMode and the BitMode as a attribute of this op. Additionally we can also reference https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QDMode and BitMode variants have a different offsets format, my idea was to have them as separate ops in AMDGPU, so user won't need to bother with offset bitpacking even if they are both lowered to ds_swizzle eventually.
kuhar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
Repack `amdgpu.swizzle_bitmode` arguments and lower it to `rocdl.ds_swizzle`. Repacking logic is follows: * `sizeof(arg) < sizeof(i32)`: bitcast to integer and zext to i32 and then trunc and bitcast back. * `sizeof(arg) == sizeof(i32)`: just bitcast to i32 and back if not i32 * `sizeof(arg) > sizeof(i32)`: bitcast to `vector<Nxi32>`, extract individual elements and do a series of `rocdl.ds_swizzle` and then compose vector and bitcast back. Added repacking logic to LLVM utils so it can be used elsewhere. I'm planning to use it for `gpu.shuffle` later.
Repack
amdgpu.swizzle_bitmodearguments and lower it torocdl.ds_swizzle.Repacking logic is follows:
sizeof(arg) < sizeof(i32): bitcast to integer and zext to i32 and then trunc and bitcast back.sizeof(arg) == sizeof(i32): just bitcast to i32 and back if not i32sizeof(arg) > sizeof(i32): bitcast tovector<Nxi32>, extract individual elements and do a series ofrocdl.ds_swizzleand then compose vector and bitcast back.Added repacking logic to LLVM utils so it can be used elsewhere. I'm planning to use it for
gpu.shufflelater.