Skip to content

Conversation

@akroviakov
Copy link
Contributor

@akroviakov akroviakov commented Oct 16, 2025

This PR is the first, minimally working, and somewhat crude application of the xegpu::uArch infra in the uarch-sensitive parts of XeGPU. We completely remove the XeGPUTargetInfo.h and rely on the attached target to the GPU module. Due to the early uncertainty of the design, I only consider pvc, and all of the new instructions provide the minimal interface.

This PR adds support for inst_data default setting for store_nd, prefetch_nd, dpas, scatter, gather. See propagate-layout-inst-data.mlir.

Some points to consider:

  • LLVM does not use C++ RTTI, its own dynamic polymorphism requires manually amending the types for dyn_cast to work. This becomes crucial if you want to check for or get a specific uArch. An example of it can be found in requireTranspose() in XeGPUSubgroupDistribute.cpp where we still need to check for hardcoded strings, instead of isa<>/dyn_cast<>.
  • Should uArch be exposed via a shared pointer to a constant structure or as a reference to a constant static structure? It depends on whether we allow for some fallback (chip string is not present or no uarch found for it, i.e., if(!uArch){...} ) or strictly require a valid uArch (then an absent/invalid uArch is not even possible, llvm_unreachable in getUArch()). Generally, I lean towards trying a reference to a static constant for simplicity, but the uArch exposure to use cases may be too small to judge yet. The shared_ptr version currently remains as the most flexible one.
  • Type verification (XeGPUDialect.cpp, TensorDescType::verify) gets trickier, because there is no way to get the target attribute. I'd be glad to hear some feedback on it. For now, I use a constexpr placeholder.
  • All anchor ops need uArch, at least to query the subgroup size for getDefaultSIMTLayoutInfo. Inst_data is not part of getDefaultSIMTLayoutInfo, because it requires querying a specific operation in uArch.

I am eager to gather feedback both for the uArch API and usage in general, so feel free to ask and propose changes.

My impression is that uArch is a nice tool for our purposes, but from the initial experience, as of now, it appears a bit bloated (two std maps of shared pointers to instructions per instance of uArch for seemingly compile time constant data) and tricky to use (e.g., no support for LLVM polymorphism). I may have missed some justifications for this in the uArch PR though, so it would be good to reiterate here.

@llvmbot
Copy link
Member

llvmbot commented Oct 16, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Artem Kroviakov (akroviakov)

Changes

This PR is the first, minimally working, and somewhat crude application of the xegpu::uArch infra in the uarch-sensitive parts of XeGPU. We completely remove the XeGPUTargetInfo.h and rely on the attached target to the GPU module. Due to the early uncertainty of the design, I only consider pvc, and all of the new instructions provide the minimal interface.

This PR adds support for inst_data default setting for store_nd, prefetch_nd, dpas, scatter, gather

Some points to consider:

  • LLVM does not use C++ RTTI, its own dynamic polymorphism requires manually amending the types for dyn_cast to work. This becomes crucial if you want to check for or get a specific uArch. An example of it can be found in requireTranspose() in XeGPUSubgroupDistribute.cpp where we still need to check for hardcoded strings, instead of isa&lt;&gt;/dyn_cast&lt;&gt;.
  • Should uArch be exposed via a shared pointer to a constant structure or as a reference to a constant static structure? It depends on whether we allow for some fallback (chip string is not present or no uarch found for it, i.e., if(!uArch){...} ) or strictly require a valid uArch (then an invalid uArch is not even possible, llvm_unreachable in getUArch()). Generally, I lean towards trying a reference to a static constant for simplicity, but the uArch exposure to use cases may still be too small to judge yet. The shared_ptr version currently remains as the most flexible one.
  • Type verification (XeGPUDialect.cpp, TensorDescType::verify) gets trickier, because there is no way to get the target attribute. I'd be glad to hear some feedback on it. For now, I use a constexpr placeholder.
  • All anchor ops need uArch, at least to query the subgroup size for getDefaultSIMTLayoutInfo. Inst_data is not part of getDefaultSIMTLayoutInfo, because it requires querying a specific operation in uArch.

