Skip to content

Conversation

@dchigarev
Copy link
Contributor

@dchigarev dchigarev commented Oct 14, 2025

As suggested here the PR adds an optional layout attribute for LoadGather and StoreScatter ops.

For the load-op the attribute describes the layout of the result (ex layout_result_0), and for store-op it describes the layout for the vector-to-store operand (ex layout_operand_0).

The PR also reworks propagate-layout pass to consider perm layout attributes and back-propagate them accordingly.

The helper utility function getDistributeLayoutAttr is reworked to return either layout_operand/result_0 or layout for load/store ops (denepding on which one is set). After an offline discussion decided that the overall utilities layouts API is confusing since it tries to mix permament and temporary layouts. Would need to change it in the future.

@dchigarev
Copy link
Contributor Author

dchigarev commented Oct 14, 2025

Some tests are failing after rebasing on the main branch. Fixing those...
fixed

%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<0> : vector<256x16xindex>
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<1> : vector<256x16xi1>
%load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>, l1_hint = #xegpu.cache_hint<cached>}
%load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>, l1_hint = #xegpu.cache_hint<cached>}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think you need to touch these tests.
These tests assume the temporary layout attributes being assigned, so not related to the permanent layout.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leaved the existing tests untouched, only added new ones to check that the permament layouts are handled properly

@dchigarev dchigarev force-pushed the dchigarev/load_layout branch 2 times, most recently from feb4def to b58eb10 Compare October 23, 2025 08:54
@dchigarev dchigarev force-pushed the dchigarev/load_layout branch from 1ed9416 to 5fd0ba1 Compare October 25, 2025 14:10
/*l2_hint=*/xegpu::CachePolicyAttr{},
/*l3_hint=*/xegpu::CachePolicyAttr{});
/*l3_hint=*/xegpu::CachePolicyAttr{},
/*layout=*/nullptr);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#163071 to fill the gap

@dchigarev dchigarev marked this pull request as ready for review October 26, 2025 13:02
@dchigarev
Copy link
Contributor Author

dchigarev commented Oct 26, 2025

cc @charithaintc @Jianhui-Li

@llvmbot
Copy link
Member

llvmbot commented Oct 26, 2025

@llvm/pr-subscribers-mlir

Author: Dmitry Chigarev (dchigarev)

Changes

As suggested here the PR adds an optional layout attribute for LoadGather and StoreScatter ops.

For the load-op the attribute describes the layout of the result (ex layout_result_0), and for store-op it describes the layout for the vector-to-store operand (ex layout_operand_0).

The PR also reworks xegpu-unroll and wg-to-sg-distribute passes to consider and carry the new layout attribute.

The helper utility function getDistributeLayoutAttr is reworked to return either layout_operand/result_0 or layout for load/store ops (denepding on which one is set) to preserve backward compatibility. The setDistributeLayoutAttr function can now only set the new layout attr.


Patch is 47.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163414.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+20-4)
  • (modified) mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp (+8-4)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+37-4)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+10-2)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+4-4)
  • (modified) mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp (+24-2)
  • (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+5-5)
  • (modified) mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir (+81)
  • (modified) mlir/test/Dialect/XeGPU/subgroup-distribute.mlir (+63)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-blocking.mlir (+125)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir (+45)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+63-2)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 426377fcf598f..4c67856b559b1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -843,7 +843,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
       AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
       OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
-      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
+      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
+      OptionalAttr<DistributeLayoutAttr>:$layout);
   let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value);
 
   let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -895,7 +896,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
                     "IntegerAttr": $chunk_size,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
-                    "xegpu::CachePolicyAttr": $l3_hint)>
+                    "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Type": $value, "Value": $source,
+                    "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+                    "IntegerAttr": $chunk_size,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint,
+                    "xegpu::DistributeLayoutAttr": $layout)>
    ];
 
   let hasVerifier = 1;
@@ -979,7 +987,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
       AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
       OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
