Skip to content
Closed
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
dab6841
[MLIR] Create GPU utils library & move distribution utils
kurapov-peter Dec 9, 2024
fe745c6
Merge remote-tracking branch 'petr_llvm/distribution-utils' into xegp…
charithaintc Dec 10, 2024
f6cd50a
pass added
charithaintc Dec 12, 2024
1c06920
fix
charithaintc Dec 12, 2024
9888c84
fix
charithaintc Dec 12, 2024
491625d
fix
charithaintc Dec 12, 2024
07f9f9f
fix
charithaintc Dec 12, 2024
b842f33
fix
charithaintc Dec 12, 2024
69cbc3b
fix
charithaintc Dec 12, 2024
b7cb16f
Merge branch 'main' into xegpu-distribution-charitha
charithaintc Dec 13, 2024
e7ca3cd
fix
charithaintc Dec 13, 2024
8234edd
Merge branch 'main' into xegpu-distribution-charitha
charithaintc Dec 13, 2024
2f4b748
Merge branch 'main' into xegpu-distribution-charitha
charithaintc Jan 30, 2025
b443c71
sync
charithaintc Jan 30, 2025
6f11f3c
fix comments
charithaintc Jan 31, 2025
36c5b46
Merge branch 'main' into xegpu-distribution-charitha
charithaintc Jan 31, 2025
7d1c7a6
add mem side effects interface
charithaintc Jan 31, 2025
eb7ee36
Merge branch 'xegpu-mem-effects' into xegpu-distribution-charitha
charithaintc Jan 31, 2025
263d72d
add mem side effects interface
charithaintc Jan 31, 2025
1b0bba7
add mem side effects interface
charithaintc Feb 3, 2025
91fa249
Merge branch 'main' into xegpu-mem-effects
charithaintc Feb 3, 2025
38ee43c
Merge branch 'xegpu-mem-effects' into xegpu-distribution-charitha
charithaintc Feb 3, 2025
ae2a3fe
Merge remote-tracking branch 'origin/main' into xegpu-distribution-ch…
charithaintc Feb 3, 2025
2d664e8
fix issues
charithaintc Feb 4, 2025
615f22d
Merge branch 'main' into xegpu-distribution-charitha
charithaintc Feb 4, 2025
983dd4d
Merge branch 'main' into xegpu-distribution-charitha
charithaintc Feb 5, 2025
4afbff9
fix comments
charithaintc Feb 5, 2025
48fc6d5
fix
charithaintc Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace xegpu {

/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
void populateXeGPUDistributePatterns(RewritePatternSet &patterns);

} // namespace xegpu
} // namespace mlir
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUFoldAliasOps.cpp
XeGPUDistribute.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
Expand All @@ -12,6 +13,10 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
MLIRIR
MLIRMemRefDialect
MLIRXeGPUDialect
MLIRVectorDialect
MLIRVectorUtils
MLIRArithDialect
MLIRFuncDialect
MLIRPass
MLIRTransforms
)
364 changes: 364 additions & 0 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUDistribute.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,364 @@
//===- XeGPUDistribute.cpp - XeGPU ditribute ops to work items --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/XeGPU/Transforms/Passes.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "xegpu-distribute"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")

using namespace mlir;

namespace {
bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); }
Copy link
Contributor

@Jianhui-Li Jianhui-Li Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about change the file name to be XeGPUSubgroupDistribute.cpp, to be more explicit. Since we also have a notion of "workgroup distribute".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

everything under our control is changed to Subgroup along with class names, pass names and test cases.


/// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
/// The warp op will still contain the original op that will not be used by
/// the
/// yield op (and should be cleaned up later with dce). The yield op will
/// bypass
/// the create_nd_tdesc's arguments.
/// The rewrite will create a subview of the size used by a single work item
/// and
/// appropriate offset. The distributed create_nd_tdesc points into the
/// subview
/// without offset. The tensor descriptor types is distributed according to
/// sg_map attribute.
///
/// Example:
///
/// ```
/// #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
/// (!xegpu.tensor_desc<4x8xf32>) {
/// ...
/// %td = xegpu.create_nd_tdesc %arg0[0, 0]
/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
/// vector.yield %td
/// }
/// ```
/// To
/// ```
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> () {
/// ...
/// %dead = xegpu.create_nd_tdesc %arg0[0, 0]
/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32>
/// vector.yield %arg0, %dead
/// }
/// %view = memref.subview %r#0[0, %laneid] [4, 1] [1, 1]
/// : memref<4x8xf32> to memref<4x1xf32>
/// %td = xegpu.create_nd_tdesc %view[0, 0]: memref<4x1xf32>
/// -> !xegpu.tensor_desc<4x1xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comments need to be change as well, as we don't need memref.subview.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed now.

///
/// ```
struct WarpOpTensorDescOp final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override;
};

