Skip to content

Commit 5defa85

Browse files
[mlir][acc] Improve verifier messages for device_type duplicates (#170773)
This improves the acc dialect IR verifier messages when duplicate device_types are found by also noting which device_type is the one causing the error.
1 parent 834b8b7 commit 5defa85

File tree

2 files changed

+78
-24
lines changed

2 files changed

+78
-24
lines changed

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3028,19 +3028,21 @@ bool hasDuplicateDeviceTypes(
30283028
}
30293029

30303030
/// Check for duplicates in the DeviceType array attribute.
3031-
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
3031+
/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found.
3032+
static std::optional<mlir::acc::DeviceType>
3033+
checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
30323034
llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
30333035
if (!deviceTypes)
3034-
return success();
3036+
return std::nullopt;
30353037
for (auto attr : deviceTypes) {
30363038
auto deviceTypeAttr =
30373039
mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
30383040
if (!deviceTypeAttr)
3039-
return failure();
3041+
return mlir::acc::DeviceType::None;
30403042
if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3041-
return failure();
3043+
return deviceTypeAttr.getValue();
30423044
}
3043-
return success();
3045+
return std::nullopt;
30443046
}
30453047

30463048
LogicalResult acc::LoopOp::verify() {
@@ -3067,9 +3069,10 @@ LogicalResult acc::LoopOp::verify() {
30673069
getCollapseDeviceTypeAttr().getValue().size())
30683070
return emitOpError() << "collapse attribute count must match collapse"
30693071
<< " device_type count";
3070-
if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
3071-
return emitOpError()
3072-
<< "duplicate device_type found in collapseDeviceType attribute";
3072+
if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr()))
3073+
return emitOpError() << "duplicate device_type `"
3074+
<< acc::stringifyDeviceType(*duplicateDeviceType)
3075+
<< "` found in collapseDeviceType attribute";
30733076

30743077
// Check gang
30753078
if (!getGangOperands().empty()) {
@@ -3082,31 +3085,43 @@ LogicalResult acc::LoopOp::verify() {
30823085
return emitOpError() << "gangOperandsArgType attribute count must match"
30833086
<< " gangOperands count";
30843087
}
3085-
if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
3086-
return emitOpError() << "duplicate device_type found in gang attribute";
3088+
if (getGangAttr()) {
3089+
if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr()))
3090+
return emitOpError() << "duplicate device_type `"
3091+
<< acc::stringifyDeviceType(*duplicateDeviceType)
3092+
<< "` found in gang attribute";
3093+
}
30873094

30883095
if (failed(verifyDeviceTypeAndSegmentCountMatch(
30893096
*this, getGangOperands(), getGangOperandsSegmentsAttr(),
30903097
getGangOperandsDeviceTypeAttr(), "gang")))
30913098
return failure();
30923099

30933100
// Check worker
3094-
if (failed(checkDeviceTypes(getWorkerAttr())))
3095-
return emitOpError() << "duplicate device_type found in worker attribute";
3096-
if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
3097-
return emitOpError() << "duplicate device_type found in "
3098-
"workerNumOperandsDeviceType attribute";
3101+
if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr()))
3102+
return emitOpError() << "duplicate device_type `"
3103+
<< acc::stringifyDeviceType(*duplicateDeviceType)
3104+
<< "` found in worker attribute";
3105+
if (auto duplicateDeviceType =
3106+
checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))
3107+
return emitOpError() << "duplicate device_type `"
3108+
<< acc::stringifyDeviceType(*duplicateDeviceType)
3109+
<< "` found in workerNumOperandsDeviceType attribute";
30993110
if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
31003111
getWorkerNumOperandsDeviceTypeAttr(),
31013112
"worker")))
31023113
return failure();
31033114

