Skip to content

Commit 06884d0

Browse files
authored
[mlir][xegpu] Bug fix in UpdateNdOffset distribution. (#150545)
Reason is UpdateNdOffset source operand not retaining the layouts when it is yielded by the warp op. `warp_execute_on_lane0` op expects that TensorDesc type is unchanged during distribution out of its region. we use UnrealizedCasts to reconcile this mismatch outside the warpOp (via `resolveDistributedTy`)
1 parent d08c297 commit 06884d0

File tree

2 files changed

+45
-41
lines changed

2 files changed

+45
-41
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -277,22 +277,13 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
277277
descOp, "the tensor descriptor lacks layout attribute");
278278

279279
SmallVector<size_t> newRetIndices;
280-
SmallVector<Value> newYieldValues;
281-
SmallVector<Type> newYieldTypes;
282-
283-
for (Value operand : descOp->getOperands()) {
284-
newYieldValues.push_back(operand);
285-
newYieldTypes.push_back(operand.getType());
286-
}
287280
rewriter.setInsertionPoint(warpOp);
288281
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
289-
rewriter, warpOp, /* new yieled values = */ newYieldValues,
290-
/* new yielded types = */ newYieldTypes, newRetIndices);
282+
rewriter, warpOp, /* new yieled values = */ descOp->getOperands(),
283+
/* new yielded types = */ descOp.getOperandTypes(), newRetIndices);
291284

292-
SmallVector<Value> newDescOperands;
293-
for (size_t i : newRetIndices) {
294-
newDescOperands.push_back(newWarpOp.getResult(i));
295-
}
285+
SmallVector<Value> newDescOperands = llvm::map_to_vector(
286+
newRetIndices, [&](size_t i) { return newWarpOp.getResult(i); });
296287
rewriter.setInsertionPointAfter(newWarpOp);
297288
xegpu::TensorDescType distributedTensorDescTy =
298289
descOp.getType().dropLayouts(); // Distributed tensor descriptor type
@@ -696,39 +687,30 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
696687
warpOp, "warp result is not a xegpu::UpdateNdOffset op");
697688
auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
698689
unsigned operandIdx = operand->getOperandNumber();
699-
// new update op does not have layout attribute.
700-
xegpu::TensorDescType newTensorDescTy =
701-
updateOp.getTensorDescType().dropLayouts();
702690

703-
SmallVector<Value, 3> newYieldValues;
704-
SmallVector<Type, 3> newYieldTypes;
705-
for (Value operand : updateOp->getOperands()) {
706-
newYieldValues.push_back(operand);
707-
if (isa<xegpu::TensorDescType>(operand.getType())) {
708-
newYieldTypes.push_back(newTensorDescTy);
709-
} else {
710-
newYieldTypes.push_back(operand.getType());
711-
}
712-
}
713691
SmallVector<size_t> newRetIndices;
714692
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
715-
rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
693+
rewriter, warpOp, updateOp->getOperands(), updateOp.getOperandTypes(),
694+
newRetIndices);
716695
rewriter.setInsertionPointAfter(newWarpOp);
717-
SmallVector<Value> newUpdateOperands;
718-
for (size_t i : newRetIndices) {
719-
// For the tensor descriptor operand, the layout attribute is dropped
720-
// after distribution. Types needs to be resolved in this case.
721-
if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
722-
newUpdateOperands.push_back(resolveDistributedTy(
723-
newWarpOp.getResult(i), newTensorDescTy, rewriter));
724-
} else {
725-
newUpdateOperands.push_back(newWarpOp.getResult(i));
726-
}
727-
}
696+
// new update op does not have layout attribute.
697+
xegpu::TensorDescType distributedTensorDescTy =
698+
updateOp.getTensorDescType().dropLayouts();
699+
SmallVector<Value> newUpdateOperands =
700+
llvm::map_to_vector(newRetIndices, [&](size_t i) {
701+
// For the tensor descriptor operand, the layout attribute is
702+
// dropped after distribution. Types needs to be resolved in this
703+
// case.
704+
if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
705+
return resolveDistributedTy(newWarpOp.getResult(i),
706+
distributedTensorDescTy, rewriter);
707+
}
708+
return newWarpOp.getResult(i);
709+
});
728710
// Create a new update op outside the warp op.
729711
auto newUpdateOp = xegpu::UpdateNdOffsetOp::create(
730-
rewriter, newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
731-
updateOp->getAttrs());
712+
rewriter, newWarpOp.getLoc(), distributedTensorDescTy,
713+
newUpdateOperands, updateOp->getAttrs());
732714
xegpu::removeLayoutAttrs(newUpdateOp);
733715
Value distributedVal = newWarpOp.getResult(operandIdx);
734716
// Resolve the distributed type with the original type.

mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -xegpu-subgroup-distribute -canonicalize -cse -split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt -xegpu-subgroup-distribute -allow-unregistered-dialect -canonicalize -cse -split-input-file %s | FileCheck %s
22

33
// CHECK-LABEL: gpu.func @store_nd_1d
44
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>) {
@@ -265,6 +265,28 @@ gpu.module @test {
265265
}
266266
}
267267

268+
// -----
269+
// Explicitly check that update_nd_offset op's source retain layout when yielded from the warp op (PR150545)
270+
// CHECK-LABEL: gpu.func @check_update_nd_offset_distributed_tensor_desc
271+
// CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] ->
272+
// CHECK-SAME: (!xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
273+
// CHECK: %[[T0:.*]] = "some_op"() : () -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
274+
// CHECK: gpu.yield %[[T0]] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
275+
// CHECK: }
276+
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]] :
277+
// CHECK-SAME: !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf32> {resolve_simt_type_mismatch}
278+
// CHECK: xegpu.update_nd_offset %[[T1]], [%{{.*}}] : !xegpu.tensor_desc<16x16xf32>
279+
gpu.module @test {
280+
gpu.func @check_update_nd_offset_distributed_tensor_desc() {
281+
%c32 = arith.constant 32 : index
282+
%cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<1.000000e+00> : vector<16x16xf32>
283+
%0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
284+
%1 = xegpu.update_nd_offset %0, [%c32, %c32] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
285+
xegpu.store_nd %cst, %1 : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
286+
gpu.return
287+
}
288+
}
289+
268290
// -----
269291
// CHECK-LABEL: gpu.func @prefetch_1d
270292
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {

0 commit comments

Comments
 (0)