Skip to content
89 changes: 88 additions & 1 deletion mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"

#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Transforms/DialectConversion.h"
#include <optional>

namespace mlir {
namespace xegpu {
Expand Down Expand Up @@ -328,6 +331,65 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};

// This pattern transforms elementwise ops to work at subgroup level.
struct WgToSgElementwiseOp : public ConversionPattern {
WgToSgElementwiseOp(MLIRContext *ctx)
: ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const override {
// Only match ops with elementwise trait and single result.
if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
return failure();

auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
assert(resultType && "Expected result to be a VectorType");

ArrayRef<int64_t> wgShape = resultType.getShape();

xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
if (!layout || !layout.getSgLayout())
return failure();

SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;

size_t numVariants = operands.empty() ? 0 : operands.front().size();

if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
return operandVec.size() != numVariants;
}))
return failure();

SmallVector<Value> newResults;
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());

for (size_t i = 0; i < numVariants; ++i) {
SmallVector<Value> opOperands;
for (auto &operandVec : operands)
opOperands.push_back(operandVec[i]);

OperationState state(op->getLoc(), op->getName());
state.addOperands(opOperands);
state.addTypes(newResultType);
// Copy all attributes, but update "layout_result_0" to drop
// sgLayout/sgData
for (auto attr : op->getAttrs()) {
if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue()))
state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
else
state.addAttribute(attr.getName(), attr.getValue());
}
Operation *newOp = rewriter.create(state);
newResults.push_back(newOp->getResult(0));
}

rewriter.replaceOpWithMultiple(op, {newResults});
return success();
}
};

// Handles UnrealizedConversionCastOp generated during
// SCFStructuralTypeConversions (step 1). This op may appear as either a
// target or source materialization for Vector values, e.g.:
Expand Down Expand Up @@ -411,7 +473,8 @@ namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
UnrealizedConversionCastOpPattern>(patterns.getContext());
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>(
patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
Expand Down Expand Up @@ -518,6 +581,30 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});

target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
[=](Operation *op) -> std::optional<bool> {
// Only handle elementwise mappable ops
if (!OpTrait::hasElementwiseMappableTraits(op))
return true;

VectorType resultType =
dyn_cast<VectorType>(op->getResult(0).getType());
if (!resultType)
return true;

// Check if all operands are vectors of the same shape
// TODO: Support other types.
for (Value operand : op->getOperands()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consider the use of llvm::all_equal on op->getOperandTypes()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this loop is equivalent to

if (llvm::any_of(op->getOperandTypes(), [&](Type type) { return type != resultType; }))
  return true; 

VectorType operandType = dyn_cast<VectorType>(operand.getType());
if (!operandType || operandType.getShape() != resultType.getShape()) {
return true;
}
}

xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
return isLegal(layout);
});

target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
[=](UnrealizedConversionCastOp op) {
return llvm::is_contained(existingCastOps, op.getOperation());
Expand Down
164 changes: 164 additions & 0 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s

gpu.module @test_elementwise_ops {
// CHECK-LABEL: unary_ops
gpu.func @unary_ops(%a: memref<24x32xf32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xf32>
// CHECK: math.exp {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
%exp = math.exp %load_a
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xf32>
// CHECK: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
%negf = arith.negf %load_a
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xf32>
gpu.return
}

// CHECK-LABEL: binary_ops
gpu.func @binary_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xf32>
%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xf32>
// CHECK: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x8xf32>
%addf = arith.addf %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xf32>
// CHECK: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x8xf32>
%powf = math.powf %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xf32>
gpu.return
}

// CHECK-LABEL: ternary_ops
gpu.func @ternary_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi1>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi1>
-> !xegpu.tensor_desc<24x32xi1, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xf32>
%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xf32>
%load_c = xegpu.load_nd %tdesc_c
: !xegpu.tensor_desc<24x32xi1, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xi1>
// CHECK: arith.select {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x8xi1>, vector<12x8xf32>
%select = arith.select %load_c, %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xi1>, vector<24x32xf32>
// CHECK: math.fma {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x8xf32>
%fma = math.fma %load_a, %load_b, %load_a
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xf32>
gpu.return
}

// CHECK-LABEL: type_conversion_ops
gpu.func @type_conversion_ops(%a: memref<24x32xf32>, %b: memref<24x32xi32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xi32>
-> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xf32>
%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xi32>
// CHECK: arith.truncf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x8xf32> to vector<12x8xf16>
%truncf = arith.truncf %load_a
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xf32> to vector<24x32xf16>
// CHECK: arith.bitcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x8xi32> to vector<12x8xf32>
%bitcast = arith.bitcast %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xi32> to vector<24x32xf32>
gpu.return
}

// CHECK-LABEL: comparison_ops
gpu.func @comparison_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
-> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
-> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xf32>
%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xf32>
%load_c = xegpu.load_nd %tdesc_c
: !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xi32>
%load_d = xegpu.load_nd %tdesc_d
: !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-> vector<24x32xi32>
// CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x8xf32>
%cmpf = arith.cmpf ult, %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xf32>
// CHECK: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x8xi32>
%cmpi = arith.cmpi eq, %load_c, %load_d
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
: vector<24x32xi32>
gpu.return
}

// 1 to N decomposition of elementwise operations
// CHECK-LABEL: elementwise_ops_rr_assignment
gpu.func @elementwise_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-> vector<24x32xf32>
%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-> vector<24x32xf32>
// CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
// CHECK-SAME-COUNT-12: : vector<2x2xf32>
// CHECK-NOT: arith.negf
%negf = arith.negf %load_a
{layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
: vector<24x32xf32>
// CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
// CHECK-SAME-COUNT-12: : vector<2x2xf32>
// CHECK-NOT: math.powf
%powf = math.powf %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
: vector<24x32xf32>
gpu.return
}
}