I am eager to gather feedback both for the uArch API and usage in general, so feel free to ask and propose changes.

My impression is that uArch is a nice for our purposes, but from the initial experience, as of now, it appears a bit bloated (two std maps of shared pointers to instructions per instance of uArch for seemingly compile time constant data) and tricky to use (e.g., no support for LLVM polymorphism). I may have missed some justifications for this in the uArch PR though, so it would be good to reiterate here.


Patch is 57.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163801.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+23-11)
  • (removed) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h (-30)
  • (modified) mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td (+6-1)
  • (modified) mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h (+74-4)
  • (modified) mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h (+15-2)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+9-7)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+164-62)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp (+16-10)
  • (modified) mlir/test/Dialect/XeGPU/move-gpu-func-to-warp-op.mlir (+1-1)
  • (added) mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir (+51)
  • (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+59-23)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 5695d5d515d7f..ec236d702de0d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -379,29 +379,41 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
   );
 
   let builders = [
-    AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
+    AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $inst_data,
+                      "llvm::ArrayRef<int32_t>": $lane_layout,
                      "llvm::ArrayRef<int32_t>": $lane_data),
       [{
         auto sg_layout = DenseI32ArrayAttr();
         auto sg_data = DenseI32ArrayAttr();
-        auto inst_data = DenseI32ArrayAttr();
         auto order = DenseI32ArrayAttr();
-        return $_get($_ctxt, sg_layout, sg_data, inst_data,
+        return $_get($_ctxt, sg_layout, sg_data,
+                     DenseI32ArrayAttr::get($_ctxt, inst_data),
                      DenseI32ArrayAttr::get($_ctxt, lane_layout),
                      DenseI32ArrayAttr::get($_ctxt, lane_data), order);
       }]>,
     AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
-                     "llvm::ArrayRef<int32_t>": $lane_data,
-                     "llvm::ArrayRef<int32_t>": $order),
+                     "llvm::ArrayRef<int32_t>": $lane_data),
       [{
-        return $_get($_ctxt,
-                     /*sg_layout =*/ nullptr,
-                     /*sg_data   =*/ nullptr,
-                     /*inst_data =*/ nullptr,
+        auto sg_layout = DenseI32ArrayAttr();
+        auto sg_data = DenseI32ArrayAttr();
+        auto inst_data = DenseI32ArrayAttr();
+        auto order = DenseI32ArrayAttr();
+        return $_get($_ctxt, sg_layout, sg_data, inst_data,
                      DenseI32ArrayAttr::get($_ctxt, lane_layout),
-                     DenseI32ArrayAttr::get($_ctxt, lane_data),
-                     DenseI32ArrayAttr::get($_ctxt, order));
+                     DenseI32ArrayAttr::get($_ctxt, lane_data), order);
       }]>,
