From 5d21a189ba7dd1e585d26a717b62c90e6c136a62 Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Wed, 23 Oct 2024 20:22:07 +0800 Subject: [PATCH] [mlir][GPU] Add FunctionOpInterface check for `OpToFuncCallLowering` This PR adds a `FunctionOpInterface` check in `OpToFuncCallLowering` to resolve a crash when ops not in function. --- .../Conversion/GPUCommon/OpToFuncCallLowering.h | 5 +++++ .../Conversion/MathToROCDL/math-to-rocdl.mlir | 16 +++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index 1cf8a1acb3193..3b94abd88f9ed 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -61,6 +61,11 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { SourceOp>::value, "expected op with same operand and result types"); + if (!op->template getParentOfType()) { + return rewriter.notifyMatchFailure( + op, "expected op to be within a function region"); + } + SmallVector castedOperands; for (Value operand : adaptor.getOperands()) castedOperands.push_back(maybeCast(operand, rewriter)); diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir index ddd96bf797e6e..e0ea18d41f66d 100644 --- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir +++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-math-to-rocdl -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-math-to-rocdl -allow-unregistered-dialect -split-input-file | FileCheck %s module @test_module { // CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16 @@ -481,3 +481,17 @@ module @test_module { func.return %resultf16, %resultf32, %resultf64, %resultbf16 : f16, f32, f64, bf16 } } + +// ----- + +// Math operation not inside function +// Ensure it not crash + +module { + "test.some_op_with_region"() ({ + ^bb0(%arg0: f64): + // CHECK: math.atan + %0 = math.atan %arg0 : f64 + "test.possible_terminator"() : () -> () + }) : () -> () +}