-      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
+      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
+      OptionalAttr<DistributeLayoutAttr>:$layout);
 
   let extraClassDeclaration = extraBaseClassDeclaration#[{
     Type getDestType() {
@@ -1030,7 +1039,14 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
                     "IntegerAttr": $chunk_size,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
-                    "xegpu::CachePolicyAttr": $l3_hint)>
+                    "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Value": $value, "Value": $dest,
+                    "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+                    "IntegerAttr": $chunk_size,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint,
+                    "xegpu::DistributeLayoutAttr": $layout)>
    ];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index e2c7d803e5a5e..b5d9323de47a6 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -435,7 +435,8 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
       /*chunk_size=*/IntegerAttr{},
       /*l1_hint=*/xegpu::CachePolicyAttr{},
       /*l2_hint=*/xegpu::CachePolicyAttr{},
-      /*l3_hint=*/xegpu::CachePolicyAttr{});
+      /*l3_hint=*/xegpu::CachePolicyAttr{},
+      /*layout=*/nullptr);
 
   rewriter.replaceOp(readOp, gatherOp.getResult());
   return success();
@@ -469,7 +470,8 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
                                 /*chunk_size=*/IntegerAttr{},
                                 /*l1_hint=*/xegpu::CachePolicyAttr{},
                                 /*l2_hint=*/xegpu::CachePolicyAttr{},
-                                /*l3_hint=*/xegpu::CachePolicyAttr{});
+                                /*l3_hint=*/xegpu::CachePolicyAttr{},
+                                /*layout=*/nullptr);
   rewriter.eraseOp(writeOp);
   return success();
 }
@@ -621,7 +623,8 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
         /*chunk_size=*/IntegerAttr{},
         /*l1_hint=*/xegpu::CachePolicyAttr{},
         /*l2_hint=*/xegpu::CachePolicyAttr{},
-        /*l3_hint=*/xegpu::CachePolicyAttr{});
+        /*l3_hint=*/xegpu::CachePolicyAttr{},
+        /*layout=*/nullptr);
 
     auto selectOp =
         arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
@@ -655,7 +658,8 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
                                   /*chunk_size=*/IntegerAttr{},
                                   /*l1_hint=*/xegpu::CachePolicyAttr{},
                                   /*l2_hint=*/xegpu::CachePolicyAttr{},
-                                  /*l3_hint=*/xegpu::CachePolicyAttr{});
+                                  /*l3_hint=*/xegpu::CachePolicyAttr{},
+                                  /*layout=*/nullptr);
     rewriter.eraseOp(scatterOp);
     return success();
   }
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index abd12e2e69ac0..2a7c7ac7e8cde 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -859,7 +859,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
                          xegpu::CachePolicyAttr l2_hint,
                          xegpu::CachePolicyAttr l3_hint) {
   build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
-        l1_hint, l2_hint, l3_hint);
+        l1_hint, l2_hint, l3_hint, /*layout=*/nullptr);
 }
 
 void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
@@ -875,7 +875,24 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
   auto offset = vector::FromElementsOp::create(builder, loc, type, values);
 
   build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
-        l2_hint, l3_hint);
+        l2_hint, l3_hint, /*layout=*/nullptr);
+}
+
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+                         Type valueType, Value source,
+                         ArrayRef<OpFoldResult> offsets, Value mask,
+                         IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
+                         xegpu::CachePolicyAttr l2_hint,
+                         xegpu::CachePolicyAttr l3_hint,
+                         DistributeLayoutAttr layout) {
+  auto loc = source.getLoc();
+  int64_t size = static_cast<int64_t>(offsets.size());
+  auto type = VectorType::get(size, builder.getIndexType());
+  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+  build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
+        l2_hint, l3_hint, layout);
 }
 
 //===----------------------------------------------------------------------===//
@@ -926,7 +943,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
                            xegpu::CachePolicyAttr l2_hint,
                            xegpu::CachePolicyAttr l3_hint) {
   build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
-        l2_hint, l3_hint);
+        l2_hint, l3_hint, /*layout=*/nullptr);
 }
 
 void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
@@ -944,7 +961,23 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
 
   // Call the correct builder overload that does not expect result types.
   build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
-        l3_hint);
+        l3_hint, /*layout=*/nullptr);
+}
+
+void StoreScatterOp::build(
+    OpBuilder &builder, OperationState &state, Value value, Value dest,
+    ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
+    xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
+    xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
+  auto loc = dest.getLoc();
+  int64_t size = static_cast<int64_t>(offsets.size());
+  auto type = VectorType::get(size, builder.getIndexType());
+  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+  // Call the correct builder overload that does not expect result types.
+  build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
+        l3_hint, layout);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index e6e71cc29a80a..c3bf9606693a8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -678,12 +678,16 @@ struct UnrollLoadGatherOpWithOffset
           pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
     }
 
