Skip to content

Commit 54c5521

Browse files
amd-eochoalokuhar
andauthored
[mlir][spirv] Use verifySymbolUses for spirv.FunctionCall. (#159399)
`spirv.FunctionCall`'s verifier was being too aggressive. It included verification of non-local properties by looking at the callee's definition. This caused problems in cases where callee had verification errors and could lead to null pointer dereferencing. According to [MLIR's developers guide ](https://mlir.llvm.org/getting_started/DeveloperGuide/#ir-verifier) > TLDR: only verify local aspects of an operation, > in particular don’t follow def-use chains > (don’t look at the producer of any operand or the user > of any results). The fix includes adding the `SymbolUserOpInterface` to `FunctionCall` and moving most of the verification logic to `verifySymbolUses`. Fixes #159295 --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent 714f032 commit 54c5521

File tree

3 files changed

+44
-9
lines changed

3 files changed

+44
-9
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define MLIR_DIALECT_SPIRV_IR_CONTROLFLOW_OPS
1616

1717
include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
18+
include "mlir/IR/SymbolInterfaces.td"
1819
include "mlir/Interfaces/CallInterfaces.td"
1920
include "mlir/Interfaces/ControlFlowInterfaces.td"
2021
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -187,7 +188,8 @@ def SPIRV_BranchConditionalOp : SPIRV_Op<"BranchConditional", [
187188
// -----
188189

189190
def SPIRV_FunctionCallOp : SPIRV_Op<"FunctionCall", [
190-
InFunctionScope, DeclareOpInterfaceMethods<CallOpInterface>]> {
191+
InFunctionScope, DeclareOpInterfaceMethods<CallOpInterface>,
192+
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
191193
let summary = "Call a function.";
192194

193195
let description = [{

mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,23 +151,27 @@ LogicalResult BranchConditionalOp::verify() {
151151
//===----------------------------------------------------------------------===//
152152

153153
LogicalResult FunctionCallOp::verify() {
154+
if (getNumResults() > 1) {
155+
return emitOpError(
156+
"expected callee function to have 0 or 1 result, but provided ")
157+
<< getNumResults();
158+
}
159+
return success();
160+
}
161+
162+
LogicalResult
163+
FunctionCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
154164
auto fnName = getCalleeAttr();
155165

156-
auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
157-
SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
166+
auto funcOp =
167+
symbolTable.lookupNearestSymbolFrom<spirv::FuncOp>(*this, fnName);
158168
if (!funcOp) {
159169
return emitOpError("callee function '")
160170
<< fnName.getValue() << "' not found in nearest symbol table";
161171
}
162172

163173
auto functionType = funcOp.getFunctionType();
164174

165-
if (getNumResults() > 1) {
166-
return emitOpError(
167-
"expected callee function to have 0 or 1 result, but provided ")
168-
<< getNumResults();
169-
}
170-
171175
if (functionType.getNumInputs() != getNumOperands()) {
172176
return emitOpError("has incorrect number of operands for callee: expected ")
173177
<< functionType.getNumInputs() << ", but provided "

mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,35 @@ spirv.module Logical GLSL450 {
262262

263263
// -----
264264

265+
"builtin.module"() ({
266+
"spirv.module"() <{
267+
addressing_model = #spirv.addressing_model<Logical>,
268+
memory_model = #spirv.memory_model<GLSL450>
269+
}> ({
270+
"spirv.func"() <{
271+
function_control = #spirv.function_control<None>,
272+
function_type = (f32) -> f32,
273+
sym_name = "bar"
274+
}> ({
275+
^bb0(%arg0: f32):
276+
%0 = "spirv.FunctionCall"(%arg0) <{callee = @foo}> : (f32) -> f32
277+
"spirv.ReturnValue"(%0) : (f32) -> ()
278+
}) : () -> ()
279+
// expected-error @+1 {{requires attribute 'function_type'}}
280+
"spirv.func"() <{
281+
function_control = #spirv.function_control<None>,
282+
message = "2nd parent",
283+
sym_name = "foo"
284+
// This is invalid MLIR because function_type is missing from spirv.func.
285+
}> ({
286+
^bb0(%arg0: f32):
287+
"spirv.ReturnValue"(%arg0) : (f32) -> ()
288+
}) : () -> ()
289+
}) : () -> ()
290+
}) : () -> ()
291+
292+
// -----
293+
265294
//===----------------------------------------------------------------------===//
266295
// spirv.mlir.loop
267296
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)