Skip to content

Conversation

@Hardcode84
Copy link
Contributor

Repack amdgpu.swizzle_bitmode arguments and lower it to rocdl.ds_swizzle.

Repacking logic is follows:

  • sizeof(arg) < sizeof(i32): bitcast to integer and zext to i32 and then trunc and bitcast back.
  • sizeof(arg) == sizeof(i32): just bitcast to i32 and back if not i32
  • sizeof(arg) > sizeof(i32): bitcast to vector<Nxi32>, extract individual elements and do a series of rocdl.ds_swizzle and then compose vector and bitcast back.

Added repacking logic to LLVM utils so it can be used elsewhere. I'm planning to use it for gpu.shuffle later.

Repack `amdgpu.swizzle_bitmode` arguments and lower it to `rocdl.ds_swizzle`.

Repacking logic is follows:
* `sizeof(arg) < sizeof(i32)`: bitcast to integer and zext to i32 and then trunc and bitcast back.
* `sizeof(arg) == sizeof(i32)`: just bitcast to i32 and back if not i32
* `sizeof(arg) > sizeof(i32)`: bitcast to `vector<Nxi32>`, extract individual elements and do a series of `rocdl.ds_swizzle` and then compose vector and bitcast back.

Added repacking logic to LLVM utils so it can be used elsewhere. I'm planning to use it for `gpu.shuffle` later.
@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2025

@llvm/pr-subscribers-mlir-amdgpu
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Repack amdgpu.swizzle_bitmode arguments and lower it to rocdl.ds_swizzle.

Repacking logic is follows:

  • sizeof(arg) &lt; sizeof(i32): bitcast to integer and zext to i32 and then trunc and bitcast back.
  • sizeof(arg) == sizeof(i32): just bitcast to i32 and back if not i32
  • sizeof(arg) &gt; sizeof(i32): bitcast to vector&lt;Nxi32&gt;, extract individual elements and do a series of rocdl.ds_swizzle and then compose vector and bitcast back.

Added repacking logic to LLVM utils so it can be used elsewhere. I'm planning to use it for gpu.shuffle later.


Full diff: https://github.com/llvm/llvm-project/pull/136223.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+12)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+33)
  • (modified) mlir/lib/Conversion/LLVMCommon/Pattern.cpp (+92)
  • (added) mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir (+75)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 66c8731ec2bf4..7a58e4fc2f984 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -31,6 +31,18 @@ LogicalResult oneToOneRewrite(
     IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
 
 } // namespace detail
+
+/// Decomposes a `src` value into a set of values of type `dstType` through
+/// series of bitcasts and vector ops. Src and dst types are expected to be int
+/// or float types or vector types of them.
+SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
+                                  Type dstType);
+
+/// Composes a set of `src` values into a single value of type `dstType` through
+/// series of bitcasts and vector ops. Inversely to `decomposeValue`, this
+/// function is used to combine multiple values into a single value.
+Value composeValue(OpBuilder &builder, Location loc, ValueRange src,
+                   Type dstType);
 } // namespace LLVM
 
 /// Base class for operation conversions targeting the LLVM IR dialect. It
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5f697bdeef566..5c4c95699142f 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1377,6 +1377,38 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
   }
 };
 
+struct AMDGPUSwizzleBitModeLowering
+    : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Type i32 = rewriter.getI32Type();
+    Value src = adaptor.getSrc();
+    SmallVector<Value> decomposed =
+        LLVM::decomposeValue(rewriter, loc, src, i32);
+    unsigned andMask = op.getAndMask();
+    unsigned orMask = op.getOrMask();
+    unsigned xorMask = op.getXorMask();
+
+    // bit 15 is 0 for the BitMode swizzle.
+    unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
+    Value maskValue = createI32Constant(rewriter, loc, mask);
+    SmallVector<Value> swizzled;
+    for (Value v : decomposed) {
+      Value res =
+          rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
+      swizzled.emplace_back(res);
+    }
+
+    Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct ConvertAMDGPUToROCDLPass
     : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
   using Base::Base;
@@ -1444,4 +1476,5 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
            PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
            GatherToLDSOpLowering>(converter, chipset);