/// Sink a store_nd feeding into vector.yield op for the enclosing
/// `gpu.warp_execute_on_lane_0`. In case arguments for the store are
/// passed
/// through the warp op interface they would be propagated as returned
/// values.
/// Both the stored vector type and tensor descriptor types are distributed
/// according to sg_map attribute.
///
/// Example:
///
/// ```
/// #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
/// gpu.warp_execute_on_lane_0(%laneid) -> () {
/// ...
/// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
/// !xegpu.tensor_desc<4x8xf32>
/// vector.yield
/// }
/// ```
/// To
/// ```
/// %r = gpu.warp_execute_on_lane_0(%laneid) -> () {
/// ...
/// vector.yield
/// }
/// xegpu.store_nd %arg0, %arg1: vector<4x1xf32>,
/// !xegpu.tensor_desc<4x1xf32>
///
/// ```
struct WarpOpStoreNd final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override;
};

/// Clone a load_nd feeding into vector.yield op for the enclosing
/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
/// The warp op will still contain the original op that will not be used by
/// the yield op (and should be cleaned up later with dce). The yield op will
/// bypass the load's arguments. Both the loaded vector type and tensor
/// descriptor types are distributed according to sg_map attribute.
///
/// Example:
///
/// ```
/// #sg_map_8 = #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 1]>
/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
/// (!xegpu.tensor_desc<4x8xf32>) {
/// ...
/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32>,
/// vector<4x8xf32> vector.yield %ld
/// }
/// ```
/// To
/// ```
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> () {
/// ...
/// %dead = xegpu.load_nd %arg0, %arg1:
/// !xegpu.tensor_desc<4x8xf32>, vector<4x8xf32>
/// vector.yield %arg0, %arg1
/// }
/// xegpu.store_nd %r#0, %r#1: vector<4x1xf32>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be load_nd?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. fixed it. sorry I missed it.

/// !xegpu.tensor_desc<4x1xf32>
///
/// ```
struct WarpOpLoadNd final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override;
};

FailureOr<VectorType> getDistributedVectorType(VectorType originalT,
xegpu::SGMapAttr sgMap) {
llvm::SmallVector<int64_t, 2> distributedShape;
auto layout = sgMap.getWiLayout();
auto shape = originalT.getShape();
for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
if (!divisible(APInt(64, o), APInt(64, l)))
return failure();
distributedShape.push_back(o / l);
}
auto newVectorType =
VectorType::get(distributedShape, originalT.getElementType(),
originalT.getScalableDims());
return newVectorType;
}

FailureOr<xegpu::TensorDescType>
getDistributedTensorDescType(xegpu::TensorDescType originalT,
xegpu::SGMapAttr sgMap,
xegpu::MemorySpace memSpace) {
llvm::SmallVector<int64_t, 2> distributedShape;
auto layout = sgMap.getWiLayout();
auto shape = originalT.getShape();
for (const auto [l, o] : llvm::zip_equal(layout, shape)) {
if (!divisible(APInt(64, o), APInt(64, l)))
return failure();
// Tensor descriptor is distributed only for the scattered case.
if (originalT.isScattered())
distributedShape.push_back(o / l);
else
distributedShape.push_back(o);
}

return xegpu::TensorDescType::get(
originalT.getContext(), distributedShape, originalT.getElementType(),
originalT.getEncoding(), originalT.getSGMapAttr());
}
} // namespace

LogicalResult WarpOpStoreNd::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const {
auto yield = cast<gpu::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
Operation *lastNode = yield->getPrevNode();
auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
if (!storeOp)
return failure();

auto origType = storeOp.getTensorDescType();
xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
if (!sgMap)
return rewriter.notifyMatchFailure(
storeOp, "the source tensor descriptor lacks sg_map attribute");

if (storeOp.getTensorDescType().getShape().size() != 2)
return rewriter.notifyMatchFailure(storeOp, "unsupported shape");

auto distributedTypeOrFailure =
getDistributedVectorType(storeOp.getValueType(), sgMap);
if (failed(distributedTypeOrFailure))
return rewriter.notifyMatchFailure(storeOp,
"Failed to distribute the type");
VectorType newVectorType = distributedTypeOrFailure.value();

auto distributedDescTypeOrFailure = getDistributedTensorDescType(
storeOp.getTensorDescType(), sgMap,
storeOp.getTensorDescType().getMemorySpace());
if (failed(distributedDescTypeOrFailure))
return rewriter.notifyMatchFailure(storeOp,
"Failed to distribute the desc type");
xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();

SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, ValueRange{storeOp.getTensorDesc(), storeOp.getValue()},
TypeRange{newTDescType, newVectorType}, newRetIndices);

rewriter.setInsertionPointAfter(newWarpOp);
auto newStoreOp =
cast<xegpu::StoreNdOp>(rewriter.clone(*storeOp.getOperation()));
rewriter.eraseOp(storeOp);
newStoreOp.getTensorDescMutable().assign(
newWarpOp.getResult(newRetIndices[0]));
newStoreOp.getValueMutable().assign(newWarpOp.getResult(newRetIndices[1]));

return success();
}

LogicalResult WarpOpLoadNd::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const {
OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
return isa<xegpu::LoadNdOp>(op) && op->hasOneUse();
});

if (!operand)
return rewriter.notifyMatchFailure(warpOp,
"warp result is not a xegpu::LoadNd op");

auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();

if (loadOp.getPacked())
return rewriter.notifyMatchFailure(
loadOp, "Packed load distribution not supported");

xegpu::TensorDescType origType = loadOp.getTensorDescType();
xegpu::SGMapAttr sgMap = origType.getSGMapAttr();
if (!sgMap)
return rewriter.notifyMatchFailure(
loadOp, "the source tensor descriptor lacks sg_map attribute");

auto origShape = origType.getShape();
if (origShape.size() != 2)
return rewriter.notifyMatchFailure(loadOp, "unsupported shape");

auto distributedTypeOrFailure =
getDistributedVectorType(loadOp.getType(), sgMap);
if (failed(distributedTypeOrFailure))
return rewriter.notifyMatchFailure(loadOp, "Failed to distribute the type");
VectorType newVectorType = distributedTypeOrFailure.value();

auto distributedDescTypeOrFailure =
getDistributedTensorDescType(loadOp.getTensorDescType(), sgMap,
loadOp.getTensorDescType().getMemorySpace());
if (failed(distributedDescTypeOrFailure))
return rewriter.notifyMatchFailure(loadOp,
"Failed to distribute the desc type");
xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();

unsigned operandIdx = operand->getOperandNumber();

SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, loadOp.getTensorDesc(), TypeRange{newTDescType},
newRetIndices);

rewriter.setInsertionPointAfter(newWarpOp);

auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
loadOp.getLoc(), newVectorType, loadOp.getTensorDesc(),
loadOp.getPackedAttr(), loadOp.getTransposeAttr(), loadOp.getL1HintAttr(),
loadOp.getL2HintAttr(), loadOp.getL3HintAttr());

newLoadOp.getTensorDescMutable().assign(
newWarpOp.getResult(newRetIndices[0]));
Value distributedVal = newWarpOp.getResult(operandIdx);
rewriter.replaceAllUsesWith(distributedVal, newLoadOp);

return success();
}

LogicalResult
WarpOpTensorDescOp::matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to stick to "subgroup" prefix since XeGPU uses "subgroup" terminology, which is counterpart of warp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

PatternRewriter &rewriter) const {
OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
return isa<xegpu::CreateNdDescOp>(op) && op->hasOneUse();
});

if (!operand)
return rewriter.notifyMatchFailure(
warpOp, "warp result is not a xegpu::CreateNdDesc op");
auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
assert(descOp && "desc op must be not null");
unsigned operandIdx = operand->getOperandNumber();

// TODO: is memref uniform in the region
rewriter.setInsertionPoint(warpOp);
auto srcTypedVal = dyn_cast<TypedValue<MemRefType>>(descOp.getSource());
assert(srcTypedVal && "source value must be not null");

auto descOffsets = descOp.getMixedOffsets();
if (descOffsets.size() != 2)
return rewriter.notifyMatchFailure(descOp,
"offsets size is expected to be 2");

xegpu::SGMapAttr sgMap = descOp.getType().getSGMapAttr();
if (!sgMap)
return rewriter.notifyMatchFailure(
descOp, "the tensor descriptor lacks sg_map attribute");

auto distributedDescTypeOrFailure = getDistributedTensorDescType(
descOp.getType(), sgMap, descOp.getType().getMemorySpace());
if (failed(distributedDescTypeOrFailure))
return rewriter.notifyMatchFailure(descOp,
"Failed to distribute the desc type");
xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value();
auto distributedShape = newTDescType.getShape();
// use the base memref strides
SmallVector<OpFoldResult> overwriteStrides =
getAsIndexOpFoldResult(rewriter.getContext(), SmallVector<int64_t>{1, 1});
SmallVector<OpFoldResult> overwriteSizes =
getAsIndexOpFoldResult(rewriter.getContext(), distributedShape);

SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, descOp.getSource(), descOp.getSourceType(),
newRetIndices);

rewriter.setInsertionPointAfter(newWarpOp);
auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
newWarpOp.getLoc(), newTDescType,
dyn_cast<TypedValue<MemRefType>>(newWarpOp.getResult(newRetIndices[0])),
descOffsets);

Value distributedVal = newWarpOp.getResult(operandIdx);
rewriter.replaceAllUsesWith(distributedVal, newDescOp);

return success();
}

void xegpu::populateXeGPUDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WarpOpTensorDescOp>(patterns.getContext());
patterns.add<WarpOpStoreNd>(patterns.getContext());
patterns.add<WarpOpLoadNd>(patterns.getContext());
}
Loading