@@ -277,22 +277,13 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
277
277
descOp, " the tensor descriptor lacks layout attribute" );
278
278
279
279
SmallVector<size_t > newRetIndices;
280
- SmallVector<Value> newYieldValues;
281
- SmallVector<Type> newYieldTypes;
282
-
283
- for (Value operand : descOp->getOperands ()) {
284
- newYieldValues.push_back (operand);
285
- newYieldTypes.push_back (operand.getType ());
286
- }
287
280
rewriter.setInsertionPoint (warpOp);
288
281
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
289
- rewriter, warpOp, /* new yieled values = */ newYieldValues ,
290
- /* new yielded types = */ newYieldTypes , newRetIndices);
282
+ rewriter, warpOp, /* new yieled values = */ descOp-> getOperands () ,
283
+ /* new yielded types = */ descOp. getOperandTypes () , newRetIndices);
291
284
292
- SmallVector<Value> newDescOperands;
293
- for (size_t i : newRetIndices) {
294
- newDescOperands.push_back (newWarpOp.getResult (i));
295
- }
285
+ SmallVector<Value> newDescOperands = llvm::map_to_vector (
286
+ newRetIndices, [&](size_t i) { return newWarpOp.getResult (i); });
296
287
rewriter.setInsertionPointAfter (newWarpOp);
297
288
xegpu::TensorDescType distributedTensorDescTy =
298
289
descOp.getType ().dropLayouts (); // Distributed tensor descriptor type
@@ -696,39 +687,30 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
696
687
warpOp, " warp result is not a xegpu::UpdateNdOffset op" );
697
688
auto updateOp = operand->get ().getDefiningOp <xegpu::UpdateNdOffsetOp>();
698
689
unsigned operandIdx = operand->getOperandNumber ();
699
- // new update op does not have layout attribute.
700
- xegpu::TensorDescType newTensorDescTy =
701
- updateOp.getTensorDescType ().dropLayouts ();
702
690
703
- SmallVector<Value, 3 > newYieldValues;
704
- SmallVector<Type, 3 > newYieldTypes;
705
- for (Value operand : updateOp->getOperands ()) {
706
- newYieldValues.push_back (operand);
707
- if (isa<xegpu::TensorDescType>(operand.getType ())) {
708
- newYieldTypes.push_back (newTensorDescTy);
709
- } else {
710
- newYieldTypes.push_back (operand.getType ());
711
- }
712
- }
713
691
SmallVector<size_t > newRetIndices;
714
692
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
715
- rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
693
+ rewriter, warpOp, updateOp->getOperands (), updateOp.getOperandTypes (),
694
+ newRetIndices);
716
695
rewriter.setInsertionPointAfter (newWarpOp);
717
- SmallVector<Value> newUpdateOperands;
718
- for (size_t i : newRetIndices) {
719
- // For the tensor descriptor operand, the layout attribute is dropped
720
- // after distribution. Types needs to be resolved in this case.
721
- if (isa<xegpu::TensorDescType>(newWarpOp.getResult (i).getType ())) {
722
- newUpdateOperands.push_back (resolveDistributedTy (
723
- newWarpOp.getResult (i), newTensorDescTy, rewriter));
724
- } else {
725
- newUpdateOperands.push_back (newWarpOp.getResult (i));
726
- }
727
- }
696
+ // new update op does not have layout attribute.
697
+ xegpu::TensorDescType distributedTensorDescTy =
698
+ updateOp.getTensorDescType ().dropLayouts ();
699
+ SmallVector<Value> newUpdateOperands =
700
+ llvm::map_to_vector (newRetIndices, [&](size_t i) {
701
+ // For the tensor descriptor operand, the layout attribute is
702
+ // dropped after distribution. Types needs to be resolved in this
703
+ // case.
704
+ if (isa<xegpu::TensorDescType>(newWarpOp.getResult (i).getType ())) {
705
+ return resolveDistributedTy (newWarpOp.getResult (i),
706
+ distributedTensorDescTy, rewriter);
707
+ }
708
+ return newWarpOp.getResult (i);
709
+ });
728
710
// Create a new update op outside the warp op.
729
711
auto newUpdateOp = xegpu::UpdateNdOffsetOp::create (
730
- rewriter, newWarpOp.getLoc (), newTensorDescTy, newUpdateOperands ,
731
- updateOp->getAttrs ());
712
+ rewriter, newWarpOp.getLoc (), distributedTensorDescTy ,
713
+ newUpdateOperands, updateOp->getAttrs ());
732
714
xegpu::removeLayoutAttrs (newUpdateOp);
733
715
Value distributedVal = newWarpOp.getResult (operandIdx);
734
716
// Resolve the distributed type with the original type.
0 commit comments