+  patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 32bfd72475569..d2737c56369d1 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -381,3 +381,95 @@ LogicalResult LLVM::detail::oneToOneRewrite(
   rewriter.replaceOp(op, results);
   return success();
 }
+
+static unsigned getBitWidth(Type type) {
+  if (type.isIntOrFloat())
+    return type.getIntOrFloatBitWidth();
+
+  auto vec = cast<VectorType>(type);
+  return vec.getNumElements() * getBitWidth(vec.getElementType());
+}
+
+static Value createI32Constant(OpBuilder &builder, Location loc,
+                               int32_t value) {
+  Type i32 = builder.getI32Type();
+  return builder.create<LLVM::ConstantOp>(loc, i32, value);
+}
+
+SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
+                                              Value src, Type dstType) {
+  Type srcType = src.getType();
+  if (srcType == dstType)
+    return {src};
+
+  unsigned srcBitWidth = getBitWidth(srcType);
+  unsigned dstBitWidth = getBitWidth(dstType);
+  if (srcBitWidth == dstBitWidth) {
+    Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src);
+    return {cast};
+  }
+
+  if (dstBitWidth > srcBitWidth) {
+    auto smallerInt = builder.getIntegerType(srcBitWidth);
+    if (srcType != smallerInt)
+      src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src);
+
+    auto largerInt = builder.getIntegerType(dstBitWidth);
+    Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src);
+    return {res};
+  }
+  assert(srcBitWidth % dstBitWidth == 0 &&
+         "src bit width must be a multiple of dst bit width");
+  int64_t numElements = srcBitWidth / dstBitWidth;
+  auto vecType = VectorType::get(numElements, dstType);
+
+  src = builder.create<LLVM::BitcastOp>(loc, vecType, src);
+
+  SmallVector<Value> res;
+  for (auto i : llvm::seq<int64_t>(0, numElements)) {
+    Value idx = createI32Constant(builder, loc, i);
+    Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx);
+    res.emplace_back(elem);
+  }
+
+  return res;
+}
+
+Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
+                               Type dstType) {
+  assert(!src.empty() && "src range must not be empty");
+  if (src.size() == 1) {
+    Value res = src.front();
+    if (res.getType() == dstType)
+      return res;
+
+    unsigned srcBitWidth = getBitWidth(res.getType());
+    unsigned dstBitWidth = getBitWidth(dstType);
+    if (dstBitWidth < srcBitWidth) {
+      auto largerInt = builder.getIntegerType(srcBitWidth);
+      if (res.getType() != largerInt)
+        res = builder.create<LLVM::BitcastOp>(loc, largerInt, res);
+
+      auto smallerInt = builder.getIntegerType(dstBitWidth);
+      res = builder.create<LLVM::TruncOp>(loc, smallerInt, res);
+    }
+
+    if (res.getType() != dstType)
+      res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+
+    return res;
+  }
+
+  int64_t numElements = src.size();
+  auto srcType = VectorType::get(numElements, src.front().getType());
+  Value res = builder.create<LLVM::PoisonOp>(loc, srcType);
+  for (auto &&[i, elem] : llvm::enumerate(src)) {
+    Value idx = createI32Constant(builder, loc, i);
+    res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx);
+  }
+
+  if (res.getType() != dstType)
+    res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+
+  return res;
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir b/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
new file mode 100644
index 0000000000000..ef439efde1bd0
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -convert-amdgpu-to-rocdl --canonicalize %s | FileCheck %s
+
+// CHECK-LABEL: func @test_swizzle_i32
+// CHECK-SAME: (%[[ARG0:.*]]: i32)
+func.func @test_swizzle_i32(%arg0 : i32) -> i32 {
+// CHECK:  %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK:  %[[RES:.*]] = rocdl.ds_swizzle %[[ARG0]], %[[C]] : (i32, i32) -> i32
+// CHECK:  return %[[RES]] : i32
+  %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: func @test_swizzle_f32
+// CHECK-SAME: (%[[ARG0:.*]]: f32)
+func.func @test_swizzle_f32(%arg0 : f32) -> f32 {
+// CHECK:  %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK:  %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
+// CHECK:  %[[RES:.*]] = rocdl.ds_swizzle %[[CAST]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK:  return %[[RES_CAST]] : f32
+  %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: func @test_swizzle_f16
+// CHECK-SAME: (%[[ARG0:.*]]: f16)
+func.func @test_swizzle_f16(%arg0 : f16) -> f16 {
+// CHECK:  %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK:  %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
+// CHECK:  %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
+// CHECK:  %[[RES:.*]] = rocdl.ds_swizzle %[[ZEXT]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK:  %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
+// CHECK:  return %[[RES_CAST]] : f16
+  %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f16
+  return %0 : f16
+}
+
+// CHECK-LABEL: func @test_swizzle_2xi32
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
+func.func @test_swizzle_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
+// CHECK-DAG:  %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG:  %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK-DAG:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-DAG:  %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:  %[[E0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
+// CHECK:  %[[E1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK:  %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK:  %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
+// CHECK:  return %[[V3]] : vector<2xi32>
+  %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<2xi32>
+  return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @test_swizzle_4xf16
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
+func.func @test_swizzle_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
+// CHECK-DAG:  %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG:  %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK-DAG:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-DAG:  %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:  %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
+// CHECK:  %[[E0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK:  %[[E1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
+// CHECK:  %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK:  %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
+// CHECK:  %[[CAST2:.*]] = llvm.bitcast %[[V3]] : vector<2xi32> to vector<4xf16>
+// CHECK:  return %[[CAST2]] : vector<4xf16>
+  %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<4xf16>
+  return %0 : vector<4xf16>
+}

@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Ivan Butygin (Hardcode84)

Changes

Repack amdgpu.swizzle_bitmode arguments and lower it to rocdl.ds_swizzle.

Repacking logic is follows:

  • sizeof(arg) &lt; sizeof(i32): bitcast to integer and zext to i32 and then trunc and bitcast back.
  • sizeof(arg) == sizeof(i32): just bitcast to i32 and back if not i32
  • sizeof(arg) &gt; sizeof(i32): bitcast to vector&lt;Nxi32&gt;, extract individual elements and do a series of rocdl.ds_swizzle and then compose vector and bitcast back.

Added repacking logic to LLVM utils so it can be used elsewhere. I'm planning to use it for gpu.shuffle later.


Full diff: https://github.com/llvm/llvm-project/pull/136223.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+12)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+33)
  • (modified) mlir/lib/Conversion/LLVMCommon/Pattern.cpp (+92)
  • (added) mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir (+75)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 66c8731ec2bf4..7a58e4fc2f984 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -31,6 +31,18 @@ LogicalResult oneToOneRewrite(
     IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
 
 } // namespace detail
+
+/// Decomposes a `src` value into a set of values of type `dstType` through
+/// series of bitcasts and vector ops. Src and dst types are expected to be int
+/// or float types or vector types of them.
+SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
+                                  Type dstType);
+
+/// Composes a set of `src` values into a single value of type `dstType` through
+/// series of bitcasts and vector ops. Inversely to `decomposeValue`, this
+/// function is used to combine multiple values into a single value.
+Value composeValue(OpBuilder &builder, Location loc, ValueRange src,
+                   Type dstType);
 } // namespace LLVM
 
 /// Base class for operation conversions targeting the LLVM IR dialect. It
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 5f697bdeef566..5c4c95699142f 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1377,6 +1377,38 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
   }
 };
 