+    auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
+    if (layout)
+      layout = layout.dropInstData();
+
     SmallVector<Value> newOps;
     for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
       auto newOp = xegpu::LoadGatherOp::create(
           rewriter, loc, newValueTy, op.getSource(), o, m,
           rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
-          op.getL2HintAttr(), op.getL3HintAttr());
+          op.getL2HintAttr(), op.getL3HintAttr(), layout);
       newOps.push_back(newOp);
     }
 
@@ -774,12 +778,16 @@ struct UnrollStoreScatterOpWithOffsets
     SmallVector<Value> convertedValues =
         pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
 
+    auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
+    if (layout)
+      layout = layout.dropInstData();
+
     for (auto [v, o, m] :
          llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
       xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
                                     rewriter.getI64IntegerAttr(chunkSize),
                                     op.getL1HintAttr(), op.getL2HintAttr(),
-                                    op.getL3HintAttr());
+                                    op.getL3HintAttr(), layout);
     }
 
     rewriter.eraseOp(op);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9fc5ad9af5c7b..ceeafbd7bd5e5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -914,9 +914,8 @@ struct WgToSgLoadGatherOpWithOffset
          llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
       auto newLoadOp = xegpu::LoadGatherOp::create(
           rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
-          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
-      xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
-                                     layout.dropSgLayoutAndData());
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
+          layout.dropSgLayoutAndData());
       newLoadOps.push_back(newLoadOp);
     }
     rewriter.replaceOpWithMultiple(op, {newLoadOps});
@@ -964,7 +963,8 @@ struct WgToSgStoreScatterOpWithOffset
              adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
       auto store = xegpu::StoreScatterOp::create(
           rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
-          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
+          layout.dropSgLayoutAndData());
       // Update the layout attribute to drop sg_layout and sg_data.
       if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
           !layout.getEffectiveInstDataAsInt().empty()) {
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index b4605cd7e94d6..6e918f162a5ea 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -105,12 +105,22 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
 std::string xegpu::getLayoutName(const OpOperand &operand) {
   const StringRef prefix("layout_operand_");
   unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
-  return llvm::formatv("{0}{1}", prefix, idx).str();
+  auto owner = operand.getOwner();
+  auto tempLayout = llvm::formatv("{0}{1}", prefix, idx).str();
+  if (isa<StoreScatterOp>(operand.getOwner()) && idx == 0 &&
+      !owner->hasAttr(tempLayout))
+    return "layout";
+  return tempLayout;
 }
 
 std::string xegpu::getLayoutName(const OpResult result) {
   const StringRef prefix = "layout_result_";
-  return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
+  auto owner = result.getOwner();
+  auto tempLayout =
+      llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
+  if (isa<LoadGatherOp>(owner) && !owner->hasAttr(tempLayout))
+    return "layout";
+  return tempLayout;
 }
 
 xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
@@ -144,6 +154,11 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
     std::string layoutName = getLayoutName(result);
     if (defOp->hasAttr(layoutName))
       return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+
+    // check for "permament" layout only after "temporary" layout name lookup
+    // for backward compatibility
+    if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp))
+      return loadGatherOp.getLayoutAttr();
   }
 
   if (auto arg = dyn_cast<BlockArgument>(value)) {
@@ -171,6 +186,13 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
   std::string layoutName = xegpu::getLayoutName(opr);
   if (op->hasAttr(layoutName))
     return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+
+  // check for "permament" layout only after "temporary" layout name lookup
+  // for backward compatibility
+  if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
+    if (auto layout = storeScatterOp.getLayoutAttr())
+      return layout;
+
   return getDistributeLayoutAttr(opr.get());
 }
 
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 30f785ded975a..3e2644beffc35 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -97,7 +97,7 @@ func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor
 // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
 // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]]  {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
 // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
 func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
@@ -122,7 +122,7 @@ func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256
 // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK-NEXT: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %[[CST]] : memref<256xf32>, vector<16xindex> ->
 // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]]  {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> :
 // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, vector<16xi1> -> vector<16xf32>
 func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -167,8 +167,8 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
 // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}>
-// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64,
+// CHECK-SAME: layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
 // CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
 func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
   %1 = arith.constant dense<1>: vector<16xi1>
@@ -186,7 +186,7 @@ func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
 // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
 // CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
-// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+// CHECK-SAME: <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
 // CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
 func.func @scatter_ops(%src: memref<256xf16>) {
   %1 = arith.constant dense<1>: vector<16xi1>
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index f233dff609f2b..e602163ff5aaa 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -462,6 +462,47 @@ gpu.module @xevm_module{
   }
 }
 
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops_chunksize_perm_layout({{.*}}) {
+// CHECK:       %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
+// CHECK:       %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
+// CHECK:       %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
+// CHECK-SAME:    -> (vector<1x8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>) {
+// CHECK:         gpu.yield %{{.*}}, %{{.*}}, %[[OFFSETS]], %[[MASKS]] :
+// CHECK-SAME:      vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  %[[T1:.*]] = xegpu.load %[[W]]#1[%[[W]]#2], %[[W]]#3 <{chunk_size = 8 : i64}>
+// CHECK-SAME:    : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT:  xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2],...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 26, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Dmitry Chigarev (dchigarev)

Changes

As suggested here the PR adds an optional layout attribute for LoadGather and StoreScatter ops.

For the load-op the attribute describes the layout of the result (ex layout_result_0), and for store-op it describes the layout for the vector-to-store operand (ex layout_operand_0).

The PR also reworks xegpu-unroll and wg-to-sg-distribute passes to consider and carry the new layout attribute.

The helper utility function getDistributeLayoutAttr is reworked to return either layout_operand/result_0 or layout for load/store ops (denepding on which one is set) to preserve backward compatibility. The setDistributeLayoutAttr function can now only set the new layout attr.


Patch is 47.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163414.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+20-4)
  • (modified) mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp (+8-4)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+37-4)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+10-2)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+4-4)
  • (modified) mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp (+24-2)
  • (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+5-5)
  • (modified) mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir (+81)
  • (modified) mlir/test/Dialect/XeGPU/subgroup-distribute.mlir (+63)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-blocking.mlir (+125)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir (+45)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+63-2)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 426377fcf598f..4c67856b559b1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -843,7 +843,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
       AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
       OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
-      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
+      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
+      OptionalAttr<DistributeLayoutAttr>:$layout);
   let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value);
 
   let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -895,7 +896,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
                     "IntegerAttr": $chunk_size,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
-                    "xegpu::CachePolicyAttr": $l3_hint)>
+                    "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Type": $value, "Value": $source,
+                    "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+                    "IntegerAttr": $chunk_size,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint,
+                    "xegpu::DistributeLayoutAttr": $layout)>
    ];
 
   let hasVerifier = 1;
@@ -979,7 +987,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
       AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
       OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
-      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
+      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
+      OptionalAttr<DistributeLayoutAttr>:$layout);
 
   let extraClassDeclaration = extraBaseClassDeclaration#[{
     Type getDestType() {
@@ -1030,7 +1039,14 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
                     "IntegerAttr": $chunk_size,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
-                    "xegpu::CachePolicyAttr": $l3_hint)>
+                    "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Value": $value, "Value": $dest,
+                    "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+                    "IntegerAttr": $chunk_size,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint,
+                    "xegpu::DistributeLayoutAttr": $layout)>
    ];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index e2c7d803e5a5e..b5d9323de47a6 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -435,7 +435,8 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
       /*chunk_size=*/IntegerAttr{},
       /*l1_hint=*/xegpu::CachePolicyAttr{},
       /*l2_hint=*/xegpu::CachePolicyAttr{},
-      /*l3_hint=*/xegpu::CachePolicyAttr{});
+      /*l3_hint=*/xegpu::CachePolicyAttr{},
+      /*layout=*/nullptr);
 
   rewriter.replaceOp(readOp, gatherOp.getResult());
   return success();
@@ -469,7 +470,8 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
                                 /*chunk_size=*/IntegerAttr{},
                                 /*l1_hint=*/xegpu::CachePolicyAttr{},
                                 /*l2_hint=*/xegpu::CachePolicyAttr{},
