Skip to content
Open
85 changes: 70 additions & 15 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,27 +280,82 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
FailureOr<SmallVector<Value>>
LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {

// TODO: handle order attribute
auto hasDefaultOrder = [&]() {
DenseI32ArrayAttr order = getOrder();
return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
llvm::reverse(order.asArrayRef())));
};
if (!hasDefaultOrder())
return mlir::emitError(loc, "order attribute is currently not supported.");
SmallVector<int64_t> layout;
SmallVector<int64_t> sgLayoutInt;
if (isForWorkgroup()) {
layout = getEffectiveSgLayoutAsInt();
sgLayoutInt = getEffectiveSgLayoutAsInt();
} else if (isForSubgroup()) {
layout = getEffectiveLaneLayoutAsInt();
sgLayoutInt = getEffectiveLaneLayoutAsInt();
} else {
return failure();
}
auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value {
return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
});

return affine::delinearizeIndex(builder, loc, linearId, dims);
DenseI32ArrayAttr orderAttr = getOrder();

// Handle order attribute
SmallVector<int64_t> order;
if (orderAttr && !orderAttr.empty()) {
order = llvm::to_vector(
llvm::map_range(orderAttr.asArrayRef(),
[](int32_t idx) { return static_cast<int64_t>(idx); }));
} else {
// Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
order = llvm::to_vector(
llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
}

if (order.size() != sgLayoutInt.size()) {
return failure();
}

SmallVector<Value> result(sgLayoutInt.size());
Value remaining = linearId;

/// Process dimensions in the order they appear in the order array
/// The first dimension in order is the fastest-changing
///
/// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
///
/// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx],
/// result=[?,?,?]
///
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consider add comment: dimIdx = order[i], dimSize = sgLayout[dimIdx]

/// i=0 (process columns, dimIdx=2, dimSize=4):
/// result[2] = 22 % 4 = 2 (column coordinate)
/// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed)
///
/// i=1 (process rows, dimIdx=1, dimSize=4):
/// result[1] = 5 % 4 = 1 (row coordinate)
/// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed)
///
/// i=2 (process layers, dimIdx=0, dimSize=2):
/// result[0] = 1 % 2 = 1 (layer coordinate)
/// (no remaining update - last iteration)
///
/// Final result: [1,1,2] = Layer 1, Row 1, Column 2
for (size_t i = 0; i < order.size(); ++i) {
int64_t dimIdx = order[i];
int64_t dimSize = sgLayoutInt[dimIdx];

Value dimSizeVal =
builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);

/// Extract the coordinate for this dimension using modulo operation
/// This gives us "how far within this dimension" we are
/// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
/// this dimension)
result[dimIdx] =
builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal);

/// Update remaining for the next dimension by removing what we've already
/// processed. Division tells us "how many complete groups of this dimension
/// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
/// completed 5 groups of 4) Skip this for the last iteration since there's
/// no next dimension to process
if (i < order.size() - 1) {
remaining =
builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal);
}
}
return result;
}

/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
Expand Down
81 changes: 69 additions & 12 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,70 @@ struct WgToSgMultiDimReductionOp
}
};