+struct AMDGPUSwizzleBitModeLowering
+    : public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Type i32 = rewriter.getI32Type();
+    Value src = adaptor.getSrc();
+    SmallVector<Value> decomposed =
+        LLVM::decomposeValue(rewriter, loc, src, i32);
+    unsigned andMask = op.getAndMask();
+    unsigned orMask = op.getOrMask();
+    unsigned xorMask = op.getXorMask();
+
+    // bit 15 is 0 for the BitMode swizzle.
+    unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
+    Value maskValue = createI32Constant(rewriter, loc, mask);
+    SmallVector<Value> swizzled;
+    for (Value v : decomposed) {
+      Value res =
+          rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
+      swizzled.emplace_back(res);
+    }
+
+    Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct ConvertAMDGPUToROCDLPass
     : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
   using Base::Base;
@@ -1444,4 +1476,5 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
            PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
            GatherToLDSOpLowering>(converter, chipset);
+  patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 32bfd72475569..d2737c56369d1 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -381,3 +381,95 @@ LogicalResult LLVM::detail::oneToOneRewrite(
   rewriter.replaceOp(op, results);
   return success();
 }
+
+static unsigned getBitWidth(Type type) {
+  if (type.isIntOrFloat())
+    return type.getIntOrFloatBitWidth();
+
+  auto vec = cast<VectorType>(type);
+  return vec.getNumElements() * getBitWidth(vec.getElementType());
+}
+
+static Value createI32Constant(OpBuilder &builder, Location loc,
+                               int32_t value) {
+  Type i32 = builder.getI32Type();
+  return builder.create<LLVM::ConstantOp>(loc, i32, value);
+}
+
+SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
+                                              Value src, Type dstType) {
+  Type srcType = src.getType();
+  if (srcType == dstType)
+    return {src};
+
+  unsigned srcBitWidth = getBitWidth(srcType);
+  unsigned dstBitWidth = getBitWidth(dstType);
+  if (srcBitWidth == dstBitWidth) {
+    Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src);
+    return {cast};
+  }
+
+  if (dstBitWidth > srcBitWidth) {
+    auto smallerInt = builder.getIntegerType(srcBitWidth);
+    if (srcType != smallerInt)
+      src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src);
+
+    auto largerInt = builder.getIntegerType(dstBitWidth);
+    Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src);
+    return {res};
+  }
+  assert(srcBitWidth % dstBitWidth == 0 &&
+         "src bit width must be a multiple of dst bit width");
+  int64_t numElements = srcBitWidth / dstBitWidth;
+  auto vecType = VectorType::get(numElements, dstType);
+
+  src = builder.create<LLVM::BitcastOp>(loc, vecType, src);
+
+  SmallVector<Value> res;
+  for (auto i : llvm::seq<int64_t>(0, numElements)) {
+    Value idx = createI32Constant(builder, loc, i);
+    Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx);
+    res.emplace_back(elem);
+  }
+
+  return res;
+}
+
+Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
+                               Type dstType) {
+  assert(!src.empty() && "src range must not be empty");
+  if (src.size() == 1) {
+    Value res = src.front();
+    if (res.getType() == dstType)
+      return res;
+
+    unsigned srcBitWidth = getBitWidth(res.getType());
+    unsigned dstBitWidth = getBitWidth(dstType);
+    if (dstBitWidth < srcBitWidth) {
+      auto largerInt = builder.getIntegerType(srcBitWidth);
+      if (res.getType() != largerInt)
+        res = builder.create<LLVM::BitcastOp>(loc, largerInt, res);
+
+      auto smallerInt = builder.getIntegerType(dstBitWidth);
+      res = builder.create<LLVM::TruncOp>(loc, smallerInt, res);
+    }
+
+    if (res.getType() != dstType)
+      res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+
+    return res;
+  }
+
+  int64_t numElements = src.size();
+  auto srcType = VectorType::get(numElements, src.front().getType());
+  Value res = builder.create<LLVM::PoisonOp>(loc, srcType);
+  for (auto &&[i, elem] : llvm::enumerate(src)) {
+    Value idx = createI32Constant(builder, loc, i);
+    res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx);
+  }
+
+  if (res.getType() != dstType)
+    res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+
+  return res;
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir b/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
new file mode 100644
index 0000000000000..ef439efde1bd0
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/swizzle.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -convert-amdgpu-to-rocdl --canonicalize %s | FileCheck %s
+
+// CHECK-LABEL: func @test_swizzle_i32
+// CHECK-SAME: (%[[ARG0:.*]]: i32)
+func.func @test_swizzle_i32(%arg0 : i32) -> i32 {
+// CHECK:  %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK:  %[[RES:.*]] = rocdl.ds_swizzle %[[ARG0]], %[[C]] : (i32, i32) -> i32
+// CHECK:  return %[[RES]] : i32
+  %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: func @test_swizzle_f32
+// CHECK-SAME: (%[[ARG0:.*]]: f32)
+func.func @test_swizzle_f32(%arg0 : f32) -> f32 {
+// CHECK:  %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK:  %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
+// CHECK:  %[[RES:.*]] = rocdl.ds_swizzle %[[CAST]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK:  return %[[RES_CAST]] : f32
+  %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: func @test_swizzle_f16
+// CHECK-SAME: (%[[ARG0:.*]]: f16)
+func.func @test_swizzle_f16(%arg0 : f16) -> f16 {
+// CHECK:  %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK:  %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
+// CHECK:  %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
+// CHECK:  %[[RES:.*]] = rocdl.ds_swizzle %[[ZEXT]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK:  %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
+// CHECK:  return %[[RES_CAST]] : f16
+  %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f16
+  return %0 : f16
+}
+
+// CHECK-LABEL: func @test_swizzle_2xi32
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
+func.func @test_swizzle_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
+// CHECK-DAG:  %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG:  %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK-DAG:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-DAG:  %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:  %[[E0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
+// CHECK:  %[[E1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK:  %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK:  %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
+// CHECK:  return %[[V3]] : vector<2xi32>
+  %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<2xi32>
+  return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @test_swizzle_4xf16
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
+func.func @test_swizzle_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
+// CHECK-DAG:  %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG:  %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
+// CHECK-DAG:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK-DAG:  %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:  %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
+// CHECK:  %[[E0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK:  %[[E1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
+// CHECK:  %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
+// CHECK:  %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK:  %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
+// CHECK:  %[[CAST2:.*]] = llvm.bitcast %[[V3]] : vector<2xi32> to vector<4xf16>
+// CHECK:  return %[[CAST2]] : vector<4xf16>
+  %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<4xf16>
+  return %0 : vector<4xf16>
+}

unsigned orMask = op.getOrMask();
unsigned xorMask = op.getXorMask();

// bit 15 is 0 for the BitMode swizzle.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we link to the ISA manual or llvm intrinsics here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (type.isIntOrFloat())
return type.getIntOrFloatBitWidth();

auto vec = cast<VectorType>(type);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to assert this is not a scalable vector or bail out in some other way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, updated verifier to reject scalable vectors, added assert here.

src = builder.create<LLVM::BitcastOp>(loc, vecType, src);

SmallVector<Value> res;
for (auto i : llvm::seq<int64_t>(0, numElements)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd use a plain for loop here...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like llvm:seq more than C loops. But switched to 1-arg version.

}

if (res.getType() != dstType)
res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very cool, didn't know llvm::bitcast can transform from scalar to vector types with same bitwidth

};

