Skip to content

Commit 57f6b2a

Browse files
committed
Fix callable detection to support all callable objects (functools.partial, etc)
Fixes #2008 The old implementation checked for __code__ attribute directly, which doesn't exist for callable objects like functools.partial. This caused an AttributeError when trying to use such objects with find_max_global and find_min_global functions. Changed to use Python's inspect.signature() which properly handles all callable objects including: - Regular functions - Lambda functions - functools.partial objects - Class methods - Any object with __call__ method The fix includes a fallback to the old __code__ method for backward compatibility in case inspect.signature fails for any reason.
1 parent 20b2172 commit 57f6b2a

File tree

2 files changed

+82
-5
lines changed

2 files changed

+82
-5
lines changed

test_callable_fix.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/usr/bin/env python3
2+
"""Test script to verify the fix for issue #2008"""
3+
from functools import partial
4+
5+
def test_function(x, y):
6+
"""Simple test function"""
7+
return -(x**2 + y**2)
8+
9+
def test_single_arg(x):
10+
"""Single argument function"""
11+
return -(x**2)
12+
13+
# Test with functools.partial
14+
partial_func = partial(test_function, 2)
15+
16+
print("Testing functools.partial with dlib.find_max_global...")
17+
print(f"partial_func type: {type(partial_func)}")
18+
print(f"partial_func callable: {callable(partial_func)}")
19+
20+
# This should work after the fix
21+
try:
22+
import dlib
23+
result = dlib.find_max_global(partial_func, [0.], [10.], 100)
24+
print(f"✓ Success! Result: {result}")
25+
except AttributeError as e:
26+
print(f"✗ Failed with AttributeError: {e}")
27+
except Exception as e:
28+
print(f"✗ Failed with error: {e}")
29+
30+
# Test with regular function (should still work)
31+
print("\nTesting regular function with dlib.find_max_global...")
32+
try:
33+
import dlib
34+
result = dlib.find_max_global(test_single_arg, [0.], [10.], 100)
35+
print(f"✓ Success! Result: {result}")
36+
except Exception as e:
37+
print(f"✗ Failed with error: {e}")
38+
39+
# Test with lambda (should still work)
40+
print("\nTesting lambda with dlib.find_max_global...")
41+
try:
42+
import dlib
43+
result = dlib.find_max_global(lambda x: -(x**2), [0.], [10.], 100)
44+
print(f"✓ Success! Result: {result}")
45+
except Exception as e:
46+
print(f"✗ Failed with error: {e}")

tools/python/src/global_optimization.cpp

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,42 @@ py::list mat_to_list (
4848

4949
size_t num_function_arguments(py::object f, size_t expected_num)
5050
{
51-
const auto code_object = f.attr(hasattr(f,"func_code") ? "func_code" : "__code__");
52-
const auto num = code_object.attr("co_argcount").cast<std::size_t>();
53-
if (num < expected_num && (code_object.attr("co_flags").cast<int>() & CO_VARARGS))
54-
return expected_num;
55-
return num;
51+
// Use Python's inspect module to get signature, which works with all callable objects
52+
// including functools.partial, lambdas, and regular functions
53+
auto inspect = py::module_::import("inspect");
54+
55+
try {
56+
auto sig = inspect.attr("signature")(f);
57+
auto params = sig.attr("parameters");
58+
auto num = py::len(params);
59+
60+
// Check if function accepts *args (VAR_POSITIONAL)
61+
bool has_var_args = false;
62+
for (auto item : params) {
63+
auto param = item.second.cast<py::object>();
64+
auto kind = param.attr("kind");
65+
// inspect.Parameter.VAR_POSITIONAL == 2
66+
if (kind.cast<int>() == 2) {
67+
has_var_args = true;
68+
break;
69+
}
70+
}
71+
72+
if (num < expected_num && has_var_args)
73+
return expected_num;
74+
return num;
75+
} catch (const py::error_already_set&) {
76+
// Fallback to old method if inspect.signature fails
77+
// This maintains backward compatibility
78+
if (!hasattr(f, "__code__") && !hasattr(f, "func_code")) {
79+
throw;
80+
}
81+
const auto code_object = f.attr(hasattr(f,"func_code") ? "func_code" : "__code__");
82+
const auto num = code_object.attr("co_argcount").cast<std::size_t>();
83+
if (num < expected_num && (code_object.attr("co_flags").cast<int>() & CO_VARARGS))
84+
return expected_num;
85+
return num;
86+
}
5687
}
5788

5889
double call_func(py::object f, const matrix<double,0,1>& args)

0 commit comments

Comments
 (0)