// This pattern transforms vector.transpose ops to work at subgroup level.
struct WgToSgVectorTransposeOp
: public OpConversionPattern<vector::TransposeOp> {
using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType resultType = op.getResultVectorType();

ArrayRef<int64_t> wgShape = resultType.getShape();
xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(op.getResult());
if (!layout || !layout.isForWorkgroup())
return failure();

xegpu::DistributeLayoutAttr sourceLayout =
xegpu::getDistributeLayoutAttr(op.getVector());
if (!sourceLayout || !sourceLayout.isForWorkgroup())
return failure();

SmallVector<int64_t> sourceSgLayout =
sourceLayout.getEffectiveSgLayoutAsInt();
SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
DenseI32ArrayAttr resultOrder = layout.getOrder();

if (!sourceOrder || !resultOrder) {
return rewriter.notifyMatchFailure(
op, "Both source and result must have order attributes");
}

ArrayRef<int64_t> permutation = op.getPermutation();
size_t permutationSize = permutation.size();
if (sourceSgLayout.size() != permutationSize ||
resultSgLayout.size() != permutationSize) {
return rewriter.notifyMatchFailure(
op, "Layouts and permutation must have the same rank");
}

// Check that sgLayout, sgData & order are properly transposed for source
// and result
if (!layout.isTransposeOf(sourceLayout, permutation))
return rewriter.notifyMatchFailure(
op, "Result layout is not a valid transpose of source layout "
"according to permutation");

SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());
SmallVector<Value> newTransposeOps;
for (auto src : adaptor.getVector()) {
auto newTranspose = vector::TransposeOp::create(
rewriter, op.getLoc(), newResultType, src, permutation);
xegpu::setDistributeLayoutAttr(newTranspose->getResult(0),
layout.dropSgLayoutAndData());
newTransposeOps.push_back(newTranspose.getResult());
}

rewriter.replaceOpWithMultiple(op, {newTransposeOps});
return success();
}
};

} // namespace

namespace mlir {
Expand All @@ -1233,7 +1297,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
WgToSgMultiDimReductionOp>(patterns.getContext());
WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>(
patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
Expand Down Expand Up @@ -1360,7 +1425,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});

target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
vector::TransposeOp, vector::BroadcastOp,
vector::MultiDimReductionOp>(
[=](Operation *op) -> bool {
// Check for either a SliceAttr or LayoutAttr on the result.
auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
Expand All @@ -1379,16 +1446,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});

target.addDynamicallyLegalOp<vector::BroadcastOp>(
[=](vector::BroadcastOp op) -> bool {
return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
});

target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
[=](vector::MultiDimReductionOp op) -> bool {
return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
});

