Skip to content
163 changes: 163 additions & 0 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"

#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Transforms/DialectConversion.h"
#include <optional>

namespace mlir {
namespace xegpu {
Expand Down Expand Up @@ -314,6 +317,90 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};

// This pattern transforms elementwise ops (unary/binary) in math/arith dialect
template <typename Op>
struct WgToSgElementwiseOp : public OpConversionPattern<Op> {
using OpConversionPattern<Op>::OpConversionPattern;
using OneToNOpAdaptor = typename OpConversionPattern<Op>::OneToNOpAdaptor;

LogicalResult
matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// All operands/results must be 1D or 2D vectors
auto resultType = dyn_cast<VectorType>(op.getResult().getType());
if (!resultType || (resultType.getRank() != 1 && resultType.getRank() != 2))
return rewriter.notifyMatchFailure(
op, "Result type is not a 1D or 2D vector");

ArrayRef<int64_t> shape = resultType.getShape();
for (Value operand : op->getOperands()) {
auto operandType = dyn_cast<VectorType>(operand.getType());
if (!operandType || operandType.getRank() != resultType.getRank() ||
operandType.getShape() != shape) {
return rewriter.notifyMatchFailure(
op, "Operand type is not a 1D or 2D vector with the same shape as "
"result type");
}
}

auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
if (!layout || !layout.getSgLayout())
return rewriter.notifyMatchFailure(
op, "Operation does not have a valid layout attribute for subgroup "
"distribution");

// Extract sgShape from layout
SmallVector<int64_t> sgShape;
if (auto sgDataAttr = layout.getSgData()) {
sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
} else {
auto sgLayoutArr = layout.getSgLayout();
sgShape.reserve(shape.size());
for (size_t i = 0; i < shape.size(); ++i) {
assert(sgLayoutArr[i] != 0 && "sgLayout elements must be non-zero");
sgShape.push_back(shape[i] / sgLayoutArr[i]);
}
}

size_t numVariants = adaptor.getOperands().empty()
? 0
: adaptor.getOperands().front().size();
for (auto &operandVec : adaptor.getOperands())
if (operandVec.size() != numVariants)
return rewriter.notifyMatchFailure(
op, "Operand lists have mismatched sizes");

SmallVector<Value> newResults;

auto origResultType = dyn_cast<VectorType>(op->getResult(0).getType());
VectorType newResultType =
origResultType
? VectorType::get(sgShape, origResultType.getElementType())
: VectorType::get(sgShape, resultType.getElementType());

for (size_t i = 0; i < numVariants; ++i) {
SmallVector<Value> operands;
for (auto &operandVec : adaptor.getOperands())
operands.push_back(operandVec[i]);

auto newOp = rewriter.create<Op>(op.getLoc(), newResultType, operands);

// Copy all attributes except "layout", and add "layout_result_0" with
// sgLayout/data dropped
for (auto attr : op->getAttrs()) {
if (attr.getName() != "layout")
newOp->setAttr(attr.getName(), attr.getValue());
}
newOp->setAttr("layout_result_0", layout.dropSgLayoutAndData());

newResults.push_back(newOp.getResult());
}

rewriter.replaceOpWithMultiple(op, {newResults});
return success();
}
};

} // namespace

