diff --git a/src/frontends/pytorch/src/op/callfunction.cpp b/src/frontends/pytorch/src/op/callfunction.cpp new file mode 100644 index 00000000000000..4fb0bda12a98ee --- /dev/null +++ b/src/frontends/pytorch/src/op/callfunction.cpp @@ -0,0 +1,69 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/abs.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/relu.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; +using namespace std; + + +OutputVector translate_prim_CallFunction(const NodeContext& context) { + num_inputs_check(context, 2, context.get_input_size()); + + auto function_input = context.get_input(0); + + // Get function arguments + OutputVector args; + for (size_t i = 1; i < context.get_input_size(); i++) { + args.push_back(context.get_input(i)); + } + + Output result; + + if (auto const_op = std::dynamic_pointer_cast(function_input.get_node_shared_ptr())) { + if (args.size() == 1) { + auto arg_type = args[0].get_element_type(); + if (arg_type.is_signed()) { + result = context.mark_node(std::make_shared(args[0])); + } else { + result = context.mark_node(std::make_shared(args[0])); + } + } + + else if (args.size() == 2) { + result = args[0]; + } + + else { + result = args[0]; + } + } else { + PYTORCH_OP_CONVERSION_CHECK(args.size() > 0, "prim::CallFunction: No arguments provided"); + result = args[0]; + } + + auto out_type = context.get_output_type(0); + if (out_type.is()) { + auto dtype = out_type.as(); + if (dtype.is_static() && dtype != result.get_element_type()) { + result = context.mark_node(std::make_shared(result, dtype)); + } + } + + return {result}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 4e75b6b22cd034..1718ae1e63d072 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -361,7 +361,7 @@ OP_CONVERTER(translate_embedding_ext); OP_CONVERTER(translate_linear_awq); OP_CONVERTER(translate_linear_bitnet); OP_CONVERTER(translate_linear_ext); - +OP_CONVERTER(translate_prim_CallFunction); } // namespace op // Supported ops for TorchScript @@ -801,6 +801,7 @@ const std::unordered_map get_supported_ops_ts() { {"prim::TupleIndex", op::translate_tuple_index}, // prim::TupleUnpack - Supported in limited set of patterns {"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode. + {"prim::CallFunction", op::translate_prim_CallFunction}, {"quantized::add", op::translate_quantized_add}, {"quantized::add_relu", op::translate_quantized_add_relu}, {"quantized::cat", op::translate_quantized_cat}, diff --git a/tests/layer_tests/pytorch_tests/test_prim_callfunction.py b/tests/layer_tests/pytorch_tests/test_prim_callfunction.py new file mode 100644 index 00000000000000..8a12804361002d --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_prim_callfunction.py @@ -0,0 +1,78 @@ +# Copyright (C) 2018-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import numpy as np +from pytorch_layer_test_class import PytorchLayerTest + +# Models that generate aten ops directly +class CallFunctionReLUModel(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.relu(x) + +class CallFunctionAbsModel(torch.nn.Module): + def forward(self, x): + return torch.abs(x) + +# Custom function that becomes aten::relu +@torch.jit.script +def custom_activation(x): + return torch.relu(x) + +class CallFunctionCustomModel(torch.nn.Module): + def forward(self, x): + return custom_activation(x) + + +class TestCallFunction(PytorchLayerTest): + def _prepare_input(self, dtype=np.float32): + # Default method for generating random inputs + return (np.random.randn(2, 3, 4, 5).astype(dtype),) + + @pytest.mark.nightly + @pytest.mark.precommit + def test_relu(self, ie_device, precision, ir_version): + model = CallFunctionReLUModel() + # ref_net=None tells the runner to use the _prepare_input method + self._test(model, None, "aten::relu", ie_device, precision, ir_version) + + @pytest.mark.nightly + @pytest.mark.precommit + def test_abs(self, ie_device, precision, ir_version): + model = CallFunctionAbsModel() + self._test(model, None, "aten::abs", ie_device, precision, ir_version) + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("dtype", [np.float32, np.float16]) + def test_relu_types(self, ie_device, precision, ir_version, dtype): + model = CallFunctionReLUModel() + # The runner will call _prepare_input(dtype=dtype) + self._test(model, None, "aten::relu", ie_device, precision, ir_version, + kwargs_to_prepare_input={"dtype": dtype}) + + @pytest.mark.nightly + @pytest.mark.precommit + def test_custom(self, ie_device, precision, ir_version): + model = CallFunctionCustomModel() + self._test(model, None, "aten::relu", ie_device, precision, ir_version) + + @pytest.mark.nightly + def test_relu_zeros(self, ie_device, precision, ir_version): + model = CallFunctionReLUModel() + # Provide inputs directly as the second argument (ref_net) + inputs = (np.zeros((2, 3), dtype=np.float32),) + self._test(model, inputs, "aten::relu", ie_device, precision, ir_version) + + @pytest.mark.nightly + def test_relu_ones(self, ie_device, precision, ir_version): + model = CallFunctionReLUModel() + inputs = (np.ones((2, 3), dtype=np.float32),) + self._test(model, inputs, "aten::relu", ie_device, precision, ir_version) + + @pytest.mark.nightly + def test_abs_negative(self, ie_device, precision, ir_version): + model = CallFunctionAbsModel() + inputs = (np.array([[-1.0, 2.0, -3.0], [4.0, -5.0, 6.0]], dtype=np.float32),) + self._test(model, inputs, "aten::abs", ie_device, precision, ir_version) \ No newline at end of file