+    // AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
+    //                  "llvm::ArrayRef<int32_t>": $lane_data,
+    //                  "llvm::ArrayRef<int32_t>": $order),
+    //   [{
+    //     return $_get($_ctxt,
+    //                  /*sg_layout =*/ nullptr,
+    //                  /*sg_data   =*/ nullptr,
+    //                  /*inst_data =*/ nullptr,
+    //                  DenseI32ArrayAttr::get($_ctxt, lane_layout),
+    //                  DenseI32ArrayAttr::get($_ctxt, lane_data),
+    //                  DenseI32ArrayAttr::get($_ctxt, order));
+    //   }]>,
     AttrBuilder<(ins "DenseI32ArrayAttr": $lane_layout,
                      "DenseI32ArrayAttr": $lane_data,
                      "DenseI32ArrayAttr": $order),
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h
deleted file mode 100644
index 8aa9536cb67c1..0000000000000
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h
+++ /dev/null
@@ -1,30 +0,0 @@
-//===- XeGPUTargetInfo.h - Target constants ---------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_XEGPU_IR_XEGPUTARGETINFO_H_
-#define MLIR_DIALECT_XEGPU_IR_XEGPUTARGETINFO_H_
-
-namespace mlir {
-namespace xegpu {
-/// HW dependent constants.
-/// TODO: These constants should be queried from the target information.
-namespace targetinfo {
-constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup.
-/// If DPAS A or B operands have low precision element types they must be packed
-/// according to the following sizes.
-constexpr unsigned packedSizeInBitsForDefault =
-    16; // Minimum packing size per register for DPAS A.
-constexpr unsigned packedSizeInBitsForDpasB =
-    32; // Minimum packing size per register for DPAS B.
-constexpr unsigned packedSizeInBitsForGatherScatter =
-    32; // Minimum packing size per register for Gather and Scatter ops.
-} // namespace targetinfo
-} // namespace xegpu
-} // namespace mlir
-
-#endif // MLIR_DIALECT_XEGPU_IR_XEGPUTARGETINFO_H_
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 564d9c4d5422b..5ef1d499d618f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -43,7 +43,12 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
   let options = [Option<
     "printOnly", "print-analysis-only", "bool",
     /*default=*/"false",
-    "Print the result of layout propagation analysis and exit.">];
+    "Print the result of layout propagation analysis and exit.">,
+    Option<
+    "assumeUnrolled", "assume-unrolled", "bool",
+    /*default=*/"false",
+    "If the input IR has SG-sized tiles matching instruction sizes, omit `inst_data`.">
+  ];
 }
 
 def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute"> {
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 0519f7b2e277d..5cb6d61336391 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -42,12 +42,59 @@ struct Xe2Plus : public uArch {
               &instrs = {})
       : uArch(archName, archDescription, regInfo, cacheInfo, instrs),
         xeCore(xeCore) {}
+  int getSubgroupSize() const override { return 16; }
+  int getPackedFormatBitSizeGatherScatter() const override { return 32; }
+  int getPackedFormatBitSize() const override { return 16; }
+  std::optional<int> getPackedFormatBitSizeDpasB() const override { return 32; }
+};
+
+//===----------------------------------------------------------------------===//
+// uArch instructions
+//===----------------------------------------------------------------------===//
+struct StoreNdInstruction : public Instruction {
+  StoreNdInstruction()
+      : Instruction(InstructionKind::STORE_ND, InstructionScope::Subgroup) {}
+
+  // Source :
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+  // Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
+  // the specified pointer
+  llvm::SmallVector<int> getSortedLaneVectorLengths() { return {1, 2, 4, 8}; }
+};
+
+struct LoadNdInstruction : public Instruction {
+  LoadNdInstruction()
+      : Instruction(InstructionKind::LOAD_ND, InstructionScope::Subgroup) {}
+
+  // Source :
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+  // Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
+  // the specified pointer.
+  llvm::SmallVector<int> getSortedLaneVectorLengths() { return {1, 2, 4, 8}; }
+};
+
+struct PrefetchNdInstruction : public Instruction {
+  PrefetchNdInstruction()
+      : Instruction(InstructionKind::PREFETCH_ND, InstructionScope::Subgroup) {}
+
+  // Source :
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
+  llvm::SmallVector<int> getSortedLaneVectorLengths(int elementBitwidth) {
+    if (elementBitwidth == 8 || elementBitwidth == 16)
+      return {1, 2, 4, 8, 16};
+    else if (elementBitwidth == 32 || elementBitwidth == 64)
+      return {1, 2, 4, 8};
+    else
+      llvm_unreachable(
+          "Unsupported element bitwidth for PrefetchNdInstruction");
+  }
 };
 
-// struct to represent DPAS instruction
 struct DPASInstruction : public Instruction, public MMAInstructionInterface {
   DPASInstruction()
       : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {}
+  // Source:
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
 
   // Override all virtuals from MatrixOpInterface
   virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
@@ -72,6 +119,9 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
   virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
 };
 