-                                /*l3_hint=*/xegpu::CachePolicyAttr{});
+                                /*l3_hint=*/xegpu::CachePolicyAttr{},
+                                /*layout=*/nullptr);
   rewriter.eraseOp(writeOp);
   return success();
 }
@@ -621,7 +623,8 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
         /*chunk_size=*/IntegerAttr{},
         /*l1_hint=*/xegpu::CachePolicyAttr{},
         /*l2_hint=*/xegpu::CachePolicyAttr{},
-        /*l3_hint=*/xegpu::CachePolicyAttr{});
+        /*l3_hint=*/xegpu::CachePolicyAttr{},
+        /*layout=*/nullptr);
 
     auto selectOp =
         arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
@@ -655,7 +658,8 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
                                   /*chunk_size=*/IntegerAttr{},
                                   /*l1_hint=*/xegpu::CachePolicyAttr{},
                                   /*l2_hint=*/xegpu::CachePolicyAttr{},
-                                  /*l3_hint=*/xegpu::CachePolicyAttr{});
+                                  /*l3_hint=*/xegpu::CachePolicyAttr{},
+                                  /*layout=*/nullptr);
     rewriter.eraseOp(scatterOp);
     return success();
   }
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index abd12e2e69ac0..2a7c7ac7e8cde 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -859,7 +859,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
                          xegpu::CachePolicyAttr l2_hint,
                          xegpu::CachePolicyAttr l3_hint) {
   build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
-        l1_hint, l2_hint, l3_hint);
+        l1_hint, l2_hint, l3_hint, /*layout=*/nullptr);
 }
 
 void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
@@ -875,7 +875,24 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
   auto offset = vector::FromElementsOp::create(builder, loc, type, values);
 
   build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
-        l2_hint, l3_hint);
+        l2_hint, l3_hint, /*layout=*/nullptr);
+}
+
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+                         Type valueType, Value source,
+                         ArrayRef<OpFoldResult> offsets, Value mask,
+                         IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
+                         xegpu::CachePolicyAttr l2_hint,
+                         xegpu::CachePolicyAttr l3_hint,
+                         DistributeLayoutAttr layout) {
+  auto loc = source.getLoc();
+  int64_t size = static_cast<int64_t>(offsets.size());
+  auto type = VectorType::get(size, builder.getIndexType());
+  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+  build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
+        l2_hint, l3_hint, layout);
 }
 
 //===----------------------------------------------------------------------===//
@@ -926,7 +943,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
                            xegpu::CachePolicyAttr l2_hint,
                            xegpu::CachePolicyAttr l3_hint) {
   build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
-        l2_hint, l3_hint);
+        l2_hint, l3_hint, /*layout=*/nullptr);
 }
 
 void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
@@ -944,7 +961,23 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
 
   // Call the correct builder overload that does not expect result types.
   build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
-        l3_hint);
+        l3_hint, /*layout=*/nullptr);
+}
+
+void StoreScatterOp::build(
+    OpBuilder &builder, OperationState &state, Value value, Value dest,
+    ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
+    xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
+    xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
+  auto loc = dest.getLoc();
+  int64_t size = static_cast<int64_t>(offsets.size());
+  auto type = VectorType::get(size, builder.getIndexType());
+  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+  // Call the correct builder overload that does not expect result types.
+  build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
+        l3_hint, layout);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index e6e71cc29a80a..c3bf9606693a8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -678,12 +678,16 @@ struct UnrollLoadGatherOpWithOffset
           pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
     }
 
+    auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
+    if (layout)
+      layout = layout.dropInstData();
+
     SmallVector<Value> newOps;
     for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
       auto newOp = xegpu::LoadGatherOp::create(
           rewriter, loc, newValueTy, op.getSource(), o, m,
           rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
-          op.getL2HintAttr(), op.getL3HintAttr());
+          op.getL2HintAttr(), op.getL3HintAttr(), layout);
       newOps.push_back(newOp);
     }
 
@@ -774,12 +778,16 @@ struct UnrollStoreScatterOpWithOffsets
     SmallVector<Value> convertedValues =
         pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
 
