@@ -135,19 +135,6 @@ static Value resolveDistributedTy(Value orig, T expected,
135
135
return orig;
136
136
}
137
137
138
- // / Helper function to filter out the temporary layout attributes attached
139
- // / during the layout assignment process. These are not needed after going to
140
- // / SIMT.
141
- static SmallVector<NamedAttribute>
142
- removeTemporaryLayoutAttributes (ArrayRef<NamedAttribute> attrs) {
143
- SmallVector<NamedAttribute> newAttrs;
144
- for (NamedAttribute attr : attrs) {
145
- if (!isa<xegpu::LayoutAttr>(attr.getValue ()))
146
- newAttrs.push_back (attr);
147
- }
148
- return newAttrs;
149
- }
150
-
151
138
// / Helper function to check if the layout is packed. Layout is packed if it is
152
139
// / 2D and lane_data[0] != 1 (data packed from col dimension).
153
140
static bool hasPackedLayout (xegpu::LayoutAttr layout) {
@@ -197,9 +184,17 @@ struct MoveFuncBodyToWarpExecuteOnLane0
197
184
return isa<gpu::WarpExecuteOnLane0Op>(op);
198
185
}))
199
186
return failure ();
200
- // Create a new function with the same signature.
187
+ // Create a new function with the same signature and same attributes.
188
+ SmallVector<Type> workgroupAttributionsTypes =
189
+ llvm::map_to_vector (gpuFuncOp.getWorkgroupAttributions (),
190
+ [](BlockArgument arg) { return arg.getType (); });
191
+ SmallVector<Type> privateAttributionsTypes =
192
+ llvm::map_to_vector (gpuFuncOp.getPrivateAttributions (),
193
+ [](BlockArgument arg) { return arg.getType (); });
201
194
auto newGpuFunc = rewriter.create <gpu::GPUFuncOp>(
202
- gpuFuncOp.getLoc (), gpuFuncOp.getName (), gpuFuncOp.getFunctionType ());
195
+ gpuFuncOp.getLoc (), gpuFuncOp.getName (), gpuFuncOp.getFunctionType (),
196
+ workgroupAttributionsTypes, privateAttributionsTypes);
197
+ newGpuFunc->setAttrs (gpuFuncOp->getAttrs ());
203
198
// Create a WarpExecuteOnLane0Op with same arguments and results as the
204
199
// original gpuFuncOp.
205
200
rewriter.setInsertionPointToEnd (&newGpuFunc.getFunctionBody ().front ());
@@ -265,13 +260,13 @@ struct MoveFuncBodyToWarpExecuteOnLane0
265
260
// / ```
266
261
struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
267
262
using gpu::WarpDistributionPattern::WarpDistributionPattern;
268
- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
263
+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
269
264
PatternRewriter &rewriter) const override {
270
265
OpOperand *operand =
271
- getWarpResult (subgroupOp , llvm::IsaPred<xegpu::CreateNdDescOp>);
266
+ getWarpResult (warpOp , llvm::IsaPred<xegpu::CreateNdDescOp>);
272
267
if (!operand)
273
268
return rewriter.notifyMatchFailure (
274
- subgroupOp , " warp result is not a xegpu::CreateNdDesc op" );
269
+ warpOp , " warp result is not a xegpu::CreateNdDesc op" );
275
270
auto descOp = operand->get ().getDefiningOp <xegpu::CreateNdDescOp>();
276
271
unsigned operandIdx = operand->getOperandNumber ();
277
272
@@ -288,9 +283,9 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
288
283
newYieldValues.push_back (operand);
289
284
newYieldTypes.push_back (operand.getType ());
290
285
}
291
- rewriter.setInsertionPoint (subgroupOp );
286
+ rewriter.setInsertionPoint (warpOp );
292
287
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
293
- rewriter, subgroupOp , /* new yieled values = */ newYieldValues,
288
+ rewriter, warpOp , /* new yieled values = */ newYieldValues,
294
289
/* new yielded types = */ newYieldTypes, newRetIndices);
295
290
296
291
SmallVector<Value> newDescOperands;
@@ -347,10 +342,10 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
347
342
// / ```
348
343
struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
349
344
using gpu::WarpDistributionPattern::WarpDistributionPattern;
350
- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
345
+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
351
346
PatternRewriter &rewriter) const override {
352
347
auto yield = cast<gpu::YieldOp>(
353
- subgroupOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
348
+ warpOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
354
349
Operation *lastNode = yield->getPrevNode ();
355
350
auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
356
351
if (!storeOp)
@@ -372,7 +367,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
372
367
373
368
SmallVector<size_t > newRetIndices;
374
369
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
375
- rewriter, subgroupOp ,
370
+ rewriter, warpOp ,
376
371
/* new yielded values = */
377
372
ValueRange{storeOp.getValue (), storeOp.getTensorDesc ()},
378
373
/* new yielded types = */
@@ -403,9 +398,9 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
403
398
resolveDistributedTy (newWarpOp.getResult (newRetIndices[1 ]),
404
399
distributedTensorDescTy, rewriter));
405
400
406
- rewriter.create <xegpu::StoreNdOp>(
407
- newWarpOp.getLoc (), TypeRange{}, newStoreOperands,
408
- removeTemporaryLayoutAttributes (storeOp-> getAttrs ()) );
401
+ auto newStoreOp = rewriter.create <xegpu::StoreNdOp>(
402
+ newWarpOp.getLoc (), TypeRange{}, newStoreOperands, storeOp-> getAttrs ());
403
+ xegpu::removeLayoutAttrs (newStoreOp );
409
404
rewriter.eraseOp (storeOp);
410
405
return success ();
411
406
}
@@ -449,21 +444,22 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
449
444
// / ```
450
445
struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
451
446
using gpu::WarpDistributionPattern::WarpDistributionPattern;
452
- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
447
+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
453
448
PatternRewriter &rewriter) const override {
454
- OpOperand *operand =
455
- getWarpResult (subgroupOp, llvm::IsaPred<xegpu::LoadNdOp>);
449
+ OpOperand *operand = getWarpResult (warpOp, [&](Operation *op) {
450
+ if (!isa<xegpu::LoadNdOp>(op))
451
+ return false ;
452
+ // Make sure the same load op is the last operation in the warp op body.
453
+ // This ensure that load op is not sinked earlier violating any barrier
454
+ // synchronizations.
455
+ auto yield = cast<gpu::YieldOp>(
456
+ warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
457
+ return yield->getPrevNode () == op;
458
+ });
459
+
456
460
if (!operand)
457
461
return rewriter.notifyMatchFailure (
458
- subgroupOp, " warp result is not a xegpu::LoadNd op" );
459
- // Make sure the load op is the last operation in the warp op body. This
460
- // ensure that load op is not sinked earlier violating any barrier
461
- // synchronizations.
462
- auto yield = cast<gpu::YieldOp>(
463
- subgroupOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
464
- Operation *lastNode = yield->getPrevNode ();
465
- if (!dyn_cast_or_null<xegpu::LoadNdOp>(lastNode))
466
- return failure ();
462
+ warpOp, " warp result is not a xegpu::LoadNd op" );
467
463
468
464
auto loadOp = operand->get ().getDefiningOp <xegpu::LoadNdOp>();
469
465
xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType ();
@@ -474,11 +470,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
474
470
475
471
unsigned operandIdx = operand->getOperandNumber ();
476
472
VectorType distributedTypeByWarpOp =
477
- cast<VectorType>(subgroupOp .getResult (operandIdx).getType ());
473
+ cast<VectorType>(warpOp .getResult (operandIdx).getType ());
478
474
479
475
SmallVector<size_t > newRetIndices;
480
476
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
481
- rewriter, subgroupOp ,
477
+ rewriter, warpOp ,
482
478
/* new yielded values = */ loadOp.getTensorDesc (),
483
479
/* new yielded types = */ tensorDescTy, newRetIndices);
484
480
@@ -498,7 +494,8 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
498
494
newWarpOp.getLoc (), loadNdDistValueTyOrFailure.value (),
499
495
resolveDistributedTy (newWarpOp->getResult (newRetIndices[0 ]),
500
496
distributedTensorDescTy, rewriter),
501
- removeTemporaryLayoutAttributes (loadOp->getAttrs ()));
497
+ loadOp->getAttrs ());
498
+ xegpu::removeLayoutAttrs (newLoadOp);
502
499
// Set the packed attribute if the layout requires it.
503
500
newLoadOp.setPacked (hasPackedLayout (layout));
504
501
Value distributedVal = newWarpOp.getResult (operandIdx);
@@ -548,12 +545,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
548
545
// / ```
549
546
struct DpasDistribution final : public gpu::WarpDistributionPattern {
550
547
using gpu::WarpDistributionPattern::WarpDistributionPattern;
551
- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
548
+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
552
549
PatternRewriter &rewriter) const override {
553
- OpOperand *operand =
554
- getWarpResult (subgroupOp, llvm::IsaPred<xegpu::DpasOp>);
550
+ OpOperand *operand = getWarpResult (warpOp, llvm::IsaPred<xegpu::DpasOp>);
555
551
if (!operand)
556
- return rewriter.notifyMatchFailure (subgroupOp ,
552
+ return rewriter.notifyMatchFailure (warpOp ,
557
553
" warp result is not a xegpu::Dpas op" );
558
554
559
555
auto dpasOp = operand->get ().getDefiningOp <xegpu::DpasOp>();
@@ -599,7 +595,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
599
595
// Create a new warp op without the dpas.
600
596
SmallVector<size_t > newRetIndices;
601
597
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
602
- rewriter, subgroupOp , newYieldValues, newYieldTypes, newRetIndices);
598
+ rewriter, warpOp , newYieldValues, newYieldTypes, newRetIndices);
603
599
604
600
FailureOr<VectorType> expectedDistLhsTyOrFailure =
605
601
xegpu::getDistributedVectorType (dpasOp.getLhsType (), layoutA);
@@ -630,14 +626,16 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
630
626
resolveDistributedTy (newWarpOp.getResult (newRetIndices[i]),
631
627
newDpasOperandExpectedTypes[i], rewriter));
632
628
}
633
- Value newDpasOp = rewriter.create <xegpu::DpasOp>(
634
- newWarpOp->getLoc (), distributedResultTy, newDpasOperands,
635
- removeTemporaryLayoutAttributes (dpasOp->getAttrs ()));
629
+ auto newDpasOp =
630
+ rewriter.create <xegpu::DpasOp>(newWarpOp->getLoc (), distributedResultTy,
631
+ newDpasOperands, dpasOp->getAttrs ());
632
+ xegpu::removeLayoutAttrs (newDpasOp);
636
633
Value distributedVal = newWarpOp.getResult (operandIdx);
637
634
// Resolve the output type.
638
- newDpasOp = resolveDistributedTy (
639
- newDpasOp, distResultTypeByWarpOpOrFailure.value (), rewriter);
640
- rewriter.replaceAllUsesWith (distributedVal, newDpasOp);
635
+ Value typeResolved =
636
+ resolveDistributedTy (newDpasOp.getResult (),
637
+ distResultTypeByWarpOpOrFailure.value (), rewriter);
638
+ rewriter.replaceAllUsesWith (distributedVal, typeResolved);
641
639
return success ();
642
640
}
643
641
};
@@ -678,13 +676,13 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
678
676
// / ```
679
677
struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
680
678
using gpu::WarpDistributionPattern::WarpDistributionPattern;
681
- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
679
+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
682
680
PatternRewriter &rewriter) const override {
683
681
OpOperand *operand =
684
- getWarpResult (subgroupOp , llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
682
+ getWarpResult (warpOp , llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
685
683
if (!operand)
686
684
return rewriter.notifyMatchFailure (
687
- subgroupOp , " warp result is not a xegpu::UpdateNdOffset op" );
685
+ warpOp , " warp result is not a xegpu::UpdateNdOffset op" );
688
686
auto updateOp = operand->get ().getDefiningOp <xegpu::UpdateNdOffsetOp>();
689
687
unsigned operandIdx = operand->getOperandNumber ();
690
688
// new update op does not have layout attribute.
@@ -703,7 +701,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
703
701
}
704
702
SmallVector<size_t > newRetIndices;
705
703
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
706
- rewriter, subgroupOp , newYieldValues, newYieldTypes, newRetIndices);
704
+ rewriter, warpOp , newYieldValues, newYieldTypes, newRetIndices);
707
705
rewriter.setInsertionPointAfter (newWarpOp);
708
706
SmallVector<Value> newUpdateOperands;
709
707
for (size_t i : newRetIndices) {
@@ -717,14 +715,15 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
717
715
}
718
716
}
719
717
// Create a new update op outside the warp op.
720
- Value newUpdateOp = rewriter.create <xegpu::UpdateNdOffsetOp>(
718
+ auto newUpdateOp = rewriter.create <xegpu::UpdateNdOffsetOp>(
721
719
newWarpOp.getLoc (), newTensorDescTy, newUpdateOperands,
722
- removeTemporaryLayoutAttributes (updateOp->getAttrs ()));
720
+ updateOp->getAttrs ());
721
+ xegpu::removeLayoutAttrs (newUpdateOp);
723
722
Value distributedVal = newWarpOp.getResult (operandIdx);
724
723
// Resolve the distributed type with the original type.
725
- newUpdateOp =
726
- resolveDistributedTy ( newUpdateOp, distributedVal.getType (), rewriter);
727
- rewriter.replaceAllUsesWith (distributedVal, newUpdateOp );
724
+ Value typeResolved = resolveDistributedTy (
725
+ newUpdateOp. getResult () , distributedVal.getType (), rewriter);
726
+ rewriter.replaceAllUsesWith (distributedVal, typeResolved );
728
727
return success ();
729
728
}
730
729
};
@@ -758,10 +757,10 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
758
757
// / ```
759
758
struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
760
759
using gpu::WarpDistributionPattern::WarpDistributionPattern;
761
- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
760
+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
762
761
PatternRewriter &rewriter) const override {
763
762
auto yield = cast<gpu::YieldOp>(
764
- subgroupOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
763
+ warpOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
765
764
Operation *lastNode = yield->getPrevNode ();
766
765
auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
767
766
if (!prefetchOp)
@@ -775,17 +774,18 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
775
774
SmallVector<Type, 1 > newYieldTypes = {prefetchOp.getTensorDescType ()};
776
775
SmallVector<size_t > newRetIndices;
777
776
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
778
- rewriter, subgroupOp , newYieldValues, newYieldTypes, newRetIndices);
777
+ rewriter, warpOp , newYieldValues, newYieldTypes, newRetIndices);
779
778
// Create a new prefetch op outside the warp op with updated tensor
780
779
// descriptor type. Source tensor descriptor require type resolution.
781
780
xegpu::TensorDescType newTensorDescTy =
782
781
prefetchOp.getTensorDescType ().dropLayouts ();
783
782
rewriter.setInsertionPointAfter (newWarpOp);
784
783
SmallVector<Value> newPrefetchOperands = {resolveDistributedTy (
785
784
newWarpOp.getResult (newRetIndices[0 ]), newTensorDescTy, rewriter)};
786
- rewriter.create <xegpu::PrefetchNdOp>(
787
- newWarpOp.getLoc (), TypeRange{}, newPrefetchOperands,
788
- removeTemporaryLayoutAttributes (prefetchOp->getAttrs ()));
785
+ rewriter.create <xegpu::PrefetchNdOp>(newWarpOp.getLoc (), TypeRange{},
786
+ newPrefetchOperands,
787
+ prefetchOp->getAttrs ());
788
+ xegpu::removeLayoutAttrs (prefetchOp);
789
789
rewriter.eraseOp (prefetchOp);
790
790
return success ();
791
791
}
@@ -795,17 +795,17 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
795
795
// / region. This will simply move the barrier op outside of the warp op.
796
796
struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
797
797
using gpu::WarpDistributionPattern::WarpDistributionPattern;
798
- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp ,
798
+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp ,
799
799
PatternRewriter &rewriter) const override {
800
800
auto yield = cast<gpu::YieldOp>(
801
- subgroupOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
801
+ warpOp .getBodyRegion ().getBlocks ().begin ()->getTerminator ());
802
802
Operation *lastNode = yield->getPrevNode ();
803
803
// The last node must be a gpu::BarrierOp.
804
804
auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
805
805
if (!barrierOp)
806
806
return failure ();
807
807
// Move the barrier op outside of the warp op.
808
- rewriter.setInsertionPointAfter (subgroupOp );
808
+ rewriter.setInsertionPointAfter (warpOp );
809
809
rewriter.create <gpu::BarrierOp>(
810
810
barrierOp.getLoc (), barrierOp->getResultTypes (),
811
811
barrierOp->getOperands (), barrierOp->getAttrs ());
0 commit comments