diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 1709654b90138..f82b20712b8c6 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -386,9 +386,7 @@ OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } Operation *ExpressionOp::getRootOp() { auto yieldOp = cast(getBody()->getTerminator()); Value yieldedValue = yieldOp.getResult(); - Operation *rootOp = yieldedValue.getDefiningOp(); - assert(rootOp && "Yielded value not defined within expression"); - return rootOp; + return yieldedValue.getDefiningOp(); } LogicalResult ExpressionOp::verify() { @@ -406,6 +404,14 @@ LogicalResult ExpressionOp::verify() { if (!yieldResult) return emitOpError("must yield a value at termination"); + Operation *rootOp = yieldResult.getDefiningOp(); + + if (!rootOp) + return emitOpError("yielded value has no defining op"); + + if (rootOp->getParentOp() != getOperation()) + return emitOpError("yielded value not defined within expression"); + Type yieldType = yieldResult.getType(); if (resultType != yieldType) diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index 3793dfe3f173b..3946a36a83c6f 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -346,6 +346,28 @@ func.func @test_expression_multiple_results(%arg0: i32) -> i32 { // ----- +emitc.func @test_expression_no_defining_op(%a : i32) { + // expected-error @+1 {{'emitc.expression' op yielded value has no defining op}} + %res = emitc.expression : i32 { + emitc.yield %a : i32 + } + + return +} + +// ----- + +emitc.func @test_expression_op_outside_expression() { + %cond = literal "true" : i1 + // expected-error @+1 {{'emitc.expression' op yielded value not defined within expression}} + %res = emitc.expression : i1 { + emitc.yield %cond : i1 + } + return +} + +// ----- + // expected-error @+1 {{'emitc.func' op requires zero or exactly one result, but has 2}} emitc.func @multiple_results(%0: i32) -> (i32, i32) { emitc.return %0 : i32