Skip to content

Commit 073bd22

Browse files
committed
fix bug
1 parent 581ba1c commit 073bd22

File tree

1 file changed

+22
-40
lines changed

1 file changed

+22
-40
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -277,22 +277,13 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
277277
descOp, "the tensor descriptor lacks layout attribute");
278278

279279
SmallVector<size_t> newRetIndices;
280-
SmallVector<Value> newYieldValues;
281-
SmallVector<Type> newYieldTypes;
282-
283-
for (Value operand : descOp->getOperands()) {
284-
newYieldValues.push_back(operand);
285-
newYieldTypes.push_back(operand.getType());
286-
}
287280
rewriter.setInsertionPoint(warpOp);
288281
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
289-
rewriter, warpOp, /* new yieled values = */ newYieldValues,
290-
/* new yielded types = */ newYieldTypes, newRetIndices);
282+
rewriter, warpOp, /* new yieled values = */ descOp->getOperands(),
283+
/* new yielded types = */ descOp.getOperandTypes(), newRetIndices);
291284

292-
SmallVector<Value> newDescOperands;
293-
for (size_t i : newRetIndices) {
294-
newDescOperands.push_back(newWarpOp.getResult(i));
295-
}
285+
SmallVector<Value> newDescOperands = llvm::map_to_vector(
286+
newRetIndices, [&](size_t i) { return newWarpOp.getResult(i); });
296287
rewriter.setInsertionPointAfter(newWarpOp);
297288
xegpu::TensorDescType distributedTensorDescTy =
298289
descOp.getType().dropLayouts(); // Distributed tensor descriptor type
@@ -696,39 +687,30 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
696687
warpOp, "warp result is not a xegpu::UpdateNdOffset op");
697688
auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
698689
unsigned operandIdx = operand->getOperandNumber();
699-
// new update op does not have layout attribute.
700-
xegpu::TensorDescType newTensorDescTy =
701-
updateOp.getTensorDescType().dropLayouts();
702690

703-
SmallVector<Value, 3> newYieldValues;
704-
SmallVector<Type, 3> newYieldTypes;
705-
for (Value operand : updateOp->getOperands()) {
706-
newYieldValues.push_back(operand);
707-
if (isa<xegpu::TensorDescType>(operand.getType())) {
708-
newYieldTypes.push_back(newTensorDescTy);
709-
} else {
710-
newYieldTypes.push_back(operand.getType());
711-
}
712-
}
713691
SmallVector<size_t> newRetIndices;
714692
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
715-
rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
693+
rewriter, warpOp, updateOp->getOperands(), updateOp.getOperandTypes(),
694+
newRetIndices);
716695
rewriter.setInsertionPointAfter(newWarpOp);
717-
SmallVector<Value> newUpdateOperands;
718-
for (size_t i : newRetIndices) {
719-
// For the tensor descriptor operand, the layout attribute is dropped
720-
// after distribution. Types needs to be resolved in this case.
721-
if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
722-
newUpdateOperands.push_back(resolveDistributedTy(
723-
newWarpOp.getResult(i), newTensorDescTy, rewriter));
724-
} else {
725-
newUpdateOperands.push_back(newWarpOp.getResult(i));
726-
}
727-
}
696+
// new update op does not have layout attribute.
697+
xegpu::TensorDescType distributedTensorDescTy =
698+
updateOp.getTensorDescType().dropLayouts();
699+
SmallVector<Value> newUpdateOperands =
700+
llvm::map_to_vector(newRetIndices, [&](size_t i) {
701+
// For the tensor descriptor operand, the layout attribute is
702+
// dropped after distribution. Types needs to be resolved in this
703+
// case.
704+
if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
705+
return resolveDistributedTy(newWarpOp.getResult(i),
706+
distributedTensorDescTy, rewriter);
707+
}
708+
return newWarpOp.getResult(i);
709+
});
728710
// Create a new update op outside the warp op.
729711
auto newUpdateOp = xegpu::UpdateNdOffsetOp::create(
730-
rewriter, newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
731-
updateOp->getAttrs());
712+
rewriter, newWarpOp.getLoc(), distributedTensorDescTy,
713+
newUpdateOperands, updateOp->getAttrs());
732714
xegpu::removeLayoutAttrs(newUpdateOp);
733715
Value distributedVal = newWarpOp.getResult(operandIdx);
734716
// Resolve the distributed type with the original type.

0 commit comments

Comments
 (0)