+//===----------------------------------------------------------------------===//
+// uArch instructions
+//===----------------------------------------------------------------------===//
 struct PVCuArch : public Xe2Plus {
   // Maintaines ownership of the instructions owned by PVUarch
   llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
@@ -101,9 +151,15 @@ struct PVCuArch : public Xe2Plus {
         CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
 
     // Add the instructions-
-    auto dpas = std::make_shared<DPASInstruction>();
-    instructions.emplace(dpas->getInstructionKind(), dpas);
-    owned_instructions.push_back(dpas);
+    llvm::SmallVector<std::shared_ptr<Instruction>> instructionsToAdd{
+        std::make_shared<DPASInstruction>(),
+        std::make_shared<StoreNdInstruction>(),
+        std::make_shared<LoadNdInstruction>(),
+        std::make_shared<PrefetchNdInstruction>()};
+    for (auto &inst : instructionsToAdd) {
+      instructions.emplace(inst->getInstructionKind(), inst);
+      owned_instructions.push_back(inst);
+    }
   }
 };
 
@@ -139,10 +195,24 @@ struct BMGuArch : public Xe2Plus {
     owned_instructions.push_back(dpas);
   }
 };
+
+inline std::shared_ptr<uArch> getUArch(const std::string &archName) {
+  if (archName == "pvc")
+    return std::make_shared<PVCuArch>();
+  else if (archName == "bmg")
+    return std::make_shared<BMGuArch>();
+  else
+    return nullptr;
+}
+
 } // namespace uArch
 } // namespace xegpu
 } // namespace mlir
 