namespace mlir {
Expand All @@ -322,6 +409,57 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
patterns.getContext());
// Add elementwise operations that can be distributed to subgroups
patterns.add<
WgToSgElementwiseOp<arith::AddFOp>, WgToSgElementwiseOp<arith::SubFOp>,
WgToSgElementwiseOp<math::ExpOp>, WgToSgElementwiseOp<math::SqrtOp>,
WgToSgElementwiseOp<math::AbsFOp>, WgToSgElementwiseOp<math::CosOp>,
WgToSgElementwiseOp<math::CoshOp>, WgToSgElementwiseOp<math::AcosOp>,
WgToSgElementwiseOp<math::AcoshOp>, WgToSgElementwiseOp<math::SinOp>,
WgToSgElementwiseOp<math::SinhOp>, WgToSgElementwiseOp<math::AsinOp>,
WgToSgElementwiseOp<math::AsinhOp>, WgToSgElementwiseOp<math::TanOp>,
WgToSgElementwiseOp<math::TanhOp>, WgToSgElementwiseOp<math::AtanOp>,
WgToSgElementwiseOp<math::Atan2Op>, WgToSgElementwiseOp<math::AtanhOp>,
WgToSgElementwiseOp<math::ErfOp>, WgToSgElementwiseOp<math::LogOp>,
WgToSgElementwiseOp<math::Log2Op>, WgToSgElementwiseOp<math::FloorOp>,
WgToSgElementwiseOp<math::CeilOp>, WgToSgElementwiseOp<math::PowFOp>,
WgToSgElementwiseOp<math::RsqrtOp>, WgToSgElementwiseOp<arith::NegFOp>,
WgToSgElementwiseOp<arith::AddIOp>, WgToSgElementwiseOp<arith::SubIOp>,
WgToSgElementwiseOp<arith::MulFOp>, WgToSgElementwiseOp<arith::MulIOp>,
WgToSgElementwiseOp<arith::ShLIOp>, WgToSgElementwiseOp<arith::ShRSIOp>,
WgToSgElementwiseOp<arith::ShRUIOp>, WgToSgElementwiseOp<arith::DivFOp>,
WgToSgElementwiseOp<arith::DivSIOp>, WgToSgElementwiseOp<arith::DivUIOp>,
WgToSgElementwiseOp<arith::MaximumFOp>,
WgToSgElementwiseOp<arith::MinimumFOp>,
WgToSgElementwiseOp<arith::RemSIOp>, WgToSgElementwiseOp<arith::RemUIOp>,
WgToSgElementwiseOp<arith::TruncFOp>,
WgToSgElementwiseOp<arith::TruncIOp>, WgToSgElementwiseOp<arith::ExtFOp>,
WgToSgElementwiseOp<arith::ExtSIOp>, WgToSgElementwiseOp<arith::ExtUIOp>,
WgToSgElementwiseOp<arith::SIToFPOp>,
WgToSgElementwiseOp<arith::UIToFPOp>,
WgToSgElementwiseOp<arith::FPToSIOp>,
WgToSgElementwiseOp<arith::FPToUIOp>,
WgToSgElementwiseOp<arith::IndexCastUIOp>,
WgToSgElementwiseOp<arith::IndexCastOp>,
WgToSgElementwiseOp<arith::BitcastOp>, WgToSgElementwiseOp<arith::CmpIOp>,
WgToSgElementwiseOp<arith::CmpFOp>, WgToSgElementwiseOp<arith::AndIOp>,
WgToSgElementwiseOp<arith::CeilDivSIOp>,
WgToSgElementwiseOp<arith::CeilDivUIOp>,
WgToSgElementwiseOp<arith::FloorDivSIOp>,
WgToSgElementwiseOp<arith::MaxNumFOp>,
WgToSgElementwiseOp<arith::MaxSIOp>, WgToSgElementwiseOp<arith::MaxUIOp>,
WgToSgElementwiseOp<arith::MinNumFOp>,
WgToSgElementwiseOp<arith::MinSIOp>, WgToSgElementwiseOp<arith::MinUIOp>,
WgToSgElementwiseOp<arith::OrIOp>, WgToSgElementwiseOp<arith::RemFOp>,
WgToSgElementwiseOp<arith::XOrIOp>, WgToSgElementwiseOp<math::AbsIOp>,
WgToSgElementwiseOp<math::CbrtOp>, WgToSgElementwiseOp<math::CopySignOp>,
WgToSgElementwiseOp<math::CtPopOp>, WgToSgElementwiseOp<math::ErfcOp>,
WgToSgElementwiseOp<math::Exp2Op>, WgToSgElementwiseOp<math::ExpM1Op>,
WgToSgElementwiseOp<math::FPowIOp>, WgToSgElementwiseOp<math::IPowIOp>,
WgToSgElementwiseOp<math::Log10Op>, WgToSgElementwiseOp<math::Log1pOp>,
WgToSgElementwiseOp<math::RoundOp>,
WgToSgElementwiseOp<math::RoundEvenOp>,
WgToSgElementwiseOp<math::TruncOp>>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
Expand Down Expand Up @@ -368,6 +506,31 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
return isLegal(layout);
});
target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
[=](Operation *op) -> std::optional<bool> {
// Handle unary and binary operations
if (op->getNumOperands() < 1 || op->getNumOperands() > 2)
return true;

// check if input and output are vectors
VectorType resultType =
dyn_cast<VectorType>(op->getResult(0).getType());
if (!resultType || resultType.getRank() != 2)
return true;

// Check if all operands are vectors
for (Value operand : op->getOperands()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consider the use of llvm::all_equal on op->getOperandTypes()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this loop is equivalent to

if (llvm::any_of(op->getOperandTypes(), [&](Type type) { return type != resultType; }))
  return true; 

VectorType operandType = dyn_cast<VectorType>(operand.getType());
if (!operandType || operandType.getRank() != 2 ||
operandType.getShape() != resultType.getShape()) {
return true;
}
}

auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(
op->getAttrOfType<xegpu::LayoutAttr>("layout"));
return isLegal(layout);
});

target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });

Expand Down
Loading