+    auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
+    if (layout)
+      layout = layout.dropInstData();
+
     for (auto [v, o, m] :
          llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
       xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
                                     rewriter.getI64IntegerAttr(chunkSize),
                                     op.getL1HintAttr(), op.getL2HintAttr(),
-                                    op.getL3HintAttr());
+                                    op.getL3HintAttr(), layout);
     }
 
     rewriter.eraseOp(op);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9fc5ad9af5c7b..ceeafbd7bd5e5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -914,9 +914,8 @@ struct WgToSgLoadGatherOpWithOffset
          llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
       auto newLoadOp = xegpu::LoadGatherOp::create(
           rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
-          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
-      xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
-                                     layout.dropSgLayoutAndData());
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
+          layout.dropSgLayoutAndData());
       newLoadOps.push_back(newLoadOp);
     }
     rewriter.replaceOpWithMultiple(op, {newLoadOps});
@@ -964,7 +963,8 @@ struct WgToSgStoreScatterOpWithOffset
              adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
       auto store = xegpu::StoreScatterOp::create(
           rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
-          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
+          layout.dropSgLayoutAndData());
       // Update the layout attribute to drop sg_layout and sg_data.
       if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
           !layout.getEffectiveInstDataAsInt().empty()) {
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index b4605cd7e94d6..6e918f162a5ea 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -105,12 +105,22 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
 std::string xegpu::getLayoutName(const OpOperand &operand) {
   const StringRef prefix("layout_operand_");
   unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
-  return llvm::formatv("{0}{1}", prefix, idx).str();
+  auto owner = operand.getOwner();
+  auto tempLayout = llvm::formatv("{0}{1}", prefix, idx).str();
+  if (isa<StoreScatterOp>(operand.getOwner()) && idx == 0 &&
+      !owner->hasAttr(tempLayout))
+    return "layout";
+  return tempLayout;
 }
 
 std::string xegpu::getLayoutName(const OpResult result) {
   const StringRef prefix = "layout_result_";
-  return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
+  auto owner = result.getOwner();
+  auto tempLayout =
+      llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
+  if (isa<LoadGatherOp>(owner) && !owner->hasAttr(tempLayout))
+    return "layout";
+  return tempLayout;
 }
 
 xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
@@ -144,6 +154,11 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
     std::string layoutName = getLayoutName(result);
     if (defOp->hasAttr(layoutName))
       return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+
+    // check for "permament" layout only after "temporary" layout name lookup
+    // for backward compatibility
+    if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp))
+      return loadGatherOp.getLayoutAttr();
   }
 
   if (auto arg = dyn_cast<BlockArgument>(value)) {
@@ -171,6 +186,13 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
   std::string layoutName = xegpu::getLayoutName(opr);
   if (op->hasAttr(layoutName))
     return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+
+  // check for "permament" layout only after "temporary" layout name lookup
+  // for backward compatibility
+  if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
+    if (auto layout = storeScatterOp.getLayoutAttr())
+      return layout;
+
   return getDistributeLayoutAttr(opr.get());
 }
 
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 30f785ded975a..3e2644beffc35 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -97,7 +97,7 @@ func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor
 // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
 // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]]  {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
 // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
 func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
@@ -122,7 +122,7 @@ func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256
 // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK-NEXT: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %[[CST]] : memref<256xf32>, vector<16xindex> ->
 // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]]  {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> :
 // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, vector<16xi1> -> vector<16xf32>
 func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -167,8 +167,8 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
 // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}>
-// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64,
+// CHECK-SAME: layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
 // CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
 func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
   %1 = arith.constant dense<1>: vector<16xi1>
@@ -186,7 +186,7 @@ func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
 // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
 // CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
-// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+// CHECK-SAME: <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
 // CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
 func.func @scatter_ops(%src: memref<256xf16>) {
   %1 = arith.constant dense<1>: vector<16xi1>
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index f233dff609f2b..e602163ff5aaa 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -462,6 +462,47 @@ gpu.module @xevm_module{
   }
 }
 
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops_chunksize_perm_layout({{.*}}) {
+// CHECK:       %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
+// CHECK:       %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
+// CHECK:       %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
+// CHECK-SAME:    -> (vector<1x8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>) {
+// CHECK:         gpu.yield %{{.*}}, %{{.*}}, %[[OFFSETS]], %[[MASKS]] :
+// CHECK-SAME:      vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  %[[T1:.*]] = xegpu.load %[[W]]#1[%[[W]]#2], %[[W]]#3 <{chunk_size = 8 : i64}>
+// CHECK-SAME:    : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT:  xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2],...
[truncated]

@dchigarev dchigarev requested a review from Jianhui-Li October 26, 2025 13:03
Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks OK now.
We need to agree on the setDistributeLayoutAttr behavior. @charithaintc


// check for "permament" layout only after "temporary" layout name lookup
// for backward compatibility
if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why there is only storescatter, no loadgather? Note that there are both load_matrix and store_matrix.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we should handle the anchore ops first. So all following ops should be proceesed first here.

Load/Store matrix
load/store nd
loadgather/storescatter

I also think store ops can be omitted (but fine to have in code), because they don't return anything. so they can never be used as an OpOperand

Copy link
Contributor Author

@dchigarev dchigarev Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why there is only storescatter, no loadgather? Note that there are both load_matrix and store_matrix.

this function handles OpOperands while the permament layout attr for the load-op describes OpResult, so no reason to access the layout here


// check for "permament" layout only after "temporary" layout name lookup
// for backward compatibility
if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no store_matrix here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think since the store scatter does not return a result it is not needed here. We only need load gather.

@charithaintc charithaintc self-requested a review October 30, 2025 17:33
Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is some confusion about how layout utils work. maybe worth having a meeting on this?

Here is a summary of my understanding.

Before layout propagation: All the anchor ops must have valid layouts (permanent). this is the only requirement.
After Layout propagation: All anchor ops have layouts (permanent). Plus all vector op results should have valid layouts with layout_result_{} attribute notation.

Before any pass that consume layouts: Layouts are forward propagated to op operands from the defining results. After this all ops will have layout_operand_{} attributes as well in addition to above. (This is done in Subgroup distribution already but not in other passes, so be careful when writing/modifying tests)

This whole thing has some redundancy which could be optimized. but this is the current status for now.

OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
OptionalAttr<DistributeLayoutAttr>:$layout);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ideally this needs to be xegpu::LayoutAttr. DistributedLayoutAttr is derivative layout because it could either be pure LayoutAttr or SliceAttr.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed,

btw Load/StoreMatrixOps take DistributeLayoutAttr, I believe we should adjust them as well eventually


// check for "permament" layout only after "temporary" layout name lookup
// for backward compatibility
if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think since the store scatter does not return a result it is not needed here. We only need load gather.

@dchigarev
Copy link
Contributor Author

@charithaintc @Jianhui-Li

updated the PR according to our discussion:

  1. Permament layout attr never replaces the temporary attr, but only works as an addition:
xegpu.load ... <{layout = ...}> {layout_result_0 = ...}
  1. Propagate-layout pass back-propagates the permament layout, e.g.:
//////// Case 1 (xegpu.load doesn't have a permament layout)
// input:
%3 = xegpu.load ...
xegpu.store %3 ... <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}>

// --propagate-layout:
%3 = xegpu.load ... {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>}
xegpu.store %3 ... <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}>

//////// Case 2 (xegpu.load has a permament layout)
// input:
%3 = xegpu.load ... <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
xegpu.store %3 ... <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}>

// --propagate-layout:
// the pass doesn't overwrite the permament layout
%3 = xegpu.load ... <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> {
    layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
xegpu.store %3 ... <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}>
  1. All the passes that work with layouts (wg-to-sg, xegpu-unroll, ...) use temp layout

@dchigarev dchigarev requested a review from Jianhui-Li November 3, 2025 14:37
Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
I think the setDistributeLayoutAttr needs some refactoring so distribution and blocking passes uses the same way, but that will be in the future PR.

OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
OptionalAttr<DistributeLayoutAttr>:$layout);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@dchigarev
Copy link
Contributor Author

@Jianhui-Li @charithaintc

may we merge this one if there are no other comments?

@charithaintc
Copy link
Contributor

@Jianhui-Li @charithaintc

may we merge this one if there are no other comments?

I am ok with it.

