Skip to content

[mlir][xegpu] add support for structure control flow ops in workgroup to subgroup distribution #142618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 13, 2025
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class TensorDescType;

namespace xegpu {

/// Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector<Value> flattenValues(ArrayRef<ValueRange> values);

/// If tensor descriptor has a layout attribute it is used in SIMT mode.
/// In this mode, the distributed vector shape is determined as follows:
/// Definitions:
Expand Down
187 changes: 163 additions & 24 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
Expand All @@ -29,6 +31,29 @@ using namespace mlir;

namespace {

static std::pair<SmallVector<int64_t>, int>
getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this go to XeGPUUtils so XeGPU blocking could use it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, they are different logic. blocking is using inst_data, here it is using sg_layout and sg_data.

int count = 1;
SmallVector<int64_t> sgShape(shape);

if (layout && layout.isWgLayout()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isWgLayout seems confusing, I think it should be called isSgLayout since it describes how the subgroups are laid out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interface defined in a previous PR. I think we can create a small fix PR if we plan to change it.

DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout();
auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
if (DenseI32ArrayAttr sgDataAttr = layout.getSgData())
sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
else
sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape);
SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
// Clamp distUnit to the original shape to handle cases where data is
// shared among subgroups, which may cause distUnit to exceed the original
// shape.
for (size_t i = 0; i < distUnit.size(); ++i)
distUnit[i] = std::min(shape[i], distUnit[i]);
count = computeProduct(shape) / computeProduct(distUnit);
}
return std::make_pair(sgShape, count);
}

/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
/// from a workgroup descriptor. It replaces the offsets and sizes with
/// appropriate values for the subgroup.
Expand Down Expand Up @@ -129,18 +154,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
return rewriter.notifyMatchFailure(
op, "sgLayout attribute is required in layout");

SmallVector<int64_t> sgShape;
if (auto sgDataAttr = layout.getSgData()) {
sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
} else {
assert(wgShape.size() == sgLayout.size() &&
"sgLayout and wgShape must have the same rank");
sgShape.reserve(wgShape.size());
for (size_t i = 0; i < wgShape.size(); ++i) {
assert(sgLayout[i] != 0 && "sgLayout elements must be non-zero");
sgShape.push_back(wgShape[i] / sgLayout[i]);
}
}
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;

// TODO : Handle order attribute
// Get the subgroup ID
Expand Down Expand Up @@ -266,15 +280,15 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
if (resultTy.getRank() != 2)
return failure();

auto originalLayout =
llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
auto originalLayout = xegpu::getLayoutAttr(op.getResult());
if (!originalLayout)
return failure();

SmallVector<Value> newDpasOps;
size_t i = 0;
SmallVector<Value> newDpasOps;
for (auto aVec : adaptor.getLhs()) {
for (auto bVec : adaptor.getRhs()) {

llvm::SmallVector<Value> operands({aVec, bVec});
Value tmpC;
if (op.getAcc()) {
Expand All @@ -288,10 +302,10 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
llvm::cast<VectorType>(bVec.getType()).getShape();
VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
resultTy.getElementType());
tmpC = rewriter.create<xegpu::DpasOp>(
loc, resTy, operands,
llvm::ArrayRef<NamedAttribute>(
{"layout_result_0", originalLayout.dropSgLayoutAndData()}));
tmpC = rewriter.create<xegpu::DpasOp>(loc, resTy, operands);
xegpu::setLayoutAttr(cast<OpResult>(tmpC),
originalLayout.dropSgLayoutAndData());

newDpasOps.push_back(tmpC);
}
}
Expand All @@ -314,14 +328,64 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};

// Handles UnrealizedConversionCastOp generated during
// SCFStructuralTypeConversions (step 1). This op may appear as either a
// target or source materialization for Vector values, e.g.:
// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
// it could be either 1:N or N:1 cast. In both cases, the pattern
// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
// TODO: remove it when context-aware type converter is ready.
struct UnrealizedConversionCastOpPattern
: public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
using OpConversionPattern<
mlir::UnrealizedConversionCastOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());

auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());

if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
!llvm::all_equal(ValueRange(inputs).getTypes()))
return failure();

// Handles the case where cast %1 : vector<256xf32> to vector<16xf32>, ...
// the input values provided by the adaptor should already be distributed,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you clarify who distributed the input? is it done by SCFStructuralTypeConversions? that means the input maybe coming from some structural op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an example for the pattern, hope it can help to understand. For arguments and results (N:1 case), they are generated by SCFStructuralTypeConversions, for 1:N case, they are generated by patterns of, e.g., create_nd etc.

// and their types should correspond exactly to the result types of the
// operation.
if (op.getNumOperands() == 1 &&
llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
rewriter.replaceOp(op, inputs);
return success();
}

