Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/frontends/pytorch/src/op/callfunction.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/abs.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/relu.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;
using namespace std;


OutputVector translate_prim_CallFunction(const NodeContext& context) {
num_inputs_check(context, 2, context.get_input_size());

auto function_input = context.get_input(0);

// Get function arguments
OutputVector args;
for (size_t i = 1; i < context.get_input_size(); i++) {
args.push_back(context.get_input(i));
}

Output<Node> result;

if (auto const_op = std::dynamic_pointer_cast<v0::Constant>(function_input.get_node_shared_ptr())) {
if (args.size() == 1) {
auto arg_type = args[0].get_element_type();
if (arg_type.is_signed()) {
result = context.mark_node(std::make_shared<v0::Abs>(args[0]));
} else {
result = context.mark_node(std::make_shared<v0::Relu>(args[0]));
Comment on lines +37 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You expect function to be either Abs or Relu, but it can be any subgraph

}
}

else if (args.size() == 2) {
result = args[0];
}

else {
result = args[0];
}
} else {
PYTORCH_OP_CONVERSION_CHECK(args.size() > 0, "prim::CallFunction: No arguments provided");
result = args[0];
}

auto out_type = context.get_output_type(0);
if (out_type.is<element::Type>()) {
auto dtype = out_type.as<element::Type>();
if (dtype.is_static() && dtype != result.get_element_type()) {
result = context.mark_node(std::make_shared<v0::Convert>(result, dtype));
}
}

return {result};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
3 changes: 2 additions & 1 deletion src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ OP_CONVERTER(translate_embedding_ext);
OP_CONVERTER(translate_linear_awq);
OP_CONVERTER(translate_linear_bitnet);
OP_CONVERTER(translate_linear_ext);

OP_CONVERTER(translate_prim_CallFunction);
} // namespace op

// Supported ops for TorchScript
Expand Down Expand Up @@ -801,6 +801,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"prim::TupleIndex", op::translate_tuple_index},
// prim::TupleUnpack - Supported in limited set of patterns
{"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode.
{"prim::CallFunction", op::translate_prim_CallFunction},
{"quantized::add", op::translate_quantized_add},
{"quantized::add_relu", op::translate_quantized_add_relu},
{"quantized::cat", op::translate_quantized_cat},
Expand Down
78 changes: 78 additions & 0 deletions tests/layer_tests/pytorch_tests/test_prim_callfunction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import numpy as np
from pytorch_layer_test_class import PytorchLayerTest

# Models that generate aten ops directly
class CallFunctionReLUModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.relu(x)

class CallFunctionAbsModel(torch.nn.Module):
def forward(self, x):
return torch.abs(x)

# Custom function that becomes aten::relu
@torch.jit.script
def custom_activation(x):
return torch.relu(x)

class CallFunctionCustomModel(torch.nn.Module):
def forward(self, x):
return custom_activation(x)


class TestCallFunction(PytorchLayerTest):
def _prepare_input(self, dtype=np.float32):
# Default method for generating random inputs
return (np.random.randn(2, 3, 4, 5).astype(dtype),)

@pytest.mark.nightly
@pytest.mark.precommit
def test_relu(self, ie_device, precision, ir_version):
model = CallFunctionReLUModel()
# ref_net=None tells the runner to use the _prepare_input method
self._test(model, None, "aten::relu", ie_device, precision, ir_version)

@pytest.mark.nightly
@pytest.mark.precommit
def test_abs(self, ie_device, precision, ir_version):
model = CallFunctionAbsModel()
self._test(model, None, "aten::abs", ie_device, precision, ir_version)

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("dtype", [np.float32, np.float16])
def test_relu_types(self, ie_device, precision, ir_version, dtype):
model = CallFunctionReLUModel()
# The runner will call _prepare_input(dtype=dtype)
self._test(model, None, "aten::relu", ie_device, precision, ir_version,
kwargs_to_prepare_input={"dtype": dtype})

@pytest.mark.nightly
@pytest.mark.precommit
def test_custom(self, ie_device, precision, ir_version):
model = CallFunctionCustomModel()
self._test(model, None, "aten::relu", ie_device, precision, ir_version)

@pytest.mark.nightly
def test_relu_zeros(self, ie_device, precision, ir_version):
model = CallFunctionReLUModel()
# Provide inputs directly as the second argument (ref_net)
inputs = (np.zeros((2, 3), dtype=np.float32),)
self._test(model, inputs, "aten::relu", ie_device, precision, ir_version)

@pytest.mark.nightly
def test_relu_ones(self, ie_device, precision, ir_version):
model = CallFunctionReLUModel()
inputs = (np.ones((2, 3), dtype=np.float32),)
self._test(model, inputs, "aten::relu", ie_device, precision, ir_version)

@pytest.mark.nightly
def test_abs_negative(self, ie_device, precision, ir_version):
model = CallFunctionAbsModel()
inputs = (np.array([[-1.0, 2.0, -3.0], [4.0, -5.0, 6.0]], dtype=np.float32),)
self._test(model, inputs, "aten::abs", ie_device, precision, ir_version)