@charithaintc charithaintc merged commit 6c563dc into llvm:main Nov 4, 2025
10 checks passed
charithaintc pushed a commit that referenced this pull request Nov 17, 2025
… gather/scatter ops (#167850)

The PR changes the layout attribute type for
`xegpu::LoadGatherOp/StoreScatterOp` from `LayoutAttr` to
`DistributeLayoutAttr` to also support `xegpu.slice` layouts.

Initially we [wanted to restrict slice
layouts](#163414 (comment))
from the attribute, but now it turns out there are actually valid use
cases for that:
```mlir
gpu.func @distribute_load_slice_attr() {
  %2 = memref.alloca() {alignment = 1024} : memref<4096xf32>
  %offset =  arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<0> : vector<256xindex>
  %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<1> : vector<256xi1>

  %3 = xegpu.load %2[%offset], %mask <{chunk_size = 1, layout = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]>>} {
      layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]> 
  } : memref<4096xf32>, vector<256xindex>, vector<256xi1> -> vector<256xf32>

  %4 = vector.broadcast %3 {layout_result_0 =
      #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
  gpu.return
}
```

Signed-off-by: dchigarev <[email protected]>
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Nov 17, 2025
…tr for load gather/scatter ops (#167850)

The PR changes the layout attribute type for
`xegpu::LoadGatherOp/StoreScatterOp` from `LayoutAttr` to
`DistributeLayoutAttr` to also support `xegpu.slice` layouts.

Initially we [wanted to restrict slice
layouts](llvm/llvm-project#163414 (comment))
from the attribute, but now it turns out there are actually valid use
cases for that:
```mlir
gpu.func @distribute_load_slice_attr() {
  %2 = memref.alloca() {alignment = 1024} : memref<4096xf32>
  %offset =  arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<0> : vector<256xindex>
  %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<1> : vector<256xi1>

  %3 = xegpu.load %2[%offset], %mask <{chunk_size = 1, layout = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]>>} {
      layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]>
  } : memref<4096xf32>, vector<256xindex>, vector<256xi1> -> vector<256xf32>

  %4 = vector.broadcast %3 {layout_result_0 =
      #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
  gpu.return
}
```

Signed-off-by: dchigarev <[email protected]>
nekoshirro pushed a commit to nekoshirro/Alchemist-LLVM that referenced this pull request Nov 24, 2025
… gather/scatter ops (#167850)

The PR changes the layout attribute type for
`xegpu::LoadGatherOp/StoreScatterOp` from `LayoutAttr` to
`DistributeLayoutAttr` to also support `xegpu.slice` layouts.

Initially we [wanted to restrict slice
layouts](llvm/llvm-project#163414 (comment))
from the attribute, but now it turns out there are actually valid use
cases for that:
```mlir
gpu.func @distribute_load_slice_attr() {
  %2 = memref.alloca() {alignment = 1024} : memref<4096xf32>
  %offset =  arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<0> : vector<256xindex>
  %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<1> : vector<256xi1>

  %3 = xegpu.load %2[%offset], %mask <{chunk_size = 1, layout = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]>>} {
      layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]> 
  } : memref<4096xf32>, vector<256xindex>, vector<256xi1> -> vector<256xf32>

  %4 = vector.broadcast %3 {layout_result_0 =
      #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
  gpu.return
}
```

Signed-off-by: dchigarev <[email protected]>
aadeshps-mcw pushed a commit to aadeshps-mcw/llvm-project that referenced this pull request Nov 26, 2025
… gather/scatter ops (llvm#167850)

The PR changes the layout attribute type for
`xegpu::LoadGatherOp/StoreScatterOp` from `LayoutAttr` to
`DistributeLayoutAttr` to also support `xegpu.slice` layouts.

Initially we [wanted to restrict slice
layouts](llvm#163414 (comment))
from the attribute, but now it turns out there are actually valid use
cases for that:
```mlir
gpu.func @distribute_load_slice_attr() {
  %2 = memref.alloca() {alignment = 1024} : memref<4096xf32>
  %offset =  arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<0> : vector<256xindex>
  %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<1> : vector<256xi1>

  %3 = xegpu.load %2[%offset], %mask <{chunk_size = 1, layout = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]>>} {
      layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]> 
  } : memref<4096xf32>, vector<256xindex>, vector<256xi1> -> vector<256xf32>

  %4 = vector.broadcast %3 {layout_result_0 =
      #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
  gpu.return
}
```

Signed-off-by: dchigarev <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants