@@ -277,22 +277,13 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
277277 descOp, " the tensor descriptor lacks layout attribute" );
278278
279279 SmallVector<size_t > newRetIndices;
280- SmallVector<Value> newYieldValues;
281- SmallVector<Type> newYieldTypes;
282-
283- for (Value operand : descOp->getOperands ()) {
284- newYieldValues.push_back (operand);
285- newYieldTypes.push_back (operand.getType ());
286- }
287280 rewriter.setInsertionPoint (warpOp);
288281 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
289- rewriter, warpOp, /* new yieled values = */ newYieldValues ,
290- /* new yielded types = */ newYieldTypes , newRetIndices);
282+ rewriter, warpOp, /* new yieled values = */ descOp-> getOperands () ,
283+ /* new yielded types = */ descOp. getOperandTypes () , newRetIndices);
291284
292- SmallVector<Value> newDescOperands;
293- for (size_t i : newRetIndices) {
294- newDescOperands.push_back (newWarpOp.getResult (i));
295- }
285+ SmallVector<Value> newDescOperands = llvm::map_to_vector (
286+ newRetIndices, [&](size_t i) { return newWarpOp.getResult (i); });
296287 rewriter.setInsertionPointAfter (newWarpOp);
297288 xegpu::TensorDescType distributedTensorDescTy =
298289 descOp.getType ().dropLayouts (); // Distributed tensor descriptor type
@@ -696,39 +687,30 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
696687 warpOp, " warp result is not a xegpu::UpdateNdOffset op" );
697688 auto updateOp = operand->get ().getDefiningOp <xegpu::UpdateNdOffsetOp>();
698689 unsigned operandIdx = operand->getOperandNumber ();
699- // new update op does not have layout attribute.
700- xegpu::TensorDescType newTensorDescTy =
701- updateOp.getTensorDescType ().dropLayouts ();
702690
703- SmallVector<Value, 3 > newYieldValues;
704- SmallVector<Type, 3 > newYieldTypes;
705- for (Value operand : updateOp->getOperands ()) {
706- newYieldValues.push_back (operand);
707- if (isa<xegpu::TensorDescType>(operand.getType ())) {
708- newYieldTypes.push_back (newTensorDescTy);
709- } else {
710- newYieldTypes.push_back (operand.getType ());
711- }
712- }
713691 SmallVector<size_t > newRetIndices;
714692 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
715- rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
693+ rewriter, warpOp, updateOp->getOperands (), updateOp.getOperandTypes (),
694+ newRetIndices);
716695 rewriter.setInsertionPointAfter (newWarpOp);
717- SmallVector<Value> newUpdateOperands;
718- for (size_t i : newRetIndices) {
719- // For the tensor descriptor operand, the layout attribute is dropped
720- // after distribution. Types needs to be resolved in this case.
721- if (isa<xegpu::TensorDescType>(newWarpOp.getResult (i).getType ())) {
722- newUpdateOperands.push_back (resolveDistributedTy (
723- newWarpOp.getResult (i), newTensorDescTy, rewriter));
724- } else {
725- newUpdateOperands.push_back (newWarpOp.getResult (i));
726- }
727- }
696+ // new update op does not have layout attribute.
697+ xegpu::TensorDescType distributedTensorDescTy =
698+ updateOp.getTensorDescType ().dropLayouts ();
699+ SmallVector<Value> newUpdateOperands =
700+ llvm::map_to_vector (newRetIndices, [&](size_t i) {
701+ // For the tensor descriptor operand, the layout attribute is
702+ // dropped after distribution. Types needs to be resolved in this
703+ // case.
704+ if (isa<xegpu::TensorDescType>(newWarpOp.getResult (i).getType ())) {
705+ return resolveDistributedTy (newWarpOp.getResult (i),
706+ distributedTensorDescTy, rewriter);
707+ }
708+ return newWarpOp.getResult (i);
709+ });
728710 // Create a new update op outside the warp op.
729711 auto newUpdateOp = xegpu::UpdateNdOffsetOp::create (
730- rewriter, newWarpOp.getLoc (), newTensorDescTy, newUpdateOperands ,
731- updateOp->getAttrs ());
712+ rewriter, newWarpOp.getLoc (), distributedTensorDescTy ,
713+ newUpdateOperands, updateOp->getAttrs ());
732714 xegpu::removeLayoutAttrs (newUpdateOp);
733715 Value distributedVal = newWarpOp.getResult (operandIdx);
734716 // Resolve the distributed type with the original type.
0 commit comments