Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ class SliceAttr;
} // namespace xegpu
} // namespace mlir

// clang-format off
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
// clang-format on

#define GET_ATTRDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
Expand Down
47 changes: 29 additions & 18 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,16 @@ def XeGPU_FenceScope: I32EnumAttr<"FenceScope",
let cppNamespace = "::mlir::xegpu";
}

def XeGPU_WGLevel: I32EnumAttrCase<"WG", 0, "wg">;
def XeGPU_SGLevel: I32EnumAttrCase<"SG", 1, "sg">;
def XeGPU_WILevel: I32EnumAttrCase<"WI", 2, "wi">;
def XeGPU_DistributionLevel: I32EnumAttr<"DistributionLevel",
"The enumeration for the scope of fence operation.",
[XeGPU_WGLevel, XeGPU_SGLevel, XeGPU_WILevel]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::xegpu";
}

def XeGPU_FenceScopeAttr:
EnumAttr<XeGPU_Dialect, XeGPU_FenceScope, "fence_scope"> {
let summary = [{Describes the scope of fence.
Expand Down Expand Up @@ -223,18 +233,18 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
InterfaceMethod<"Derive a new layout by dropping InstData",
"xegpu::DistributeLayoutAttr",
"dropInstData">,
InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
indices based on the effective subgroup layout.}],
InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
indices based on the effective `level` layout.}],
"FailureOr<SmallVector<Value>>",
"delinearizeSubgroupId",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
InterfaceMethod<[{Generates instructions to compute multidimensional offsets for blocks
assigned to a subgroup identified by linearId. The shape parameter
represents the workgroup-level problem size. Each subgroup may access
"delinearizeId",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "xegpu::DistributionLevel": $level)>,
InterfaceMethod<[{Generates instructions to compute multidimensional offsets for dist units
assigned to a `level` identified by linearId. The shape parameter
represents the higher-level problem size. Each `level` may access
multiple blocks according to round-robin distribution rules.}],
"FailureOr<SmallVector<SmallVector<Value>>>",
"getOffsets",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
"computeDistributedCoords",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape, "xegpu::DistributionLevel": $level)>,
InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
to some other layout according to given permutation of (0...n-1).}],
/*retTy=*/"bool",
Expand Down Expand Up @@ -476,17 +486,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
return {};
}

/// Delinearizes a linear subgroup ID into its multidimensional indices
/// based on the effective subgroup layout.
/// Delinearizes a linear ID into its multidimensional indices
/// based on the effective `level` layout.
FailureOr<SmallVector<Value>>
delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
delinearizeId(OpBuilder &builder, Location loc, Value linearId, xegpu::DistributionLevel level);

/// Generates instructions to compute multidimensional offsets for blocks
/// assigned to a subgroup identified by linearId. The shape parameter
/// represents the workgroup-level problem size. Each subgroup may access
/// Generates instructions to compute multidimensional offsets for dist units
/// assigned to a `level` identified by linearId. The shape parameter
/// represents the higher-level problem size. Each `level` may access
/// multiple blocks according to round-robin distribution rules.
FailureOr<SmallVector<SmallVector<Value>>>
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape, xegpu::DistributionLevel level);

/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
Expand Down Expand Up @@ -643,14 +653,15 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
/// Delinearizes a linear subgroup ID into its multidimensional indices
/// based on the effective subgroup layout.
FailureOr<SmallVector<Value>>
delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
delinearizeId(OpBuilder &builder, Location loc, Value linearId, xegpu::DistributionLevel level);

/// Generates instructions to compute multidimensional offsets for blocks
/// assigned to a subgroup identified by linearId. The shape parameter
/// represents the workgroup-level problem size. Each subgroup may access
/// multiple blocks according to round-robin distribution rules.

FailureOr<SmallVector<SmallVector<Value>>>
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape, xegpu::DistributionLevel level);

/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
The pass distributes subgroup level (SIMD) XeGPU ops to work items.
}];
let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
"vector::VectorDialect"];
"vector::VectorDialect", "index::IndexDialect"];
}

