@@ -3028,19 +3028,21 @@ bool hasDuplicateDeviceTypes(
30283028}
30293029
30303030// / Check for duplicates in the DeviceType array attribute.
3031- LogicalResult checkDeviceTypes (mlir::ArrayAttr deviceTypes) {
3031+ // / Returns std::nullopt if no duplicates, or the duplicate DeviceType if found.
3032+ static std::optional<mlir::acc::DeviceType>
3033+ checkDeviceTypes (mlir::ArrayAttr deviceTypes) {
30323034 llvm::SmallSet<mlir::acc::DeviceType, 3 > crtDeviceTypes;
30333035 if (!deviceTypes)
3034- return success () ;
3036+ return std:: nullopt ;
30353037 for (auto attr : deviceTypes) {
30363038 auto deviceTypeAttr =
30373039 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
30383040 if (!deviceTypeAttr)
3039- return failure () ;
3041+ return mlir::acc::DeviceType::None ;
30403042 if (!crtDeviceTypes.insert (deviceTypeAttr.getValue ()).second )
3041- return failure ();
3043+ return deviceTypeAttr. getValue ();
30423044 }
3043- return success () ;
3045+ return std:: nullopt ;
30443046}
30453047
30463048LogicalResult acc::LoopOp::verify () {
@@ -3067,9 +3069,10 @@ LogicalResult acc::LoopOp::verify() {
30673069 getCollapseDeviceTypeAttr ().getValue ().size ())
30683070 return emitOpError () << " collapse attribute count must match collapse"
30693071 << " device_type count" ;
3070- if (failed (checkDeviceTypes (getCollapseDeviceTypeAttr ())))
3071- return emitOpError ()
3072- << " duplicate device_type found in collapseDeviceType attribute" ;
3072+ if (auto duplicateDeviceType = checkDeviceTypes (getCollapseDeviceTypeAttr ()))
3073+ return emitOpError () << " duplicate device_type `"
3074+ << acc::stringifyDeviceType (*duplicateDeviceType)
3075+ << " ` found in collapseDeviceType attribute" ;
30733076
30743077 // Check gang
30753078 if (!getGangOperands ().empty ()) {
@@ -3082,31 +3085,43 @@ LogicalResult acc::LoopOp::verify() {
30823085 return emitOpError () << " gangOperandsArgType attribute count must match"
30833086 << " gangOperands count" ;
30843087 }
3085- if (getGangAttr () && failed (checkDeviceTypes (getGangAttr ())))
3086- return emitOpError () << " duplicate device_type found in gang attribute" ;
3088+ if (getGangAttr ()) {
3089+ if (auto duplicateDeviceType = checkDeviceTypes (getGangAttr ()))
3090+ return emitOpError () << " duplicate device_type `"
3091+ << acc::stringifyDeviceType (*duplicateDeviceType)
3092+ << " ` found in gang attribute" ;
3093+ }
30873094
30883095 if (failed (verifyDeviceTypeAndSegmentCountMatch (
30893096 *this , getGangOperands (), getGangOperandsSegmentsAttr (),
30903097 getGangOperandsDeviceTypeAttr (), " gang" )))
30913098 return failure ();
30923099
30933100 // Check worker
3094- if (failed (checkDeviceTypes (getWorkerAttr ())))
3095- return emitOpError () << " duplicate device_type found in worker attribute" ;
3096- if (failed (checkDeviceTypes (getWorkerNumOperandsDeviceTypeAttr ())))
3097- return emitOpError () << " duplicate device_type found in "
3098- " workerNumOperandsDeviceType attribute" ;
3101+ if (auto duplicateDeviceType = checkDeviceTypes (getWorkerAttr ()))
3102+ return emitOpError () << " duplicate device_type `"
3103+ << acc::stringifyDeviceType (*duplicateDeviceType)
3104+ << " ` found in worker attribute" ;
3105+ if (auto duplicateDeviceType =
3106+ checkDeviceTypes (getWorkerNumOperandsDeviceTypeAttr ()))
3107+ return emitOpError () << " duplicate device_type `"
3108+ << acc::stringifyDeviceType (*duplicateDeviceType)
3109+ << " ` found in workerNumOperandsDeviceType attribute" ;
30993110 if (failed (verifyDeviceTypeCountMatch (*this , getWorkerNumOperands (),
31003111 getWorkerNumOperandsDeviceTypeAttr (),
31013112 " worker" )))
31023113 return failure ();
31033114
31043115 // Check vector
3105- if (failed (checkDeviceTypes (getVectorAttr ())))
3106- return emitOpError () << " duplicate device_type found in vector attribute" ;
3107- if (failed (checkDeviceTypes (getVectorOperandsDeviceTypeAttr ())))
3108- return emitOpError () << " duplicate device_type found in "
3109- " vectorOperandsDeviceType attribute" ;
3116+ if (auto duplicateDeviceType = checkDeviceTypes (getVectorAttr ()))
3117+ return emitOpError () << " duplicate device_type `"
3118+ << acc::stringifyDeviceType (*duplicateDeviceType)
3119+ << " ` found in vector attribute" ;
3120+ if (auto duplicateDeviceType =
3121+ checkDeviceTypes (getVectorOperandsDeviceTypeAttr ()))
3122+ return emitOpError () << " duplicate device_type `"
3123+ << acc::stringifyDeviceType (*duplicateDeviceType)
3124+ << " ` found in vectorOperandsDeviceType attribute" ;
31103125 if (failed (verifyDeviceTypeCountMatch (*this , getVectorOperands (),
31113126 getVectorOperandsDeviceTypeAttr (),
31123127 " vector" )))
@@ -4096,7 +4111,8 @@ LogicalResult acc::RoutineOp::verify() {
40964111
40974112 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1 ))
40984113 return emitError () << " only one of `gang`, `worker`, `vector`, `seq` can "
4099- " be present at the same time" ;
4114+ " be present at the same time for device_type `"
4115+ << acc::stringifyDeviceType (dtype) << " `" ;
41004116 }
41014117
41024118 return success ();
0 commit comments