31
31
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
32
32
#include " llvm/ADT/STLExtras.h"
33
33
#include " llvm/ADT/TypeSwitch.h"
34
+ #include " llvm/Support/DebugLog.h"
34
35
35
36
#define DEBUG_TYPE " vector-to-gpu"
36
- #define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE " ]: " )
37
- #define DBGSNL () (llvm::dbgs() << " \n " )
38
37
39
38
namespace mlir {
40
39
#define GEN_PASS_DEF_CONVERTVECTORTOGPU
@@ -366,7 +365,7 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
366
365
// by all operations.
367
366
if (llvm::any_of (dependentOps, [useNvGpu](Operation *op) {
368
367
if (!supportsMMaMatrixType (op, useNvGpu)) {
369
- LLVM_DEBUG ( DBGS ( ) << " cannot convert op: " << *op << " \n " ) ;
368
+ LDBG ( ) << " cannot convert op: " << *op;
370
369
return true ;
371
370
}
372
371
return false ;
@@ -548,7 +547,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
548
547
std::optional<int64_t > stride =
549
548
getStaticallyKnownRowStride (op.getShapedType ());
550
549
if (!stride.has_value ()) {
551
- LLVM_DEBUG ( DBGS ( ) << " no stride\n " ) ;
550
+ LDBG ( ) << " no stride" ;
552
551
return rewriter.notifyMatchFailure (op, " no stride" );
553
552
}
554
553
@@ -583,7 +582,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
583
582
isTranspose ? rewriter.getUnitAttr () : UnitAttr ());
584
583
valueMapping[mappingResult] = load;
585
584
586
- LLVM_DEBUG ( DBGS ( ) << " transfer read to: " << load << " \n " ) ;
585
+ LDBG ( ) << " transfer read to: " << load;
587
586
return success ();
588
587
}
589
588
@@ -597,13 +596,13 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
597
596
std::optional<int64_t > stride =
598
597
getStaticallyKnownRowStride (op.getShapedType ());
599
598
if (!stride.has_value ()) {
600
- LLVM_DEBUG ( DBGS ( ) << " no stride\n " ) ;
599
+ LDBG ( ) << " no stride" ;
601
600
return rewriter.notifyMatchFailure (op, " no stride" );
602
601
}
603
602
604
603
auto it = valueMapping.find (op.getVector ());
605
604
if (it == valueMapping.end ()) {
606
- LLVM_DEBUG ( DBGS ( ) << " no mapping\n " ) ;
605
+ LDBG ( ) << " no mapping" ;
607
606
return rewriter.notifyMatchFailure (op, " no mapping" );
608
607
}
609
608
@@ -613,9 +612,9 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
613
612
rewriter.getIndexAttr (*stride), /* transpose=*/ UnitAttr ());
614
613
(void )store;
615
614
616
- LLVM_DEBUG ( DBGS ( ) << " transfer write to: " << store << " \n " ) ;
615
+ LDBG ( ) << " transfer write to: " << store;
617
616
618
- LLVM_DEBUG ( DBGS ( ) << " erase: " << op << " \n " ) ;
617
+ LDBG ( ) << " erase: " << op;
619
618
rewriter.eraseOp (op);
620
619
return success ();
621
620
}
@@ -641,21 +640,21 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
641
640
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
642
641
nvgpu::getWarpMatrixInfo (op);
643
642
if (failed (warpMatrixInfo)) {
644
- LLVM_DEBUG ( DBGS ( ) << " no warpMatrixInfo\n " ) ;
643
+ LDBG ( ) << " no warpMatrixInfo" ;
645
644
return rewriter.notifyMatchFailure (op, " no warpMatrixInfo" );
646
645
}
647
646
648
647
FailureOr<nvgpu::FragmentElementInfo> regInfo =
649
648
nvgpu::getMmaSyncRegisterType (*warpMatrixInfo);
650
649
if (failed (regInfo)) {
651
- LLVM_DEBUG ( DBGS ( ) << " not mma sync reg info\n " ) ;
650
+ LDBG ( ) << " not mma sync reg info" ;
652
651
return rewriter.notifyMatchFailure (op, " not mma sync reg info" );
653
652
}
654
653
655
654
VectorType vectorType = getMmaSyncVectorOperandType (*regInfo);
656
655
auto dense = dyn_cast<SplatElementsAttr>(op.getValue ());
657
656
if (!dense) {
658
- LLVM_DEBUG ( DBGS ( ) << " not a splat\n " ) ;
657
+ LDBG ( ) << " not a splat" ;
659
658
return rewriter.notifyMatchFailure (op, " not a splat" );
660
659
}
661
660
@@ -677,8 +676,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
677
676
mlir::AffineMap map = op.getPermutationMap ();
678
677
679
678
if (map.getNumResults () != 2 ) {
680
- LLVM_DEBUG ( DBGS () << " Failed because the result of `vector.transfer_read` "
681
- " is not a 2d operand\n " ) ;
679
+ LDBG () << " Failed because the result of `vector.transfer_read` "
680
+ " is not a 2d operand" ;
682
681
return failure ();
683
682
}
684
683
@@ -691,8 +690,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
691
690
auto exprN = dyn_cast<AffineDimExpr>(dN);
692
691
693
692
if (!exprM || !exprN) {
694
- LLVM_DEBUG ( DBGS () << " Failed because expressions are not affine dim "
695
- " expressions, then transpose cannot be determined.\n " ) ;
693
+ LDBG () << " Failed because expressions are not affine dim "
694
+ " expressions, then transpose cannot be determined." ;
696
695
return failure ();
697
696
}
698
697
@@ -709,20 +708,20 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
709
708
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
710
709
nvgpu::getWarpMatrixInfo (op);
711
710
if (failed (warpMatrixInfo)) {
712
- LLVM_DEBUG ( DBGS ( ) << " no warpMatrixInfo\n " ) ;
711
+ LDBG ( ) << " no warpMatrixInfo" ;
713
712
return rewriter.notifyMatchFailure (op, " no warpMatrixInfo" );
714
713
}
715
714
716
715
FailureOr<nvgpu::FragmentElementInfo> regInfo =
717
716
nvgpu::getMmaSyncRegisterType (*warpMatrixInfo);
718
717
if (failed (regInfo)) {
719
- LLVM_DEBUG ( DBGS ( ) << " not mma sync reg info\n " ) ;
718
+ LDBG ( ) << " not mma sync reg info" ;
720
719
return rewriter.notifyMatchFailure (op, " not mma sync reg info" );
721
720
}
722
721
723
722
FailureOr<bool > transpose = isTransposed (op);
724
723
if (failed (transpose)) {
725
- LLVM_DEBUG ( DBGS ( ) << " failed to determine the transpose\n " ) ;
724
+ LDBG ( ) << " failed to determine the transpose" ;
726
725
return rewriter.notifyMatchFailure (
727
726
op, " Op should likely not be converted to a nvgpu.ldmatrix call." );
728
727
}
@@ -731,10 +730,8 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
731
730
nvgpu::getLdMatrixParams (*warpMatrixInfo, *transpose);
732
731
733
732
if (failed (params)) {
734
- LLVM_DEBUG (
735
- DBGS ()
736
- << " failed to convert vector.transfer_read to ldmatrix. "
737
- << " Op should likely not be converted to a nvgpu.ldmatrix call.\n " );
733
+ LDBG () << " failed to convert vector.transfer_read to ldmatrix. "
734
+ << " Op should likely not be converted to a nvgpu.ldmatrix call." ;
738
735
return rewriter.notifyMatchFailure (
739
736
op, " failed to convert vector.transfer_read to ldmatrix; this op "
740
737
" likely should not be converted to a nvgpu.ldmatrix call." );
@@ -745,7 +742,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
745
742
FailureOr<AffineMap> offsets =
746
743
nvgpu::getLaneIdToLdMatrixMatrixCoord (rewriter, loc, *params);
747
744
if (failed (offsets)) {
748
- LLVM_DEBUG ( DBGS ( ) << " no offsets\n " ) ;
745
+ LDBG ( ) << " no offsets" ;
749
746
return rewriter.notifyMatchFailure (op, " no offsets" );
750
747
}
751
748
@@ -934,7 +931,7 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
934
931
vector::StoreOp::create (rewriter, loc, el, op.getBase (), newIndices);
935
932
}
936
933
937
- LLVM_DEBUG ( DBGS ( ) << " erase: " << op << " \n " ) ;
934
+ LDBG ( ) << " erase: " << op;
938
935
rewriter.eraseOp (op);
939
936
return success ();
940
937
}
@@ -1132,9 +1129,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1132
1129
loop.getNumResults ())))
1133
1130
rewriter.replaceAllUsesWith (std::get<0 >(it), std::get<1 >(it));
1134
1131
1135
- LLVM_DEBUG ( DBGS ( ) << " newLoop now: " << newLoop << " \n " ) ;
1136
- LLVM_DEBUG ( DBGS ( ) << " stripped scf.for: " << loop << " \n " ) ;
1137
- LLVM_DEBUG ( DBGS ( ) << " erase: " << loop) ;
1132
+ LDBG ( ) << " newLoop now: " << newLoop;
1133
+ LDBG ( ) << " stripped scf.for: " << loop;
1134
+ LDBG ( ) << " erase: " << loop;
1138
1135
1139
1136
rewriter.eraseOp (loop);
1140
1137
return newLoop;
@@ -1150,7 +1147,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1150
1147
for (const auto &operand : llvm::enumerate (op.getInitArgs ())) {
1151
1148
auto it = valueMapping.find (operand.value ());
1152
1149
if (it == valueMapping.end ()) {
1153
- LLVM_DEBUG ( DBGS ( ) << " no value mapping for: " << operand.value () << " \n " );
1150
+ LDBG ( ) << " no value mapping for: " << operand.value ();
1154
1151
continue ;
1155
1152
}
1156
1153
argMapping.push_back (std::make_pair (
@@ -1168,7 +1165,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1168
1165
loopBody.getArgument (mapping.second + newForOp.getNumInductionVars ());
1169
1166
}
1170
1167
1171
- LLVM_DEBUG ( DBGS ( ) << " scf.for to: " << newForOp << " \n " ) ;
1168
+ LDBG ( ) << " scf.for to: " << newForOp;
1172
1169
return success ();
1173
1170
}
1174
1171
@@ -1191,7 +1188,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1191
1188
}
1192
1189
scf::YieldOp::create (rewriter, op.getLoc (), yieldOperands);
1193
1190
1194
- LLVM_DEBUG ( DBGS ( ) << " erase: " << op << " \n " ) ;
1191
+ LDBG ( ) << " erase: " << op;
1195
1192
rewriter.eraseOp (op);
1196
1193
return success ();
1197
1194
}
@@ -1244,7 +1241,7 @@ LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
1244
1241
1245
1242
auto globalRes = LogicalResult::success ();
1246
1243
for (Operation *op : ops) {
1247
- LLVM_DEBUG ( DBGS ( ) << " Process op: " << *op << " \n " ) ;
1244
+ LDBG ( ) << " Process op: " << *op;
1248
1245
// Apparently callers do not want to early exit on failure here.
1249
1246
auto res = LogicalResult::success ();
1250
1247
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
0 commit comments