2626#include " mlir/IR/Value.h"
2727#include " mlir/Pass/Pass.h"
2828#include " llvm/Support/Debug.h"
29+ #include " llvm/Support/DebugLog.h"
2930#include " llvm/Support/ErrorHandling.h"
3031#include " llvm/Support/raw_ostream.h"
3132#include < optional>
3233
3334#define DEBUG_TYPE " nvgpu-to-nvvm"
34- #define DBGS () (llvm::dbgs() << ' [' << DEBUG_TYPE << " ] " )
35- #define DBGSE () (llvm::dbgs())
3635
3736namespace mlir {
3837#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
@@ -1105,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
11051104 // // [0,14) start_address
11061105 dsc = insertBit (dsc, basePtr14bit, startBaseAddrBit);
11071106
1108- LLVM_DEBUG ( DBGS () << " Generating warpgroup.descriptor: "
1109- << " leading_off:" << leadDimVal << " \t "
1110- << " stride_off :" << strideDimVal << " \t "
1111- << " base_offset:" << offsetVal << " \t "
1112- << " layout_type:" << swizzle << " ("
1113- << nvgpu::stringifyTensorMapSwizzleKind (swizzleKind)
1114- << " )\n start_addr : " << baseAddr << " \n " ) ;
1107+ LDBG () << " Generating warpgroup.descriptor: "
1108+ << " leading_off:" << leadDimVal << " \t "
1109+ << " stride_off :" << strideDimVal << " \t "
1110+ << " base_offset:" << offsetVal << " \t "
1111+ << " layout_type:" << swizzle << " ("
1112+ << nvgpu::stringifyTensorMapSwizzleKind (swizzleKind)
1113+ << " )\n start_addr : " << baseAddr;
11151114
11161115 rewriter.replaceOp (op, dsc);
11171116 return success ();
@@ -1281,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering
12811280 } else {
12821281 llvm_unreachable (" msg: not supported K shape" );
12831282 }
1284- LLVM_DEBUG ( DBGS () << " Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1285- << " , n = " << wgmmaN << " , k = " << wgmmaK << " ]\n " ) ;
1283+ LDBG () << " Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1284+ << " , n = " << wgmmaN << " , k = " << wgmmaK << " ]" ;
12861285 }
12871286
12881287 // / Generates WGMMATypesAttr from MLIR Type
@@ -1366,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering
13661365 int tileShapeA = matrixTypeA.getDimSize (1 );
13671366 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
13681367 incrementVal = incrementVal >> exclude4LSB;
1369- LLVM_DEBUG ( DBGS () << " \t\t [m: " << i << " n: " << j << " k: " << k
1370- << " ] [wgmma descriptors] Descriptor A + "
1371- << incrementVal << " | \t " ) ;
1368+ LDBG () << " \t\t [m: " << i << " n: " << j << " k: " << k
1369+ << " ] [wgmma descriptors] Descriptor A + " << incrementVal
1370+ << " | \t " ;
13721371 if (!incrementVal)
13731372 return desc;
13741373 return makeAdd (desc, makeI64Const (b, incrementVal));
@@ -1391,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering
13911390 int byte = elemB.getIntOrFloatBitWidth () / 8 ;
13921391 int incrementVal = matrixTypeB.getDimSize (0 ) * wgmmaK * k * byte;
13931392 incrementVal = incrementVal >> exclude4LSB;
1394- LLVM_DEBUG ( DBGSE ( ) << " Descriptor B + " << incrementVal << " \n " ) ;
1393+ LDBG ( ) << " Descriptor B + " << incrementVal;
13951394 if (!incrementVal)
13961395 return desc;
13971396 return makeAdd (desc, makeI64Const (b, incrementVal));
@@ -1400,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering
14001399 // / This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
14011400 // / descriptors and arranges them based on induction variables: i, j, and k.
14021401 Value generateWgmma (int i, int j, int k, Value matrixC) {
1403- LLVM_DEBUG (DBGS () << " \t wgmma."
1404- << " m" << wgmmaM << " n" << wgmmaN << " k" << wgmmaK
1405- << " (A[" << (iterationM * wgmmaM) << " :"
1406- << (iterationM * wgmmaM) + wgmmaM << " ]["
1407- << (iterationK * wgmmaK) << " :"
1408- << (iterationK * wgmmaK + wgmmaK) << " ] * "
1409- << " B[" << (iterationK * wgmmaK) << " :"
1410- << (iterationK * wgmmaK + wgmmaK) << " ][" << 0 << " :"
1411- << wgmmaN << " ])\n " );
1402+ LDBG () << " \t wgmma."
1403+ << " m" << wgmmaM << " n" << wgmmaN << " k" << wgmmaK << " (A["
1404+ << (iterationM * wgmmaM) << " :" << (iterationM * wgmmaM) + wgmmaM
1405+ << " ][" << (iterationK * wgmmaK) << " :"
1406+ << (iterationK * wgmmaK + wgmmaK) << " ] * "
1407+ << " B[" << (iterationK * wgmmaK) << " :"
1408+ << (iterationK * wgmmaK + wgmmaK) << " ][" << 0 << " :" << wgmmaN
1409+ << " ])" ;
14121410
14131411 Value descriptorA = iterateDescriptorA (adaptor.getDescriptorA (), i, j, k);
14141412 Value descriptorB = iterateDescriptorB (adaptor.getDescriptorB (), i, j, k);
@@ -1467,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering
14671465 totalM = op.getDescriptorA ().getType ().getTensor ().getDimSize (0 );
14681466 totalN = op.getDescriptorB ().getType ().getTensor ().getDimSize (1 );
14691467 totalK = op.getDescriptorA ().getType ().getTensor ().getDimSize (1 );
1470- LLVM_DEBUG ( DBGS ( ) << " ===--- GEMM D[" << totalM << " ][" << totalN
1471- << " ] += A [" << totalM << " ][" << totalK << " ] * B[ "
1472- << totalK << " ][ " << totalN << " ] ---===\n " ) ;
1468+ LDBG ( ) << " ===--- GEMM D[" << totalM << " ][" << totalN << " ] += A[ "
1469+ << totalM << " ][" << totalK << " ] * B [" << totalK << " ][ " << totalN
1470+ << " ] ---===" ;
14731471
14741472 // Find the shape for one wgmma instruction
14751473 findWgmmaShape (
0 commit comments