Skip to content

Conversation

@nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Oct 14, 2025

This PR changes the pack/unpack method used for unrolling to allow for lower rank slice to be extracted and inserted from and to src vector by adding reshapes. It also removes leading unit dims from inst_data if there are any.

@nbpatel nbpatel requested a review from Jianhui-Li October 15, 2025 14:01
@nbpatel nbpatel marked this pull request as ready for review October 15, 2025 14:01
@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

This PR changes the pack/unpack method used for unrolling to allow for lower rank slice to be extracted and inserted from and to src vector by adding reshapes. It also removes leading unit dims from inst_data if there are any.


Full diff: https://github.com/llvm/llvm-project/pull/163459.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp (+8-3)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+1-5)
  • (modified) mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp (+23-4)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-blocking.mlir (+43)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index f77784abaf0b2..48831728ad624 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -145,8 +145,13 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
   xegpu::DistributeLayoutAttr layout =
       xegpu::getDistributeLayoutAttr(operandOrResult);
   if (layout && layout.isForSubgroup()) {
-    if (!layout.getEffectiveInstDataAsInt().empty())
-      return layout.getEffectiveInstDataAsInt();
+    if (!layout.getEffectiveInstDataAsInt().empty()) {
+      SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
+      // Remove all leading unit dimensions from inst_data
+      while (!instData.empty() && instData.front() == 1)
+        instData.erase(instData.begin());
+      return instData;
+    }
 
     if (auto type = dyn_cast<ShapedType>(value.getType()))
       return llvm::to_vector(type.getShape());
@@ -363,7 +368,7 @@ void XeGPUBlockingPass::runOnOperation() {
           xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
                                      tdescTy.getLayoutAttr().dropInstData());
     } else {
-      newTy = type.clone(tileShape, elemTy);
+      newTy = VectorType::get(tileShape, elemTy);
     }
 
     if (returnSingleType)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index a178d0fe4b0b0..75b215c320e54 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -66,8 +66,6 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
   Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
                Location loc, PatternRewriter &rewriter) const {
     if (auto vecTy = dyn_cast<VectorType>(destTy)) {
-      assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
-             "Expecting blockSize size to match the rank of destTy.");
       auto shape = vecTy.getShape();
       return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
     }
@@ -93,8 +91,6 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
                           ArrayRef<int64_t> blockSize, Location loc,
                           PatternRewriter &rewriter) const {
     if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
-      assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
-             "Expecting blockSize size to match the rank of src.");
       return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
                                                      blockSize);
     }
@@ -635,7 +631,7 @@ struct UnrollLoadGatherOpWithOffset
     VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
     VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
     Type elemTy = valueTy.getElementType();
-    VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+    VectorType newValueTy = VectorType::get(*targetShape, elemTy);
 
     SmallVector<Type> convertedMaskTypes;
     SmallVector<Value> convertedMasks;
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 2c56a438ea62c..40013eb161678 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -246,11 +246,30 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
   if (!computeShapeRatio(srcShape, shape))
     return {value};
 
+  int64_t srcShapeRank = srcShape.size();
+  int64_t targetShapeRank = shape.size();
+
+  SmallVector<int64_t> adjustedTargetShape(srcShape.size());
+  int64_t rankDiff = srcShapeRank - targetShapeRank;
+  std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
+            1);
+  std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff);
+
+  int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
+
   SmallVector<Value> result;
-  for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) {
+  for (SmallVector<int64_t> offsets :
+       StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
     SmallVector<int64_t> staticStrides(offsets.size(), 1);
-    result.push_back(vector::ExtractStridedSliceOp::create(
-        builder, loc, value, offsets, shape, staticStrides));
+    Value slice = vector::ExtractStridedSliceOp::create(
+        builder, loc, value, offsets, adjustedTargetShape, staticStrides);
+
+    // Reshape to remove leading unit dims if needed
+    if (adjustedTargetShapeRank > targetShapeRank) {
+      auto targetTy = VectorType::get(shape, vecTy.getElementType());
+      slice = builder.create<vector::ShapeCastOp>(loc, targetTy, slice);
+    }
+    result.push_back(slice);
   }
 
   return result;
@@ -274,7 +293,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
 
   for (auto [src, offsets] :
        llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
-    SmallVector<int64_t> staticStrides(offsets.size(), 1);
+    SmallVector<int64_t> staticStrides(tileShape.size(), 1);
     result = vector::InsertStridedSliceOp::create(builder, loc, src, result,
                                                   offsets, staticStrides);
   }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index fe4f44c0b02ab..6301533da640d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -682,3 +682,46 @@ gpu.module @test_kernel {
     gpu.return
   }
 }
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: load_gather
+  // CHECK-COUNT-2: xegpu.load  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+  gpu.func @load_gather(%src: ui64) -> vector<1x1x32xf32> {
+      %cst = arith.constant dense<[[
+      [0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248]
+      ]]> : vector<1x1x32xindex>
+
+      %mask = arith.constant dense<true> : vector<1x1x32xi1>
+      %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
+
+      gpu.return %ld : vector<1x1x32xf32>
+  }
+}
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: store_scatter
+  // CHECK-COUNT-2: xegpu.store  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
+  gpu.func @store_scatter(%src: ui64) {
+      %cst = arith.constant dense<[[
+      [0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248]
+      ]]> : vector<1x1x32xindex>
+
+      %mask = arith.constant dense<true> : vector<1x1x32xi1>
+
+      %st_vec = arith.constant dense<1023.0>: vector<1x1x32xf32>
+      xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<inst_data = [1, 1, 16]>,
+                                              layout_operand_2 = #xegpu.layout<inst_data = [1, 1, 16]>,
+                                              layout_operand_3 = #xegpu.layout<inst_data = [1, 1, 16]>,
+                                              l1_hint = #xegpu.cache_hint<cached>} : vector<1x1x32xf32>, ui64, vector<1x1x32xindex>, vector<1x1x32xi1>
+
+      gpu.return
+  }
+}

Copy link
Contributor

@akroviakov akroviakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +735 to +737
// CHECK: [[ins1:%.+]] = vector.insert_strided_slice [[add1]], [[ins0]] {offsets = [0, 16], strides = [1]} : vector<16xf32> into vector<1x32xf32>
// CHECK: [[ext0:%.+]] = vector.extract_strided_slice [[ins1]] {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
// CHECK: [[ext1:%.+]] = vector.extract_strided_slice [[ins1]] {offsets = [0, 16], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: canonicalize should remove these extracts and inserts. adding canonicalize may simplify your tests.

@nbpatel nbpatel merged commit 621ed04 into llvm:main Oct 24, 2025
10 checks passed
dvbuka pushed a commit to dvbuka/llvm-project that referenced this pull request Oct 27, 2025
This PR changes the pack/unpack method used for unrolling to allow for
lower rank slice to be extracted and inserted from and to src vector by
adding reshapes. It also removes leading unit dims from inst_data if
there are any.
Lukacma pushed a commit to Lukacma/llvm-project that referenced this pull request Oct 29, 2025
This PR changes the pack/unpack method used for unrolling to allow for
lower rank slice to be extracted and inserted from and to src vector by
adding reshapes. It also removes leading unit dims from inst_data if
there are any.
aokblast pushed a commit to aokblast/llvm-project that referenced this pull request Oct 30, 2025
This PR changes the pack/unpack method used for unrolling to allow for
lower rank slice to be extracted and inserted from and to src vector by
adding reshapes. It also removes leading unit dims from inst_data if
there are any.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants