3333#include " llvm/AsmParser/Parser.h"
3434#include " llvm/IR/Attributes.h"
3535#include " llvm/IR/Function.h"
36+ #include " llvm/IR/IntrinsicsNVPTX.h"
3637#include " llvm/IR/Type.h"
3738#include " llvm/Support/Casting.h"
3839#include " llvm/Support/FormatVariadic.h"
@@ -56,7 +57,7 @@ using namespace NVVM;
5657// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
5758// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
5859// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
59- static LogicalResult CpAsyncBulkTensorCommonVerifier (size_t tensorDims,
60+ static LogicalResult cpAsyncBulkTensorCommonVerifier (size_t tensorDims,
6061 bool isIm2Col,
6162 size_t numIm2ColOffsets,
6263 Location loc) {
@@ -81,7 +82,7 @@ static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
8182LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify () {
8283 size_t numIm2ColOffsets = getIm2colOffsets ().size ();
8384 bool isIm2Col = numIm2ColOffsets > 0 ;
84- return CpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col,
85+ return cpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col,
8586 numIm2ColOffsets, getLoc ());
8687}
8788
@@ -105,13 +106,13 @@ LogicalResult CpAsyncOp::verify() {
105106LogicalResult CpAsyncBulkTensorPrefetchOp::verify () {
106107 size_t numIm2ColOffsets = getIm2colOffsets ().size ();
107108 bool isIm2Col = numIm2ColOffsets > 0 ;
108- return CpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col,
109+ return cpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col,
109110 numIm2ColOffsets, getLoc ());
110111}
111112
112113LogicalResult CpAsyncBulkTensorReduceOp::verify () {
113114 bool isIm2Col = (getMode () == TMAStoreMode::IM2COL);
114- return CpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col, 0 ,
115+ return cpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col, 0 ,
115116 getLoc ());
116117}
117118
@@ -183,14 +184,14 @@ static bool isIntegerPtxType(MMATypes type) {
183184
184185MMATypes MmaOp::accumPtxType () {
185186 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType (
186- getODSOperands (2 ).getTypes ().front (), /* isAccum =*/ true );
187+ getODSOperands (2 ).getTypes ().front (), /* isAccumulator =*/ true );
187188 assert (val.has_value () && " accumulator PTX type should always be inferrable" );
188189 return val.value ();
189190}
190191
191192MMATypes MmaOp::resultPtxType () {
192193 std::optional<mlir::NVVM::MMATypes> val =
193- inferOperandMMAType (getResult ().getType (), /* isAccum =*/ true );
194+ inferOperandMMAType (getResult ().getType (), /* isAccumulator =*/ true );
194195 assert (val.has_value () && " result PTX type should always be inferrable" );
195196 return val.value ();
196197}
@@ -224,7 +225,7 @@ void MmaOp::print(OpAsmPrinter &p) {
224225 }
225226 }
226227 std::optional<MMATypes> inferredType =
227- inferOperandMMAType (regTypes.back (), /* isAccum =*/ fragIdx >= 2 );
228+ inferOperandMMAType (regTypes.back (), /* isAccumulator =*/ fragIdx >= 2 );
228229 if (inferredType)
229230 ignoreAttrNames.push_back (frag.ptxTypeAttr );
230231 }
@@ -364,14 +365,14 @@ ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
364365 if (failed (parser.resolveOperands (frag.regs , frag.regTypes ,
365366 parser.getNameLoc (), result.operands )))
366367 return failure ();
367- frag.elemtype =
368- inferOperandMMAType (frag. regTypes [ 0 ], /* isAccum= */ iter.index () < 2 );
368+ frag.elemtype = inferOperandMMAType (frag. regTypes [ 0 ],
369+ /* isAccumulator */ iter.index () < 2 );
369370 }
370371
371372 Type resultType;
372373 if (parser.parseArrow () || parser.parseType (resultType))
373374 return failure ();
374- frags[3 ].elemtype = inferOperandMMAType (resultType, /* isAccum= */ true );
375+ frags[3 ].elemtype = inferOperandMMAType (resultType, /* isAccumulator */ true );
375376
376377 std::array<StringRef, 2 > names{" multiplicandAPtxType" ,
377378 " multiplicandBPtxType" };
@@ -1121,9 +1122,9 @@ LogicalResult NVVM::Tcgen05CpOp::verify() {
11211122
11221123LogicalResult NVVM::MatchSyncOp::verify () {
11231124 if (getKind () == NVVM::MatchSyncKind::all) {
1124- auto Type = llvm::dyn_cast<LLVM::LLVMStructType>(getType ());
1125- if (!Type || Type .getBody ().size () != 2 ||
1126- !Type .getBody ()[0 ].isInteger (32 ) || !Type .getBody ()[1 ].isInteger (1 )) {
1125+ auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType ());
1126+ if (!type || type .getBody ().size () != 2 ||
1127+ !type .getBody ()[0 ].isInteger (32 ) || !type .getBody ()[1 ].isInteger (1 )) {
11271128 return emitOpError (" match.sync 'all' returns a two element struct with "
11281129 " first element as i32 and second element as i1" );
11291130 }
@@ -1164,7 +1165,7 @@ CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
11641165 llvm::Intrinsic::ID id;
11651166
11661167 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1167- bool hasCpSize = cpAsyncOp.getCpSize () ? true : false ;
1168+ bool hasCpSize = static_cast < bool >( cpAsyncOp.getCpSize ()) ;
11681169 switch (cpAsyncOp.getSize ()) {
11691170 case 4 :
11701171 id = GET_CP_ASYNC_ID (ca, 4 , hasCpSize);
@@ -1263,6 +1264,8 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
12631264 llvm_unreachable (" Invalid Reduction Op for CpAsyncBulkTensorReduceOp" );
12641265}
12651266
1267+ #define _none
1268+
12661269#define CVT_F2TF32_ID_IMPL (rnd, relu, sf ) \
12671270 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
12681271 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
@@ -1282,7 +1285,7 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
12821285 case RndMode::RZ:
12831286 return GET_CVT_F2TF32_ID (rz, _relu, _satfinite);
12841287 case RndMode::RNA:
1285- return GET_CVT_F2TF32_ID (rna, , _satfinite);
1288+ return GET_CVT_F2TF32_ID (rna, _none , _satfinite);
12861289 default :
12871290 llvm_unreachable (" Invalid RoundingMode for CvtFloatToTF32Op" );
12881291 }
@@ -1293,9 +1296,9 @@ Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
12931296 LLVM::ModuleTranslation &mt,
12941297 llvm::SmallVector<llvm::Value *> &args) {
12951298 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
1296- unsigned AS = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr ().getType ())
1299+ unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr ().getType ())
12971300 .getAddressSpace ();
1298- bool isShared = AS == NVVMMemorySpace::kSharedMemorySpace ;
1301+ bool isShared = as == NVVMMemorySpace::kSharedMemorySpace ;
12991302 bool is2CTAMode = curOp.getGroup () == Tcgen05GroupKind::CTA_2;
13001303
13011304 llvm::Intrinsic::ID id;
@@ -1342,14 +1345,15 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
13421345 LLVM::ModuleTranslation &mt,
13431346 llvm::SmallVector<llvm::Value *> &args) {
13441347 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
1345- unsigned AS = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr ().getType ())
1348+ unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr ().getType ())
13461349 .getAddressSpace ();
1347- bool isShared = AS == NVVMMemorySpace::kSharedMemorySpace ;
1348- bool hasMulticast = curOp.getMulticastMask () ? true : false ;
1350+ bool isShared = as == NVVMMemorySpace::kSharedMemorySpace ;
1351+ bool hasMulticast = static_cast < bool >( curOp.getMulticastMask ()) ;
13491352 bool is2CTAMode = curOp.getGroup () == Tcgen05GroupKind::CTA_2;
13501353
1351- auto id = is2CTAMode ? GET_TCGEN05_COMMIT_ID (cg2, isShared, hasMulticast)
1352- : GET_TCGEN05_COMMIT_ID (cg1, isShared, hasMulticast);
1354+ llvm::Intrinsic::ID id =
1355+ is2CTAMode ? GET_TCGEN05_COMMIT_ID (cg2, isShared, hasMulticast)
1356+ : GET_TCGEN05_COMMIT_ID (cg1, isShared, hasMulticast);
13531357
13541358 // Fill the Intrinsic Args
13551359 args.push_back (mt.lookupValue (curOp.getAddr ()));
@@ -1368,9 +1372,9 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
13681372
13691373#define GET_TCGEN05_CP_ID (shape_mc, src_fmt, is_2cta ) \
13701374 [&]() -> auto { \
1371- if (src_fmt == Tcgen05CpSrcFormat::B6x16_P32) \
1375+ if (( src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
13721376 return TCGEN05_CP_2CTA (shape_mc, _b6x16_p32, is_2cta); \
1373- if (src_fmt == Tcgen05CpSrcFormat::B4x16_P64) \
1377+ if (( src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
13741378 return TCGEN05_CP_2CTA (shape_mc, _b4x16_p64, is_2cta); \
13751379 return TCGEN05_CP_2CTA (shape_mc, , is_2cta); \
13761380 }()
@@ -1400,47 +1404,47 @@ llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
14001404
14011405// Returns the valid vector length for a given shape and vector length, the
14021406// function models the table mentioned in the tcgen05.{ld, st} Op description
1403- static unsigned isValidVectorLength (NVVM::Tcgen05LdStShape Shape ,
1404- unsigned VecLen ) {
1405- if (Shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1406- return VecLen >= 2 ;
1407- if (Shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1408- return VecLen >= 4 ;
1407+ static unsigned isValidVectorLength (NVVM::Tcgen05LdStShape shape ,
1408+ unsigned vecLen ) {
1409+ if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1410+ return vecLen >= 2 ;
1411+ if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1412+ return vecLen >= 4 ;
14091413 return true ;
14101414}
14111415
14121416LogicalResult Tcgen05LdOp::verify () {
1413- LogicalResult Result = success ();
1417+ LogicalResult result = success ();
14141418 if (getShape () == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset ())
1415- Result = emitError (" shape 16x32bx2 requires offset argument" );
1419+ result = emitError (" shape 16x32bx2 requires offset argument" );
14161420
1417- auto ResTy = getRes ().getType ();
1418- unsigned ResLen = isa<VectorType>(ResTy )
1419- ? llvm::cast<VectorType>(ResTy ).getNumElements ()
1421+ auto resTy = getRes ().getType ();
1422+ unsigned resLen = isa<VectorType>(resTy )
1423+ ? llvm::cast<VectorType>(resTy ).getNumElements ()
14201424 : 1 ;
1421- if (!isValidVectorLength (getShape (), ResLen ))
1422- Result = emitError (llvm::formatv (" invalid result type length {0} for shape "
1425+ if (!isValidVectorLength (getShape (), resLen ))
1426+ result = emitError (llvm::formatv (" invalid result type length {0} for shape "
14231427 " {1} in tcgen05.ld Op" ,
1424- ResLen , stringifyEnum (getShape ())));
1428+ resLen , stringifyEnum (getShape ())));
14251429
1426- return Result ;
1430+ return result ;
14271431}
14281432
14291433LogicalResult Tcgen05StOp::verify () {
1430- LogicalResult Result = success ();
1434+ LogicalResult result = success ();
14311435 if (getShape () == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset ())
1432- Result = emitError (" shape 16x32bx2 requires offset argument" );
1436+ result = emitError (" shape 16x32bx2 requires offset argument" );
14331437
1434- auto ValTy = getVal ().getType ();
1435- unsigned ValLen = isa<VectorType>(ValTy )
1436- ? llvm::cast<VectorType>(ValTy ).getNumElements ()
1438+ auto valTy = getVal ().getType ();
1439+ unsigned valLen = isa<VectorType>(valTy )
1440+ ? llvm::cast<VectorType>(valTy ).getNumElements ()
14371441 : 1 ;
1438- if (!isValidVectorLength (getShape (), ValLen ))
1439- Result = emitError (llvm::formatv (" invalid input length {0} for shape "
1442+ if (!isValidVectorLength (getShape (), valLen ))
1443+ result = emitError (llvm::formatv (" invalid input length {0} for shape "
14401444 " {1} in tcgen05.st Op" ,
1441- ValLen , stringifyEnum (getShape ())));
1445+ valLen , stringifyEnum (getShape ())));
14421446
1443- return Result ;
1447+ return result ;
14441448}
14451449
14461450// / Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
@@ -1560,7 +1564,7 @@ NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
15601564 return failure ();
15611565 }
15621566 if (files && !llvm::all_of (files, [](::mlir::Attribute attr) {
1563- return attr && mlir::isa <StringAttr>(attr);
1567+ return mlir::isa_and_nonnull <StringAttr>(attr);
15641568 })) {
15651569 emitError () << " All the elements in the `link` array must be strings." ;
15661570 return failure ();
0 commit comments