+//===----------------------------------------------------------------------===//
+// Instruction implementations
+//===----------------------------------------------------------------------===//
+
 inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
 DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
   auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 955994ea5ecf5..0f5b1282f0e24 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -32,8 +32,11 @@ namespace uArch {
 // An enum class to represent the scope of an instruction
 enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
 enum class InstructionKind {
-  DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
-        // multiply-add operation
+  DPAS,       // Dot Product Accumulate Systolic (DPAS) is a matrix
+              // multiply-add operation
+  STORE_ND,   // Subgroup-level 2D block write instruction
+  LOAD_ND,    // Subgroup-level 2D block load instruction
+  PREFETCH_ND // Subgroup-level 2D block prefetch instruction
   // @TODO: Add more instructions as needed
 };
 
@@ -148,6 +151,16 @@ struct uArch {
 
   const std::string &getDescription() const { return description; }
 
+  virtual int getSubgroupSize() const = 0;
+  virtual int getPackedFormatBitSizeGatherScatter() const = 0;
+  virtual int getPackedFormatBitSize() const = 0;
+  virtual std::optional<int> getPackedFormatBitSizeDpasB() const = 0;
+
+  std::shared_ptr<Instruction> getInstruction(InstructionKind instKind) const {
+    assert(instructions.find(instKind) != instructions.end());
+    return instructions.at(instKind);
+  }
+
   const std::map<RegisterFileType, RegisterFileInfo> &
   getRegisterFileInfo() const {
     return registerFileInfo;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 9beb22d517473..afda04fa71105 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -11,7 +11,7 @@
 #include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -226,8 +226,10 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
   }
 
   if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
-    return emitError()
-           << "expected inst_data and lane_layout to have the same rank";
+    return emitError() << "expected inst_data and lane_layout to have the same "
+                          "rank, got inst_data "
+                       << inst_data.size() << ", lane_layout "
+                       << lane_layout.size();
   }
 
   // sg_data is optional for Workgroup layout, but its presence requires
@@ -565,10 +567,10 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
 
   // for gather and scatter ops, Low-precision types are packed in 32-bit units.
   unsigned bitWidth = elementType.getIntOrFloatBitWidth();
-  int chunkAlignmentFactor =
-      bitWidth < targetinfo::packedSizeInBitsForGatherScatter
-          ? targetinfo::packedSizeInBitsForGatherScatter / bitWidth
-          : 1;
+  constexpr int packingBitSizeGatherScatter{32};
+  int chunkAlignmentFactor = bitWidth < packingBitSizeGatherScatter
+                                 ? packingBitSizeGatherScatter / bitWidth
+                                 : 1;
   auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
   if (scatterAttr) {
     int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 8fab255d6347f..9c09908f3547d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -14,7 +14,6 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/IR/Attributes.h"
@@ -37,6 +36,8 @@
 #include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/raw_ostream.h"
 
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+
 namespace mlir {
 namespace xegpu {
 #define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
@@ -104,6 +105,8 @@ struct LayoutInfo {
 
   SmallVector<int> getLaneData() const;
 
+  SmallVector<int> getInstData() const;
+
   bool isSliceLayout() const {
     if (!isAssigned())
       return false;
@@ -137,6 +140,13 @@ SmallVector<int> LayoutInfo::getLaneData() const {
                              [](int64_t val) { return static_cast<int>(val); });
 }
 
+SmallVector<int> LayoutInfo::getInstData() const {
+  if (!isAssigned())
+    return {};
+  return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
+                             [](int64_t val) { return static_cast<int>(val); });
+}
+
 void LayoutInfo::print(raw_ostream &os) const {
   if (isAssigned()) {
     os << storage;
@@ -174,12 +184,14 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
 
   SmallVector<int32_t> laneLayout;
   SmallVector<int32_t> laneData;
+  SmallVector<int32_t> instData;
   for (int64_t idx : permutation) {
     laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
     laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
+    instData.push_back(static_cast<int32_t>(getInstData()[idx]));
   }
-  return LayoutInfo(
-      xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData));
+  return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData,
+                                           laneLayout, laneData));
 }
 
 //===----------------------------------------------------------------------===//
@@ -199,20 +211,33 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
 /// Helper Function to get the default layout for uniform values like constants.
 /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
 /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
-static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
-                                           unsigned rank) {
+static LayoutInfo
+getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, unsigned rank,
+                         std::shared_ptr<xegpu::uArch::uArch> &uArch,
+                         ArrayRef<int> instData) {
   assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
   if (rank == 1) {
     return LayoutInfo(
-        xegpu::LayoutAttr::get(ctx, {xegpu::targetinfo::subgroupSize}, {1}));
+        xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1}));
   }
   return LayoutInfo(xegpu::LayoutAttr::get(
-      ctx, {1, xegpu::targetinfo::subgroupSize}, {1, 1}));
+      ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1}));
+}
+
+static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
+                                           unsigned rank, int subgroupSize) {
+  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
+  if (rank == 1) {
+    return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
+  }
+  return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
 }
 
 /// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
-                                           bool isScattered = false) {
+static LayoutInfo
+getDefaultSIMTLayoutInfo(VectorType vectorTy,
+                         std::shared_ptr<xegpu::uArch::uArch> &uArch,
+                         ArrayRef<int> instData, bool isScattered = false) {
   // Expecting a 1D or 2D vector.
   assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
          "Expected 1D or 2D vector.");
@@ -221,29 +246,31 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
   if (vectorTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1);
+    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData);
   // Packing factor is determined by the element type bitwidth.
   int packingFactor = 1;
   unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
   if (isScattered) {
     packingFactor =
-        bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
-            ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
+        bitwidth < uArch->getPackedFormatBitSizeGatherScatter()
+            ? uArch->getPackedFormatBitSizeGatherScatter() / bitwidth
             : 1;
-    return LayoutInfo(xegpu::LayoutAttr::get(
-        vectorTy.getContext(), {xegpu::targetinfo::subgroupSize, 1},
-        {1, packingFactor}));
+    return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+                                             {uArch->getSubgroupSize(), 1},
+                                             {1, packingFactor}));
   }
-  if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
-    packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
-  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
-                                           {1, xegpu::targetinfo::subgroupSize},
+  if (bitwidth < uArch->getPackedFormatBitSize())
+    packingFactor = uArch->getPackedFormatBitSize() / bitwidth;
+  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+                                           {1, uArch->getSubgroupSize()},
                                            {1, packingFactor}));
 }
 
 /// Helper to get the default layout for a vector type.