// Handles the case where cast %1 : vector<16xf32>, ... to vector<256xf32>.
// All input values must have the same vector type, and their shape must be
// evenly divisible by the output vector's shape.
// TODO: it is not safe to do such forward, since such N:1 cast could be
// from others
if (op.getNumResults() == 1 &&
computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
rewriter.replaceOpWithMultiple(op, {inputs});
return success();
}

return mlir::failure();
}
};

} // namespace

namespace mlir {
namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
patterns.getContext());
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
UnrealizedConversionCastOpPattern>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
Expand All @@ -334,10 +398,60 @@ struct XeGPUWgToSgDistributePass
} // namespace

void XeGPUWgToSgDistributePass::runOnOperation() {
// Track existing UnrealizedConversionCastOps
SmallVector<Operation *> existingCastOps;
getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
existingCastOps.push_back(castOp.getOperation());
});

TypeConverter converter;
converter.addConversion([&](Type type) -> Type { return type; });
converter.addConversion(
[&](RankedTensorType type,
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
Type elemTy = type.getElementType();
ArrayRef<int64_t> shape = type.getShape();

int count;
SmallVector<int64_t> subShape;
std::tie(subShape, count) = getSgShapeAndCount(
shape, dyn_cast<xegpu::LayoutAttr>(type.getEncoding()));

auto newTy = VectorType::get(subShape, elemTy);
result.append(count, newTy);
return success();
});

// Step 1: Apply SCFStructuralTypeConversions to SCF operations with
// VectorType operands. This first converts such operands to RankedTensorType,
// propagates the layout attribute into the encoding attribute, and finally
// converts the RankedTensorType to VectorType based on the encoding.
xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), converter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we need the "do"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it needs a verb-prefix. I am open to change it to something else instead of do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after this point how does the IR looks like? all SCF operations are distributed and there is no ranked tensor type in the IR. is that correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. All VectorType operands/arguments/Results of SCF::If, SCF::For, SCF::While and SCF::Condition ops will be converted.


MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
ConversionTarget target(*ctx);

converter.addConversion(
[&](xegpu::TensorDescType type,
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
Type elemTy = type.getElementType();
ArrayRef<int64_t> shape = type.getShape();

int count;
SmallVector<int64_t> subShape;
xegpu::LayoutAttr layout = type.getLayoutAttr();
std::tie(subShape, count) = getSgShapeAndCount(shape, layout);

if (layout)
layout = layout.dropSgLayoutAndData();

auto newTy = xegpu::TensorDescType::get(
type.getContext(), subShape, elemTy, type.getEncoding(), layout);
result.append(count, newTy);
return success();
});

auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
return createOp.getType();
Expand All @@ -353,26 +467,51 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
};

auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
return !layout || layout.getSgLayout() == nullptr;
return !layout || !layout.isWgLayout();
};

target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
auto tdescTy = getTensorDescType(op);
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
return isLegal(layout);
});

target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
auto layout = xegpu::getLayoutAttr(op.getResult());
return isLegal(layout);
});

target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
[=](UnrealizedConversionCastOp op) {
return llvm::is_contained(existingCastOps, op.getOperation());
});

target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });

// Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
// as well as XeGPU, Arith, and Vector operations.
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
return signalPassFailure();

// Remove sg_layout and sg_data attributes from the Layout
// attribute for each VectorType result of the operation.
// For Structured Control Flow ops, the layout is simply removed,
// since in 1:N case, the layout for new results are missing.
// Layout propagation pass will activated.
getOperation()->walk([](Operation *op) {
for (OpResult result : op->getOpResults()) {
std::string name = xegpu::getLayoutName(result);
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
op->removeAttr(name);
if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op))
op->setAttr(name, layout.dropInstData());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why drop inst data?

Copy link
Contributor Author

@chencha3 chencha3 Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bug. Fixed

}
}
});
}
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
using namespace mlir;

/// convert ArrayRef<ValueRange> into SmallVector<Value>
static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
SmallVector<Value> xegpu::flattenValues(ArrayRef<ValueRange> values) {
SmallVector<Value> result;
for (const auto &vals : values)
llvm::append_range(result, vals);
Expand Down Expand Up @@ -271,7 +271,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
auto resultTy = dyn_cast<RankedTensorType>(result.getType());

// Only look at ops casting from VectorType to RankedTensorType
if (!isa<VectorType>(inputTy) || !isa<RankedTensorType>(resultTy))
if (!inputTy || !resultTy)
return WalkResult::skip();

xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input);
Expand Down Expand Up @@ -342,7 +342,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
}

if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
SmallVector<Value> values = flattenValues(adaptor.getInputs());
SmallVector<Value> values = xegpu::flattenValues(adaptor.getInputs());
auto newOp = rewriter.create<UnrealizedConversionCastOp>(
op.getLoc(), outputTy, values);
rewriter.replaceOp(op, newOp);
Expand Down
Loading