target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
[=](xegpu::ConvertLayoutOp op) -> bool {
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
Expand Down
42 changes: 22 additions & 20 deletions mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,16 @@ gpu.module @xevm_module{

// -----
// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) {
// CHECK: %[[LAYOUT_X:.*]] = arith.constant 8 : index
// CHECK: %[[LAYOUT_Y:.*]] = arith.constant 2 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_Y]], %[[LAYOUT_Y]]
// CHECK: %[[LANE_X_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[LAYOUT_X]]
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
// CHECK: %[[REMU1:.*]] = index.remu %[[LANE_ID]], %[[C8]]
// CHECK: %[[DIVU:.*]] = index.divu %[[LANE_ID]], %[[C8]]
// CHECK: %[[REMU2:.*]] = index.remu %[[DIVU]], %[[C2]]
// CHECK: %[[REMU3:.*]] = index.remu %[[REMU2]], %[[C2]]
// CHECK: %[[REMU4:.*]] = index.remu %[[REMU1]], %[[C8]]
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[REMU3]], %[[REMU4]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[REMU3]], %[[REMU4]]] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
gpu.module @xevm_module{
gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
%c0 = arith.constant 0 : index
Expand All @@ -288,19 +289,20 @@ gpu.module @xevm_module{

// -----
// CHECK-LABEL: gpu.func @load_store_matrix_2({{.*}}) {
// CHECK: %[[DIST_UNIT_HEIGHT_X:.*]] = arith.constant 4 : index
// CHECK: %[[DIST_UNIT_HEIGHT_Y:.*]] = arith.constant 8 : index
// CHECK: %[[LANE_DATA_Y:.*]] = arith.constant 2 : index
// CHECK: %[[USER_OFFSET_X:.*]] = arith.constant 1 : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
// CHECK: %[[LANE_Y_OFFSET_1:.*]] = index.mul %[[DELINEARIZED_LANE_Y]], %[[LANE_DATA_Y]]
// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[LANE_Y_OFFSET_1]], %[[DIST_UNIT_HEIGHT_Y]]
// CHECK: %[[LANE_X_OFFSET_1:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[DIST_UNIT_HEIGHT_X]]
// CHECK: %[[LANE_X_OFFSET:.*]] = index.add %[[LANE_X_OFFSET_1]], %[[USER_OFFSET_X]]
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
// CHECK: %[[REMU1:.*]] = index.remu %[[LANE_ID]], %[[C4]]
// CHECK: %[[DIVU:.*]] = index.divu %[[LANE_ID]], %[[C4]]
// CHECK: %[[REMU2:.*]] = index.remu %[[DIVU]], %[[C4]]
// CHECK: %[[MUL:.*]] = index.mul %[[REMU2]], %[[C2]]
// CHECK: %[[REMU3:.*]] = index.remu %[[MUL]], %[[C8]]
// CHECK: %[[REMU4:.*]] = index.remu %[[REMU1]], %[[C4]]
// CHECK: %[[ADD:.*]] = index.add %[[REMU4]], %[[C1]]
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[REMU3]], %[[ADD]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[REMU3]], %[[ADD]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
gpu.module @xevm_module{
gpu.func @load_store_matrix_2(%arg0: !xegpu.mem_desc<32x32xf32>) {
%c0 = arith.constant 0 : index
Expand Down
39 changes: 19 additions & 20 deletions mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
Original file line number Diff line number Diff line change
@@ -1,33 +1,32 @@
// RUN: mlir-opt --test-xegpu-layout-interface --cse -split-input-file %s | FileCheck %s

//CHECk: #map = affine_map<()[s0] -> (s0 floordiv 8)>
gpu.module @test {
gpu.func @slice_attr() -> vector<128xindex> {
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
//CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
//CHECK: [[c32:%.+]] = arith.constant 32 : index
//CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
//CHECK: [[c128:%.+]] = arith.constant 128 : index
//CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
//CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
//CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
// CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C8:.*]]
// CHECK-DAG: %[[REMU:.*]] = index.remu %[[DIVU]], %[[C4:.*]]
// CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C32:.*]]
// CHECK-DAG: %[[MOD:.*]] = index.remu %[[MUL]], %[[C128:.*]]
// CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
// CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
%step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
gpu.return %step : vector<128xindex>
}

gpu.func @nested_slice_attr() -> vector<128xindex> {
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
//CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
//CHECK: [[c32:%.+]] = arith.constant 32 : index
//CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
//CHECK: [[c128:%.+]] = arith.constant 128 : index
//CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
//CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
//CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
// CHECK-DAG: %[[DIVU1:.*]] = index.divu %[[SGID]], %[[C1:.*]]
// CHECK-DAG: %[[DIVU2:.*]] = index.divu %[[DIVU1]], %[[C8:.*]]
// CHECK-DAG: %[[REMU:.*]] = index.remu %[[DIVU2]], %[[C4:.*]]
// CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C32:.*]]
// CHECK-DAG: %[[MOD:.*]] = index.remu %[[MUL]], %[[C128:.*]]
// CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
// CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
%0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 1], sg_data = [32, 32, 1]>, dims = [2]>, dims = [1]>} : vector<128xindex>
gpu.return %0 : vector<128xindex>
}

}
}

6 changes: 2 additions & 4 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,12 @@ gpu.module @test_elementwise_ops {
%load_b = xegpu.load_nd %tdesc_b
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-> vector<24x32xf32>
// CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
// CHECK-SAME-COUNT-12: : vector<2x2xf32>
// CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
// CHECK-NOT: arith.negf
%negf = arith.negf %load_a
{layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
: vector<24x32xf32>
// CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
// CHECK-SAME-COUNT-12: : vector<2x2xf32>
// CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
// CHECK-NOT: math.powf
%powf = math.powf %load_a, %load_b
{layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
Expand Down
Loading