31043115
// Check vector
3105-
if (failed(checkDeviceTypes(getVectorAttr())))
3106-
return emitOpError() << "duplicate device_type found in vector attribute";
3107-
if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
3108-
return emitOpError() << "duplicate device_type found in "
3109-
"vectorOperandsDeviceType attribute";
3116+
if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr()))
3117+
return emitOpError() << "duplicate device_type `"
3118+
<< acc::stringifyDeviceType(*duplicateDeviceType)
3119+
<< "` found in vector attribute";
3120+
if (auto duplicateDeviceType =
3121+
checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))
3122+
return emitOpError() << "duplicate device_type `"
3123+
<< acc::stringifyDeviceType(*duplicateDeviceType)
3124+
<< "` found in vectorOperandsDeviceType attribute";
31103125
if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
31113126
getVectorOperandsDeviceTypeAttr(),
31123127
"vector")))
@@ -4096,7 +4111,8 @@ LogicalResult acc::RoutineOp::verify() {
40964111

40974112
if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
40984113
return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4099-
"be present at the same time";
4114+
"be present at the same time for device_type `"
4115+
<< acc::stringifyDeviceType(dtype) << "`";
41004116
}
41014117

41024118
return success();

mlir/test/Dialect/OpenACC/invalid.mlir

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,27 +76,65 @@ acc.loop {
7676

7777
// -----
7878

79-
// expected-error@+1 {{'acc.loop' op duplicate device_type found in gang attribute}}
79+
// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in gang attribute}}
8080
acc.loop {
8181
acc.yield
8282
} attributes {gang = [#acc.device_type<none>, #acc.device_type<none>]}
8383

8484
// -----
8585

86-
// expected-error@+1 {{'acc.loop' op duplicate device_type found in worker attribute}}
86+
// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in worker attribute}}
8787
acc.loop {
8888
acc.yield
8989
} attributes {worker = [#acc.device_type<none>, #acc.device_type<none>]}
9090

9191
// -----
9292

93-
// expected-error@+1 {{'acc.loop' op duplicate device_type found in vector attribute}}
93+
// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in vector attribute}}
9494
acc.loop {
9595
acc.yield
9696
} attributes {vector = [#acc.device_type<none>, #acc.device_type<none>]}
9797

9898
// -----
9999

100+
// expected-error@+1 {{'acc.loop' op duplicate device_type `nvidia` found in gang attribute}}
101+
acc.loop {
102+
acc.yield
103+
} attributes {gang = [#acc.device_type<nvidia>, #acc.device_type<nvidia>]}
104+
105+
// -----
106+
107+
// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in collapseDeviceType attribute}}
108+
acc.loop {
109+
acc.yield
110+
} attributes {collapse = [1, 1], collapseDeviceType = [#acc.device_type<none>, #acc.device_type<none>], independent = [#acc.device_type<none>]}
111+
112+
// -----
113+
114+
%i64value = arith.constant 1 : i64
115+
// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in workerNumOperandsDeviceType attribute}}
116+
acc.loop worker(%i64value: i64, %i64value: i64) {
117+
acc.yield
118+
} attributes {workerNumOperandsDeviceType = [#acc.device_type<none>, #acc.device_type<none>], independent = [#acc.device_type<none>]}
119+
120+
// -----
121+
122+
%i64value = arith.constant 1 : i64
123+
// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in vectorOperandsDeviceType attribute}}
124+
acc.loop vector(%i64value: i64, %i64value: i64) {
125+
acc.yield
126+
} attributes {vectorOperandsDeviceType = [#acc.device_type<none>, #acc.device_type<none>], independent = [#acc.device_type<none>]}
127+
128+
// -----
129+
130+
func.func @acc_routine_parallelism() -> () {
131+
return
132+
}
133+
// expected-error@+1 {{only one of `gang`, `worker`, `vector`, `seq` can be present at the same time for device_type `nvidia`}}
134+
"acc.routine"() <{func_name = @acc_routine_parallelism, sym_name = "acc_routine_parallelism_rout", gang = [#acc.device_type<nvidia>], worker = [#acc.device_type<nvidia>]}> : () -> ()
135+
136+
// -----
137+
100138
%1 = arith.constant 1 : i32
101139
%2 = arith.constant 10 : i32
102140
// expected-error@+1 {{only one of auto, independent, seq can be present at the same time}}

0 commit comments

Comments
 (0)