def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
if (!valOrResVecTy)
valOrResVecTy = VectorType::get(1, data.getType());
if (valOrResVecTy.getShape().size() != 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a verification here for load/store_matrix with @subgroup_block_io attribute: The payload must be contiguous in the memory.

Both of these two IRs in the tests added in this PR are actually not correct. Since the payload data are not contiguous between lanes. They are correct if you change the vector<2x16xf32> to <16x2xf32> (lane_layout/lane_data need to change accordingly but that is out of IR verifier's scope).

    %1 = xegpu.load_matrix %arg0[%c0, %c0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
      !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<2x16xf32>

    xegpu.store_matrix %1, %arg0[%c0, %c0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
      vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added verification. However,

change the vector<2x16xf32> to <16x2xf32>

I understand the logical reasoning for this in the matrix ops case, but the current distribution does not allow it, considering the "correct" lane layout the block load requires.

We have

  for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
    if (i < distributionStart)
      continue;
    // Check if the dimension can be distributed evenly.
    if (dim % effectiveLaneLayout[i - distributionStart] != 0)
      return failure();
    distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
  }

Meaning that given lane_layout = [1, 16], lane_data = [1, 1] and a 16x2 data shape, we get

shape[0] % layout[0] = 16 % 1 = 0 // good
shape[1] % layout[1] = 2 % 16 = 2 // fail

We can change the layout to be [16, 1], which would allow the pattern to complete and the distributed code to still be correct, since the lane layout is not used in further coordinate calculations. But [16, 1] may be harder for users to reason about by simply looking at the xevm block load description and the sg-level subgroup_block_io matrix op.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If user uses stride=[1, 32] in the memory layout, then user should able to reason sg_layout = [16, 1].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if user use lane_layout = [1, 16], it should not use strided memory layout, the example above should just use block layout. The maxtrix op with subgroup_block_io is a subgroup operation, and all lanes collectively access a contiguous memory buffer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is now a test with a 16x2xf32 result using the proper stride.
Short snippet:

    %1 = xegpu.load_matrix %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
     !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index -> vector<16x2xf32>

It distributes to 1x2xf32

return rewriter.notifyMatchFailure(op, "Expected 1D data vector.");

int64_t elemBitWidth =
valOrResVecTy.getElementType().getIntOrFloatBitWidth();
Expand Down
154 changes: 86 additions & 68 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,47 +38,47 @@ void XeGPUDialect::initialize() {
>();
}

/// Generates instructions to compute offsets for a subgroup identified by
/// its multidimensional indices (sgId), using the specified subgroup layout
/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
/// dimensions (sizePerWg).
// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
// within each distribution unit.
static SmallVector<SmallVector<Value>>
genOffsetsComputingInsts(OpBuilder &builder, Location loc,
SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
ArrayRef<int64_t> sizePerSg,
ArrayRef<int64_t> sizePerWg) {

genOffsets(OpBuilder &builder, Location loc, SmallVector<Value> delinearizedId,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the function should be genCoords(). Inside the function, when we compute each individual value, it is fine to still use offset.

A coordinate (coord) in n-d tensor is a vector of logical offsets.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually reading the code below, almost all "offsets" variable (vector of value) can be renamed to "coord".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed.

ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
ArrayRef<int64_t> srcShape) {
SmallVector<SmallVector<Value>> offsets;

// nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
SmallVector<Value> localOffsets = llvm::map_to_vector(
llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
// A distribution unit must be less than or equal to `srcShape`
SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
llvm::zip_equal(srcShape,
computeElementwiseMul(subShapesLayout, subShape)),
[](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });

// Get the offset of `subShape` within a distribution unit.
SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
return builder.createOrFold<index::MulOp>(
loc, std::get<0>(t),
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
});

// distUnit[i] is the minimum value between sizePerWg[i] and
// sgLayout[i] * sizePerSg[i]
SmallVector<int64_t> distUnit = llvm::map_to_vector(
llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
[](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });

// For each dist unit
for (SmallVector<int64_t> unitOffs :
StaticTileOffsetRange(sizePerWg, distUnit)) {
StaticTileOffsetRange(srcShape, distUnitShape)) {
// Get dist unit offset within `srcShape`.
SmallVector<Value> base =
llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
return arith::ConstantIndexOp::create(builder, loc, d);
});

SmallVector<Value> adds = llvm::map_to_vector(
llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
std::get<1>(t));
});

// Calculate `subShape` offset within `srcShape`.
SmallVector<Value> adds =
llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
[&](const auto &t) -> Value {
return builder.createOrFold<arith::AddIOp>(
loc, std::get<0>(t), std::get<1>(t));
});
// Do not go beyond `srcShape` bounds.
SmallVector<Value> mods = llvm::map_to_vector(
llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
return builder.createOrFold<index::RemUOp>(
loc, std::get<0>(t),
arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
Expand Down Expand Up @@ -268,12 +268,8 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
}

FailureOr<SmallVector<Value>>
LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
Value linearId) {
// delinearizeSubgroupId is only available for
// workgroup-level layout attribute
if (!isForWorkgroup())
return failure();
LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
xegpu::DistributionLevel idLevel) {

// TODO: handle order attribute
auto hasDefaultOrder = [&]() {
Expand All @@ -283,41 +279,53 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
};
if (!hasDefaultOrder())
return mlir::emitError(loc, "order attribute is currently not supported.");

auto dims =
llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
});
SmallVector<int64_t> layout;
if (idLevel == xegpu::DistributionLevel::SG) {
layout = getEffectiveSgLayoutAsInt();
} else if (idLevel == xegpu::DistributionLevel::WI) {
layout = getEffectiveLaneLayoutAsInt();
} else {
return failure();
}
auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value {
return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
});

