@@ -541,12 +541,10 @@ static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op,
541541 mlir::OperandRange operands,
542542 mlir::TypeRange types,
543543 std::optional<mlir::ArrayAttr> attributes) {
544- for (unsigned i = 0 , e = attributes->size (); i < e; ++i) {
545- if (i != 0 )
546- p << " , " ;
547- p << (*attributes)[i] << " -> " << operands[i] << " : "
548- << operands[i].getType ();
549- }
544+ llvm::interleaveComma (llvm::zip (*attributes, operands), p, [&](auto it) {
545+ p << std::get<0 >(it) << " -> " << std::get<1 >(it) << " : "
546+ << std::get<1 >(it).getType ();
547+ });
550548}
551549
552550// ===----------------------------------------------------------------------===//
@@ -852,27 +850,27 @@ static ParseResult parseNumGangs(
852850 return success ();
853851}
854852
853+ static void printSingleDeviceType (mlir::OpAsmPrinter &p, mlir::Attribute attr) {
854+ auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
855+ if (deviceTypeAttr.getValue () != mlir::acc::DeviceType::None)
856+ p << " [" << attr << " ]" ;
857+ }
858+
855859static void printNumGangs (mlir::OpAsmPrinter &p, mlir::Operation *op,
856860 mlir::OperandRange operands, mlir::TypeRange types,
857861 std::optional<mlir::ArrayAttr> deviceTypes,
858862 std::optional<mlir::DenseI32ArrayAttr> segments) {
859863 unsigned opIdx = 0 ;
860- for (unsigned i = 0 ; i < deviceTypes->size (); ++i) {
861- if (i != 0 )
862- p << " , " ;
864+ llvm::interleaveComma (llvm::enumerate (*deviceTypes), p, [&](auto it) {
863865 p << " {" ;
864- for (int32_t j = 0 ; j < (*segments)[i]; ++j) {
865- if (j != 0 )
866- p << " , " ;
867- p << operands[opIdx] << " : " << operands[opIdx].getType ();
868- ++opIdx;
869- }
866+ llvm::interleaveComma (
867+ llvm::seq<int32_t >(0 , (*segments)[it.index ()]), p, [&](auto it) {
868+ p << operands[opIdx] << " : " << operands[opIdx].getType ();
869+ ++opIdx;
870+ });
870871 p << " }" ;
871- auto deviceTypeAttr =
872- mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
873- if (deviceTypeAttr.getValue () != mlir::acc::DeviceType::None)
874- p << " [" << (*deviceTypes)[i] << " ]" ;
875- }
872+ printSingleDeviceType (p, it.value ());
873+ });
876874}
877875
878876static ParseResult parseDeviceTypeOperandsWithSegment (
@@ -921,30 +919,21 @@ static ParseResult parseDeviceTypeOperandsWithSegment(
921919 return success ();
922920}
923921
924- static void printSingleDeviceType (mlir::OpAsmPrinter &p, mlir::Attribute attr) {
925- auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
926- if (deviceTypeAttr.getValue () != mlir::acc::DeviceType::None)
927- p << " [" << attr << " ]" ;
928- }
929-
930922static void printDeviceTypeOperandsWithSegment (
931923 mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
932924 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
933925 std::optional<mlir::DenseI32ArrayAttr> segments) {
934926 unsigned opIdx = 0 ;
935- for (unsigned i = 0 ; i < deviceTypes->size (); ++i) {
936- if (i != 0 )
937- p << " , " ;
927+ llvm::interleaveComma (llvm::enumerate (*deviceTypes), p, [&](auto it) {
938928 p << " {" ;
939- for (int32_t j = 0 ; j < (*segments)[i]; ++j) {
940- if (j != 0 )
941- p << " , " ;
942- p << operands[opIdx] << " : " << operands[opIdx].getType ();
943- ++opIdx;
944- }
929+ llvm::interleaveComma (
930+ llvm::seq<int32_t >(0 , (*segments)[it.index ()]), p, [&](auto it) {
931+ p << operands[opIdx] << " : " << operands[opIdx].getType ();
932+ ++opIdx;
933+ });
945934 p << " }" ;
946- printSingleDeviceType (p, (*deviceTypes)[i] );
947- }
935+ printSingleDeviceType (p, it. value () );
936+ });
948937}
949938
950939static ParseResult parseDeviceTypeOperands (
@@ -977,12 +966,10 @@ static void
977966printDeviceTypeOperands (mlir::OpAsmPrinter &p, mlir::Operation *op,
978967 mlir::OperandRange operands, mlir::TypeRange types,
979968 std::optional<mlir::ArrayAttr> deviceTypes) {
980- for (unsigned i = 0 , e = deviceTypes->size (); i < e; ++i) {
981- if (i != 0 )
982- p << " , " ;
983- p << operands[i] << " : " << operands[i].getType ();
984- printSingleDeviceType (p, (*deviceTypes)[i]);
985- }
969+ llvm::interleaveComma (llvm::zip (*deviceTypes, operands), p, [&](auto it) {
970+ p << std::get<1 >(it) << " : " << std::get<1 >(it).getType ();
971+ printSingleDeviceType (p, std::get<0 >(it));
972+ });
986973}
987974
988975static ParseResult parseDeviceTypeOperandsWithKeywordOnly (
@@ -1056,14 +1043,10 @@ static void printDeviceTypes(mlir::OpAsmPrinter &p,
10561043 std::optional<mlir::ArrayAttr> deviceTypes) {
10571044 if (!hasDeviceTypeValues (deviceTypes))
10581045 return ;
1046+
10591047 p << " [" ;
1060- for (unsigned i = 0 ; i < deviceTypes.value ().size (); ++i) {
1061- if (i != 0 )
1062- p << " , " ;
1063- auto deviceTypeAttr =
1064- mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
1065- p << deviceTypeAttr;
1066- }
1048+ llvm::interleaveComma (*deviceTypes, p,
1049+ [&](mlir::Attribute attr) { p << attr; });
10671050 p << " ]" ;
10681051}
10691052
@@ -1081,19 +1064,11 @@ static void printDeviceTypeOperandsWithKeywordOnly(
10811064 }
10821065
10831066 p << " (" ;
1084-
10851067 printDeviceTypes (p, keywordOnlyDeviceTypes);
1086-
10871068 if (hasDeviceTypeValues (keywordOnlyDeviceTypes) &&
10881069 hasDeviceTypeValues (deviceTypes))
10891070 p << " , " ;
1090-
1091- for (unsigned i = 0 , e = deviceTypes->size (); i < e; ++i) {
1092- if (i != 0 )
1093- p << " , " ;
1094- p << operands[i] << " : " << operands[i].getType ();
1095- printSingleDeviceType (p, (*deviceTypes)[i]);
1096- }
1071+ printDeviceTypeOperands (p, op, operands, types, deviceTypes);
10971072 p << " )" ;
10981073}
10991074
@@ -1483,49 +1458,33 @@ void printGangClause(OpAsmPrinter &p, Operation *op,
14831458 }
14841459
14851460 p << " (" ;
1486- if (hasDeviceTypeValues (gangOnlyDeviceTypes)) {
1487- p << " [" ;
1488- for (unsigned i = 0 ; i < gangOnlyDeviceTypes.value ().size (); ++i) {
1489- if (i != 0 )
1490- p << " , " ;
1491- auto deviceTypeAttr =
1492- mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gangOnlyDeviceTypes)[i]);
1493- p << deviceTypeAttr;
1494- }
1495- p << " ]" ;
1496- }
1461+ printDeviceTypes (p, gangOnlyDeviceTypes);
14971462
14981463 if (hasDeviceTypeValues (gangOnlyDeviceTypes) &&
14991464 hasDeviceTypeValues (deviceTypes))
15001465 p << " , " ;
15011466
15021467 if (deviceTypes) {
15031468 unsigned opIdx = 0 ;
1504- for (unsigned i = 0 ; i < deviceTypes->size (); ++i) {
1505- if (i != 0 )
1506- p << " , " ;
1469+ llvm::interleaveComma (llvm::enumerate (*deviceTypes), p, [&](auto it) {
15071470 p << " {" ;
1508- for (int32_t j = 0 ; j < (*segments)[i]; ++j) {
1509- if (j != 0 )
1510- p << " , " ;
1511- auto gangArgTypeAttr =
1512- mlir::dyn_cast<mlir::acc::GangArgTypeAttr>((*gangArgTypes)[opIdx]);
1513- if (gangArgTypeAttr.getValue () == mlir::acc::GangArgType::Num)
1514- p << LoopOp::getGangNumKeyword ();
1515- else if (gangArgTypeAttr.getValue () == mlir::acc::GangArgType::Dim)
1516- p << LoopOp::getGangDimKeyword ();
1517- else if (gangArgTypeAttr.getValue () == mlir::acc::GangArgType::Static)
1518- p << LoopOp::getGangStaticKeyword ();
1519- p << " =" << operands[opIdx] << " : " << operands[opIdx].getType ();
1520- ++opIdx;
1521- }
1522-
1471+ llvm::interleaveComma (
1472+ llvm::seq<int32_t >(0 , (*segments)[it.index ()]), p, [&](auto it) {
1473+ auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
1474+ (*gangArgTypes)[opIdx]);
1475+ if (gangArgTypeAttr.getValue () == mlir::acc::GangArgType::Num)
1476+ p << LoopOp::getGangNumKeyword ();
1477+ else if (gangArgTypeAttr.getValue () == mlir::acc::GangArgType::Dim)
1478+ p << LoopOp::getGangDimKeyword ();
1479+ else if (gangArgTypeAttr.getValue () ==
1480+ mlir::acc::GangArgType::Static)
1481+ p << LoopOp::getGangStaticKeyword ();
1482+ p << " =" << operands[opIdx] << " : " << operands[opIdx].getType ();
1483+ ++opIdx;
1484+ });
15231485 p << " }" ;
1524- auto deviceTypeAttr =
1525- mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
1526- if (deviceTypeAttr.getValue () != mlir::acc::DeviceType::None)
1527- p << " [" << (*deviceTypes)[i] << " ]" ;
1528- }
1486+ printSingleDeviceType (p, it.value ());
1487+ });
15291488 }
15301489 p << " )" ;
15311490}
0 commit comments