struct AMDGPUSwizzleBitModeLowering
: public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated to this PR, but I think it'd be more intuitive if the op name isdsSwizzleOp, and we can have the QDMode and the BitMode as a attribute of this op. Additionally we can also reference https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QDMode and BitMode variants have a different offsets format, my idea was to have them as separate ops in AMDGPU, so user won't need to bother with offset bitpacking even if they are both lowered to ds_swizzle eventually.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@Hardcode84 Hardcode84 merged commit dda4b96 into llvm:main Apr 18, 2025
12 checks passed
@Hardcode84 Hardcode84 deleted the amdgpu_swizzle_lowering branch April 18, 2025 14:19
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
Repack `amdgpu.swizzle_bitmode` arguments and lower it to
`rocdl.ds_swizzle`.

Repacking logic is follows:
* `sizeof(arg) < sizeof(i32)`: bitcast to integer and zext to i32 and
then trunc and bitcast back.
* `sizeof(arg) == sizeof(i32)`: just bitcast to i32 and back if not i32
* `sizeof(arg) > sizeof(i32)`: bitcast to `vector<Nxi32>`, extract
individual elements and do a series of `rocdl.ds_swizzle` and then
compose vector and bitcast back.

Added repacking logic to LLVM utils so it can be used elsewhere. I'm
planning to use it for `gpu.shuffle` later.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants