3434using namespace mlir ;
3535using namespace mlir ::sparse_tensor;
3636
37- #define RETURN_FAILURE_IF_FAILED (X ) \
38- if (failed(X)) { \
39- return failure (); \
40- }
41-
4237// ===----------------------------------------------------------------------===//
4338// Local convenience methods.
4439// ===----------------------------------------------------------------------===//
@@ -68,10 +63,6 @@ void StorageLayout::foreachField(
6863 llvm::function_ref<bool (FieldIndex, SparseTensorFieldKind, Level,
6964 DimLevelType)>
7065 callback) const {
71- #define RETURN_ON_FALSE (fidx, kind, lvl, dlt ) \
72- if (!(callback (fidx, kind, lvl, dlt))) \
73- return ;
74-
7566 const auto lvlTypes = enc.getLvlTypes ();
7667 const Level lvlRank = enc.getLvlRank ();
7768 const Level cooStart = getCOOStart (enc);
@@ -81,21 +72,22 @@ void StorageLayout::foreachField(
8172 for (Level l = 0 ; l < end; l++) {
8273 const auto dlt = lvlTypes[l];
8374 if (isDLTWithPos (dlt)) {
84- RETURN_ON_FALSE (fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt);
75+ if (!(callback (fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt)))
76+ return ;
8577 }
8678 if (isDLTWithCrd (dlt)) {
87- RETURN_ON_FALSE (fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt);
79+ if (!(callback (fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt)))
80+ return ;
8881 }
8982 }
9083 // The values array.
91- RETURN_ON_FALSE (fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel ,
92- DimLevelType::Undef);
93-
84+ if (!( callback (fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel ,
85+ DimLevelType::Undef)))
86+ return ;
9487 // Put metadata at the end.
95- RETURN_ON_FALSE (fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel ,
96- DimLevelType::Undef);
97-
98- #undef RETURN_ON_FALSE
88+ if (!(callback (fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel ,
89+ DimLevelType::Undef)))
90+ return ;
9991}
10092
10193void sparse_tensor::foreachFieldAndTypeInSparseTensor (
@@ -435,18 +427,11 @@ SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
435427}
436428
437429Attribute SparseTensorEncodingAttr::parse (AsmParser &parser, Type type) {
438- #define RETURN_ON_FAIL (stmt ) \
439- if (failed (stmt)) { \
440- return {}; \
441- }
442- #define ERROR_IF (COND, MSG ) \
443- if (COND) { \
444- parser.emitError (parser.getNameLoc (), MSG); \
445- return {}; \
446- }
447-
448- RETURN_ON_FAIL (parser.parseLess ())
449- RETURN_ON_FAIL (parser.parseLBrace ())
430+ // Open "<{" part.
431+ if (failed (parser.parseLess ()))
432+ return {};
433+ if (failed (parser.parseLBrace ()))
434+ return {};
450435
451436 // Process the data from the parsed dictionary value into struct-like data.
452437 SmallVector<DimLevelType> lvlTypes;
@@ -466,13 +451,15 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
466451 }
467452 unsigned keyWordIndex = it - keys.begin ();
468453 // Consume the `=` after keys
469- RETURN_ON_FAIL (parser.parseEqual ())
454+ if (failed (parser.parseEqual ()))
455+ return {};
470456 // Dispatch on keyword.
471457 switch (keyWordIndex) {
472458 case 0 : { // map
473459 ir_detail::DimLvlMapParser cParser (parser);
474460 auto res = cParser.parseDimLvlMap ();
475- RETURN_ON_FAIL (res);
461+ if (failed (res))
462+ return {};
476463 const auto &dlm = *res;
477464
478465 const Level lvlRank = dlm.getLvlRank ();
@@ -504,17 +491,27 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
504491 }
505492 case 1 : { // posWidth
506493 Attribute attr;
507- RETURN_ON_FAIL (parser.parseAttribute (attr))
494+ if (failed (parser.parseAttribute (attr)))
495+ return {};
508496 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
509- ERROR_IF (!intAttr, " expected an integral position bitwidth" )
497+ if (!intAttr) {
498+ parser.emitError (parser.getNameLoc (),
499+ " expected an integral position bitwidth" );
500+ return {};
501+ }
510502 posWidth = intAttr.getInt ();
511503 break ;
512504 }
513505 case 2 : { // crdWidth
514506 Attribute attr;
515- RETURN_ON_FAIL (parser.parseAttribute (attr))
507+ if (failed (parser.parseAttribute (attr)))
508+ return {};
516509 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
517- ERROR_IF (!intAttr, " expected an integral index bitwidth" )
510+ if (!intAttr) {
511+ parser.emitError (parser.getNameLoc (),
512+ " expected an integral index bitwidth" );
513+ return {};
514+ }
518515 crdWidth = intAttr.getInt ();
519516 break ;
520517 }
@@ -524,10 +521,11 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
524521 break ;
525522 }
526523
527- RETURN_ON_FAIL (parser.parseRBrace ())
528- RETURN_ON_FAIL (parser.parseGreater ())
529- #undef ERROR_IF
530- #undef RETURN_ON_FAIL
524+ // Close "}>" part.
525+ if (failed (parser.parseRBrace ()))
526+ return {};
527+ if (failed (parser.parseGreater ()))
528+ return {};
531529
532530 // Construct struct-like storage for attribute.
533531 if (!lvlToDim || lvlToDim.isEmpty ()) {
@@ -668,9 +666,9 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
668666 function_ref<InFlightDiagnostic()> emitError) const {
669667 // Check structural integrity. In particular, this ensures that the
670668 // level-rank is coherent across all the fields.
671- RETURN_FAILURE_IF_FAILED ( verify (emitError, getLvlTypes (), getDimToLvl (),
672- getLvlToDim (), getPosWidth (), getCrdWidth (),
673- getDimSlices ()))
669+ if ( failed ( verify (emitError, getLvlTypes (), getDimToLvl (), getLvlToDim (),
670+ getPosWidth (), getCrdWidth (), getDimSlices ())))
671+ return failure ();
674672 // Check integrity with tensor type specifics. In particular, we
675673 // need only check that the dimension-rank of the tensor agrees with
676674 // the dimension-rank of the encoding.
@@ -926,10 +924,6 @@ Level mlir::sparse_tensor::toStoredDim(RankedTensorType type, Dimension d) {
926924 return toStoredDim (getSparseTensorEncoding (type), d);
927925}
928926
929- // ===----------------------------------------------------------------------===//
930- // SparseTensorDialect Types.
931- // ===----------------------------------------------------------------------===//
932-
933927// / We normalized sparse tensor encoding attribute by always using
934928// / ordered/unique DLT such that "compressed_nu_no" and "compressed_nu" (as well
935929// / as other variants) lead to the same storage specifier type, and stripping
@@ -1340,9 +1334,8 @@ LogicalResult ToSliceStrideOp::verify() {
13401334}
13411335
13421336LogicalResult GetStorageSpecifierOp::verify () {
1343- RETURN_FAILURE_IF_FAILED (verifySparsifierGetterSetter (
1344- getSpecifierKind (), getLevel (), getSpecifier (), getOperation ()))
1345- return success ();
1337+ return verifySparsifierGetterSetter (getSpecifierKind (), getLevel (),
1338+ getSpecifier (), getOperation ());
13461339}
13471340
13481341template <typename SpecifierOp>
@@ -1360,9 +1353,8 @@ OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
13601353}
13611354
13621355LogicalResult SetStorageSpecifierOp::verify () {
1363- RETURN_FAILURE_IF_FAILED (verifySparsifierGetterSetter (
1364- getSpecifierKind (), getLevel (), getSpecifier (), getOperation ()))
1365- return success ();
1356+ return verifySparsifierGetterSetter (getSpecifierKind (), getLevel (),
1357+ getSpecifier (), getOperation ());
13661358}
13671359
13681360template <class T >
@@ -1404,20 +1396,23 @@ LogicalResult BinaryOp::verify() {
14041396 // Check correct number of block arguments and return type for each
14051397 // non-empty region.
14061398 if (!overlap.empty ()) {
1407- RETURN_FAILURE_IF_FAILED (verifyNumBlockArgs (
1408- this , overlap, " overlap" , TypeRange{leftType, rightType}, outputType))
1399+ if (failed (verifyNumBlockArgs (this , overlap, " overlap" ,
1400+ TypeRange{leftType, rightType}, outputType)))
1401+ return failure ();
14091402 }
14101403 if (!left.empty ()) {
1411- RETURN_FAILURE_IF_FAILED (
1412- verifyNumBlockArgs (this , left, " left" , TypeRange{leftType}, outputType))
1404+ if (failed (verifyNumBlockArgs (this , left, " left" , TypeRange{leftType},
1405+ outputType)))
1406+ return failure ();
14131407 } else if (getLeftIdentity ()) {
14141408 if (leftType != outputType)
14151409 return emitError (" left=identity requires first argument to have the same "
14161410 " type as the output" );
14171411 }
14181412 if (!right.empty ()) {
1419- RETURN_FAILURE_IF_FAILED (verifyNumBlockArgs (
1420- this , right, " right" , TypeRange{rightType}, outputType))
1413+ if (failed (verifyNumBlockArgs (this , right, " right" , TypeRange{rightType},
1414+ outputType)))
1415+ return failure ();
14211416 } else if (getRightIdentity ()) {
14221417 if (rightType != outputType)
14231418 return emitError (" right=identity requires second argument to have the "
@@ -1434,13 +1429,15 @@ LogicalResult UnaryOp::verify() {
14341429 // non-empty region.
14351430 Region &present = getPresentRegion ();
14361431 if (!present.empty ()) {
1437- RETURN_FAILURE_IF_FAILED (verifyNumBlockArgs (
1438- this , present, " present" , TypeRange{inputType}, outputType))
1432+ if (failed (verifyNumBlockArgs (this , present, " present" ,
1433+ TypeRange{inputType}, outputType)))
1434+ return failure ();
14391435 }
14401436 Region &absent = getAbsentRegion ();
14411437 if (!absent.empty ()) {
1442- RETURN_FAILURE_IF_FAILED (
1443- verifyNumBlockArgs (this , absent, " absent" , TypeRange{}, outputType))
1438+ if (failed (verifyNumBlockArgs (this , absent, " absent" , TypeRange{},
1439+ outputType)))
1440+ return failure ();
14441441 // Absent branch can only yield invariant values.
14451442 Block *absentBlock = &absent.front ();
14461443 Block *parent = getOperation ()->getBlock ();
@@ -1655,22 +1652,18 @@ LogicalResult ReorderCOOOp::verify() {
16551652
16561653LogicalResult ReduceOp::verify () {
16571654 Type inputType = getX ().getType ();
1658- // Check correct number of block arguments and return type.
16591655 Region &formula = getRegion ();
1660- RETURN_FAILURE_IF_FAILED (verifyNumBlockArgs (
1661- this , formula, " reduce" , TypeRange{inputType, inputType}, inputType))
1662- return success ();
1656+ return verifyNumBlockArgs (this , formula, " reduce" ,
1657+ TypeRange{inputType, inputType}, inputType);
16631658}
16641659
16651660LogicalResult SelectOp::verify () {
16661661 Builder b (getContext ());
16671662 Type inputType = getX ().getType ();
16681663 Type boolType = b.getI1Type ();
1669- // Check correct number of block arguments and return type.
16701664 Region &formula = getRegion ();
1671- RETURN_FAILURE_IF_FAILED (verifyNumBlockArgs (this , formula, " select" ,
1672- TypeRange{inputType}, boolType))
1673- return success ();
1665+ return verifyNumBlockArgs (this , formula, " select" , TypeRange{inputType},
1666+ boolType);
16741667}
16751668
16761669LogicalResult SortOp::verify () {
@@ -1725,8 +1718,6 @@ LogicalResult YieldOp::verify() {
17251718 " reduce, select or foreach" );
17261719}
17271720
1728- #undef RETURN_FAILURE_IF_FAILED
1729-
17301721// / Materialize a single constant operation from a given attribute value with
17311722// / the desired resultant type.
17321723Operation *SparseTensorDialect::materializeConstant (OpBuilder &builder,
0 commit comments