-static LayoutInfo get...
[truncated]

@akroviakov akroviakov requested a review from mshahneo October 16, 2025 14:56
@akroviakov akroviakov force-pushed the akroviak/xegpu-enhance-layout-prop branch from fa08396 to 63d11b8 Compare October 16, 2025 16:23
@akroviakov akroviakov force-pushed the akroviak/xegpu-enhance-layout-prop branch from 63d11b8 to 4b99cdd Compare October 16, 2025 17:00
DenseI32ArrayAttr::get($_ctxt, order));
DenseI32ArrayAttr::get($_ctxt, lane_data), order);
}]>,
// AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, the constructor was not used and its signature conflicted with the new inst_data one, so it can be removed altogether, until we need order somewhere.

auto dpas = std::make_shared<DPASInstruction>();
instructions.emplace(dpas->getInstructionKind(), dpas);
owned_instructions.push_back(dpas);
llvm::SmallVector<std::shared_ptr<Instruction>> instructionsToAdd{

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - formatting

@@ -42,12 +40,61 @@ struct Xe2Plus : public uArch {
&instrs = {})
: uArch(archName, archDescription, regInfo, cacheInfo, instrs),
xeCore(xeCore) {}
int getSubgroupSize() const override { return 16; }
unsigned getPackedFormatBitSizeGatherScatter() const override { return 32; }
unsigned getPackedFormatBitSize() const override { return 16; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is getPackedFormatBitSize really getPackedFormatBitSizeDpasA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will be renamed. And I think it should be a member of dpas instruction per uarch instance. We might want to split this PR into two parts to have a substantial discussion in each: (1) uArch modification and (2) propagation option and uArch application in passes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For C, is it the same as B (32)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For C, the result is f32 so no packing needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a generic lane data calculation for dpas operands, wouldn't the following format be desired in the dpas propagation
packingFactor = dpasInst->getOperand*A/B/C*PackingBitSize() / dataElemBitwidth?

It is not so much about whether we actually consider "packing" C.

"Print the result of layout propagation analysis and exit.">];
"Print the result of layout propagation analysis and exit.">,
Option<
"assumeUnrolled", "assume-unrolled", "bool",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this option be an enumeration, so the propagation could be applied to "lane", "inst", and "subgroup" parameters? High-level implies lower level will be propagated, so "assumeUnrolled = true" can be replaced to "level = lane" here and the options are more extensible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not an enum, but a string.
For subgroup, we must have a user layout on anchor ops to propagate? It's not like lane/inst fields, which are tightly coupled to hw subgroup size and/or instruction size.
Anyway, this is a topic for a different PR. For now, we can do lane and inst.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, yes we expect user to set sg_layout/sg_data.

uArchInstruction->getSupportedM(aTy.getElementType()).back();
const int maxBLen =
uArchInstruction->getSupportedK(bTy.getElementType()).back();
SmallVector<int> instDataA = {maxALen, subgroupSize};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maxAlen and maxBlen need to compare with input operands' size since they need to be multiple of inst_data? It also applied to other places.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add a check.

For the future, should this happen as a verification of the user shape (effectively sg_data), or as part of the max*Len selection?
A user-supplied shape can have a dimension of 12 for an instruction that supports sizes [1,2,4,8]. Using the maximum size 8 fails, but using 4 succeeds (4+4+4). We might also do 8+4, but I suppose then inst_data step needs to be fused with the blocking and not surface in the pre-blocking IR. Were there any plans in this direction, or do we only work with multiples of the max size in the foreseeable future?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a need to support 8+4. For user's input 12, we can't give 8 the max inst_data value as the current logic does. instead, we should give 4.

do we only work with multiples of the max size in the foreseeable future
I don't see it worth to support complex scheme like 8+4. We can stick to the max of supported sizes among all divisors of user provided shape.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants