Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9d0341d
add pass
charithaintc Oct 21, 2025
76f7323
save work
charithaintc Oct 22, 2025
43c35be
add some tests
charithaintc Oct 22, 2025
cf23eaf
Merge branch 'main' into optimize_transpose
charithaintc Oct 22, 2025
f79d2a2
add some tests
charithaintc Oct 23, 2025
e9211c8
Merge branch 'main' into optimize_transpose
charithaintc Oct 23, 2025
ca5d902
save work
charithaintc Oct 24, 2025
35ca92b
working version
charithaintc Oct 25, 2025
9fcbe03
Merge branch 'main' into optimize_transpose
charithaintc Oct 27, 2025
44e6ac4
Merge branch 'main' into optimize_transpose
charithaintc Oct 28, 2025
17fd7c8
add tests
charithaintc Oct 28, 2025
cbcccf6
add comments
charithaintc Oct 29, 2025
b55f6b0
add comments
charithaintc Oct 29, 2025
f424297
Merge branch 'main' into optimize_transpose
charithaintc Oct 29, 2025
0508bde
Merge branch 'main' into optimize_transpose
charithaintc Oct 31, 2025
3d829c9
Merge branch 'main' into optimize_transpose
charithaintc Oct 31, 2025
51e84ab
use uArch
charithaintc Nov 3, 2025
bd92296
Merge branch 'main' into optimize_transpose
charithaintc Nov 3, 2025
9b694ad
change pass name
charithaintc Nov 3, 2025
4f00ec4
address comments
charithaintc Nov 3, 2025
60ec9f5
address comments
charithaintc Nov 3, 2025
c8590bb
address comments
charithaintc Nov 3, 2025
1af68c7
address comments
charithaintc Nov 3, 2025
22e25a9
remove unused headers
charithaintc Nov 3, 2025
0c96d3e
Merge branch 'main' into optimize_transpose
charithaintc Nov 3, 2025
f70c07c
fix comment
charithaintc Nov 3, 2025
d66e04e
Merge branch 'main' into optimize_transpose
charithaintc Nov 4, 2025
51f4c4b
fix comment
charithaintc Nov 4, 2025
25277bb
Merge branch 'main' into optimize_transpose
charithaintc Nov 4, 2025
0c9cee9
Merge branch 'main' into optimize_transpose
charithaintc Nov 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,16 @@ def XeGPUVectorLinearize : Pass<"xegpu-vector-linearize"> {
"scf::SCFDialect", "ub::UBDialect", "vector::VectorDialect"];
}

def XeGPUOptimizeBlockLoads : Pass<"xegpu-optimize-block-loads"> {
let summary = "Optimize XeGPU block load operations";
let description = [{
This pass rewrites XeGPU loadNd operations into more optimal forms
to improve performance. This includes,
- Rewriting transpose B loads into more optimal forms to use HW block
transpose instructions for better performance.
}];
let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
"vector::VectorDialect"];
}

#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ struct UnrollOptions {

/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);

/// Appends patterns for optimizing block load operations into `patterns`.
void populateXeGPUOptimizeBlockLoadsPatterns(RewritePatternSet &patterns);
/// Appends patterns for XeGPU SIMT distribution into `patterns`.
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
/// Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,15 @@ SmallVector<OpFoldResult> addElementwise(OpBuilder &builder, Location loc,
SmallVector<OpFoldResult> addWithRightAligned(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> lhs,
ArrayRef<OpFoldResult> rhs);

/// Helper Function to find a proper instruction multiple for the user-supplied
/// sg-level data shape (diven by `dim`). `candidates` are uArch allowed shapes.
/// `candidateMultiples` are uArch multiples of such shapes (i.e. block count or
/// array length).
template <typename T>
int getLargestDivisor(T dim, ArrayRef<T> candidates,
ArrayRef<T> candidateMultiples = {});

} // namespace xegpu

} // namespace mlir
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUWgToSgDistribute.cpp
XeGPUPropagateLayout.cpp
XeGPUVectorLinearize.cpp
XeGPUOptimizeBlockLoads.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
Expand Down
490 changes: 490 additions & 0 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp

Large diffs are not rendered by default.

36 changes: 7 additions & 29 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,28 +204,6 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
using Lattice::Lattice;
};

/// Helper Function to find a proper instruction multiple for the user-supplied
/// sg-level data shape. `candidates` are uArch allowed shapes.
/// `candidateMultiples` are uArch multiples of such shapes (e.g., block count).
template <typename T>
int getLargestDivisor(T dim, ArrayRef<T> candidates,
ArrayRef<T> candidateMultiples = {}) {
static_assert(std::is_integral<T>::value, "T must be an integer type");
int largest = -1;
SmallVector<T> multiples = {1};
if (!candidateMultiples.empty())
multiples =
SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
for (T candidate : candidates) {
for (T multiple : multiples) {
int value = static_cast<int>(candidate * multiple);
if (value != 0 && dim % value == 0 && value > largest)
largest = value;
}
}
return largest;
}

/// Helper Functions to get default layouts. A `default layout` is a layout that
/// is assigned to a value when the layout is not fixed by some anchor operation
/// (like DPAS).
Expand Down Expand Up @@ -505,7 +483,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
prefetch.emitWarning("No known block params found for the element type.");
auto [bWidth, bHeight, bCount] = blockWHC.value();
SmallVector<int> instData;
int instWidth = getLargestDivisor(
int instWidth = xegpu::getLargestDivisor(
static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
bCount);
if (instWidth == -1)
Expand All @@ -514,7 +492,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
if (tdescTy.getRank() == 1)
instData = {instWidth};
else {
int instHeight = getLargestDivisor(
int instHeight = xegpu::getLargestDivisor(
static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
if (instHeight == -1)
prefetch.emitWarning(
Expand Down Expand Up @@ -634,15 +612,15 @@ void LayoutInfoPropagation::visitDpasOp(
const unsigned dataALen = aTy.getShape().front();
auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
const int maxALen =
getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
if (maxALen == -1)
dpas.emitWarning(
"No suitable instruction multiple found for the given shape.");

const unsigned dataBLen = bTy.getShape().back();
auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType());
const int maxBLen =
getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
if (maxBLen == -1)
dpas.emitWarning(
"No suitable instruction multiple found for the given shape.");
Expand All @@ -662,7 +640,7 @@ void LayoutInfoPropagation::visitDpasOp(
const unsigned dataCLen = bTy.getShape().back();
auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType());
const int maxCLen =
getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
if (maxCLen == -1)
dpas.emitWarning(
"No suitable instruction multiple found for the given shape.");
Expand Down Expand Up @@ -691,7 +669,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
store.emitWarning("No known block params found for the element type.");
auto [bWidth, bHeight, bCount] = blockWHC.value();
SmallVector<int> instData;
int instWidth = getLargestDivisor(
int instWidth = xegpu::getLargestDivisor(
static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth,
bCount);
if (instWidth == -1)
Expand All @@ -700,7 +678,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
if (dataTy.getRank() == 1)
instData = {instWidth};
else {
int instHeight = getLargestDivisor(
int instHeight = xegpu::getLargestDivisor(
static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
if (instHeight == -1)
store.emitWarning(
Expand Down
26 changes: 26 additions & 0 deletions mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,3 +500,29 @@ xegpu::addWithRightAligned(OpBuilder &builder, Location loc,
results.append(addElementwise(builder, loc, a, b));
return results;
}

template <typename T>
int xegpu::getLargestDivisor(T dim, ArrayRef<T> candidates,
ArrayRef<T> candidateMultiples) {
static_assert(std::is_integral<T>::value, "T must be an integer type");
int largest = -1;
SmallVector<T> multiples = {1};
if (!candidateMultiples.empty())
multiples =
SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
for (T candidate : candidates) {
for (T multiple : multiples) {
int value = static_cast<int>(candidate * multiple);
if (value != 0 && dim % value == 0 && value > largest)
largest = value;
}
}
return largest;
}

/// Explicit instantiations
template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
ArrayRef<int> candidateMultiples);
template int
xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
ArrayRef<unsigned> candidateMultiples);
Loading
Loading