Skip to content

Commit 6425961

Browse files
committed
add support for nD
1 parent b3ba670 commit 6425961

File tree

2 files changed

+16
-25
lines changed

2 files changed

+16
-25
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -77,19 +77,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
7777
const SmallVector<OpFoldResult> &originalOffsets,
7878
const SmallVector<Value> &localOffset,
7979
const SmallVector<int64_t> &distUnitBaseAddr) const {
80-
81-
Value constOffsetX =
82-
rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[0]);
83-
Value constOffsetY =
84-
rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[1]);
85-
86-
Value offsetX =
87-
rewriter.createOrFold<index::AddOp>(loc, localOffset[0], constOffsetX);
88-
Value offsetY =
89-
rewriter.createOrFold<index::AddOp>(loc, localOffset[1], constOffsetY);
90-
91-
size_t lastDimIndex = originalOffsets.size() - 1;
92-
size_t secondLastDimIndex = lastDimIndex - 1;
80+
assert(localOffset.size() == distUnitBaseAddr.size() &&
81+
"localOffset and distUnitBaseAddr must have the same rank");
9382

9483
// Convert originalOffsets to Value
9584
auto getValueFromOpFoldResult = [&](OpFoldResult ofr) -> Value {
@@ -102,18 +91,20 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
10291
llvm_unreachable("Unsupported OpFoldResult kind");
10392
};
10493

105-
Value origOffsetX =
106-
getValueFromOpFoldResult(originalOffsets[secondLastDimIndex]);
107-
Value origOffsetY = getValueFromOpFoldResult(originalOffsets[lastDimIndex]);
108-
Value globalOffsetX =
109-
rewriter.createOrFold<index::AddOp>(loc, origOffsetX, offsetX);
110-
Value globalOffsetY =
111-
rewriter.createOrFold<index::AddOp>(loc, origOffsetY, offsetY);
112-
11394
SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
11495
originalOffsets.end());
115-
globalOffsets[secondLastDimIndex] = globalOffsetX;
116-
globalOffsets[lastDimIndex] = globalOffsetY;
96+
size_t rank = localOffset.size();
97+
for (size_t i = 0; i < rank; ++i) {
98+
size_t dimIdx = originalOffsets.size() - rank + i;
99+
Value constOffset =
100+
rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
101+
Value offset =
102+
rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
103+
Value origOffset = getValueFromOpFoldResult(originalOffsets[dimIdx]);
104+
Value globalOffset =
105+
rewriter.createOrFold<index::AddOp>(loc, origOffset, offset);
106+
globalOffsets[dimIdx] = globalOffset;
107+
}
117108

118109
return globalOffsets;
119110
}
@@ -283,7 +274,7 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
283274
tmpC = rewriter.create<xegpu::DpasOp>(
284275
loc, resTy, operands,
285276
llvm::ArrayRef<NamedAttribute>(
286-
{"layout", originalLayout.dropSgLayoutAndData()}));
277+
{"layout_result_0", originalLayout.dropSgLayoutAndData()}));
287278
newDpasOps.push_back(tmpC);
288279
}
289280
}

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
9090
// CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
9191
// CHECK-SAME: -> vector<8x12xf32>
9292
// CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
93-
// CHECK-SAME: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
93+
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
9494
// CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
9595
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
9696
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>

0 commit comments

Comments
 (0)