return affine::delinearizeIndex(builder, loc, linearId, dims);
}

/// Implements DistributeLayoutAttr::getOffsets to generate
/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
/// instructions for computing multi-dimensional offsets when distributed by
/// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
ArrayRef<int64_t> shape) {
if (!isForWorkgroup())
LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
Value linearId, ArrayRef<int64_t> shape,
xegpu::DistributionLevel targetLevel) {
SmallVector<int64_t> layout;
SmallVector<int64_t> subShape;
if (targetLevel == DistributionLevel::SG) {
layout = getEffectiveSgLayoutAsInt();
subShape = getEffectiveSgDataAsInt();
} else if (targetLevel == DistributionLevel::WI) {
layout = getEffectiveLaneLayoutAsInt();
subShape = getEffectiveLaneDataAsInt();
} else {
return failure();

SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
if (sgShape.empty()) {
if (auto derivedShape = computeShapeRatio(shape, sgLayout))
sgShape = derivedShape.value();
}
if (subShape.empty()) {
if (auto derivedShape = computeShapeRatio(shape, layout))
subShape = derivedShape.value();
else
return failure();
}

// delinearize Ids
auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
auto maybeIds = delinearizeId(builder, loc, linearId, targetLevel);
if (failed(maybeIds))
return failure();
SmallVector<Value> sgIds = *maybeIds;
SmallVector<Value> ids = *maybeIds;

return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
shape);
return genOffsets(builder, loc, ids, layout, subShape, shape);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -371,34 +379,45 @@ SliceAttr SliceAttr::flatten() const {
}

FailureOr<SmallVector<Value>>
SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
Value linearId) {
SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
xegpu::DistributionLevel level) {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
return parent.delinearizeSubgroupId(builder, loc, linearId);
return parent.delinearizeId(builder, loc, linearId, level);
}

/// Implements DistributeLayoutAttr::getOffsets to generate
/// instructions for computing multi-dimensional offsets when distributed by
/// SliceAttr.
// Implements DistributeLayoutAttr::computeDistributedCoords to generate
// instructions for computing multi-dimensional offsets when distributed by
// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
ArrayRef<int64_t> shape) {
SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
Value linearId, ArrayRef<int64_t> shape,
xegpu::DistributionLevel targetLevel) {
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
if (!isForWorkgroup())
return failure();

SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
if (sgShape.empty()) {
if (auto derivedShape = computeShapeRatio(shape, sgLayout))
sgShape = derivedShape.value();
SmallVector<int64_t> layout;
SmallVector<int64_t> subShape;
if (targetLevel == DistributionLevel::SG) {
layout = getEffectiveSgLayoutAsInt();
subShape = getEffectiveSgDataAsInt();
} else if (targetLevel == DistributionLevel::WI) {
layout = getEffectiveLaneLayoutAsInt();
subShape = getEffectiveLaneDataAsInt();
} else {
return failure();
}

if (subShape.empty()) {
if (auto derivedShape = computeShapeRatio(shape, layout))
subShape = derivedShape.value();
else
return failure();
}

// delinearize Ids
auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
auto maybeIds = delinearizeId(builder, loc, linearId, targetLevel);
if (failed(maybeIds))
return failure();

Expand All @@ -408,8 +427,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
SmallVector<Value> sgIds =
XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);

return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
shape);
return genOffsets(builder, loc, sgIds, layout, subShape, shape);
}

bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
Expand Down
5 changes: 1 addition & 4 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
if (!dataTy) {
if (subgroup_block_io)
return emitError() << "subgroup_block_io "
"are only allowed when result is a 1D VectorType.";
"are only allowed when result is a VectorType.";
else
return success();
}
Expand All @@ -193,9 +193,6 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
ArrayRef<int64_t> mdescShape = mdescTy.getShape();

if (dataShape.size() == 2) {
if (subgroup_block_io)
return emitError() << "subgroup_block_io "
"are only allowed when result is a 1D VectorType.";
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
return emitError() << "data shape must not exceed mem_desc shape.";
Expand Down
Loading