Skip to content

Commit ee40e7d

Browse files
committed
Changed callback support to use C++ rather than C for increased memory safety and ownership. e.g., avoid out-of-bounds, especially after changes to the model.
Includes a small breaking change for the C API: renamed HighsCallbackDataOut to be Highs*C*CallbackDataOut. Similar with HighsCCallbackDataIn, but it also requires user_solution_size to be set. We will only access elements in memory up to user_solution_size, however we won't do anything unless user_solution_size == num_cols. Also added minor usability changes to highspy callbacks + a setObjective function that should've been previously added. Minor breaking change to highspy, the mip_solution no longer needs/supports to_array(n).
1 parent 9ac9478 commit ee40e7d

File tree

14 files changed

+223
-101
lines changed

14 files changed

+223
-101
lines changed

check/TestCAPI.c

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
const HighsInt dev_run = 0;
1313
const double double_equal_tolerance = 1e-5;
1414

15-
void checkGetCallbackDataOutPointer(const HighsCallbackDataOut* data_out,
15+
void checkGetCallbackDataOutPointer(const HighsCCallbackDataOut* data_out,
1616
const char* name, HighsInt valid) {
1717
const void* name_p = Highs_getCallbackDataOutItem(data_out, name);
1818
if (valid) {
@@ -28,7 +28,7 @@ void checkGetCallbackDataOutPointer(const HighsCallbackDataOut* data_out,
2828
}
2929
}
3030

31-
void checkGetCallbackDataOutHighsInt(const HighsCallbackDataOut* data_out,
31+
void checkGetCallbackDataOutHighsInt(const HighsCCallbackDataOut* data_out,
3232
const char* name, HighsInt value) {
3333
const void* name_p = Highs_getCallbackDataOutItem(data_out, name);
3434
if (!name_p) {
@@ -46,7 +46,7 @@ void checkGetCallbackDataOutHighsInt(const HighsCallbackDataOut* data_out,
4646
}
4747
}
4848

49-
void checkGetCallbackDataOutInt(const HighsCallbackDataOut* data_out,
49+
void checkGetCallbackDataOutInt(const HighsCCallbackDataOut* data_out,
5050
const char* name, int value) {
5151
const void* name_p = Highs_getCallbackDataOutItem(data_out, name);
5252
if (!name_p) {
@@ -64,7 +64,7 @@ void checkGetCallbackDataOutInt(const HighsCallbackDataOut* data_out,
6464
}
6565
}
6666

67-
void checkGetCallbackDataOutInt64(const HighsCallbackDataOut* data_out,
67+
void checkGetCallbackDataOutInt64(const HighsCCallbackDataOut* data_out,
6868
const char* name, int64_t value) {
6969
const void* name_p = Highs_getCallbackDataOutItem(data_out, name);
7070
if (!name_p) {
@@ -82,7 +82,7 @@ void checkGetCallbackDataOutInt64(const HighsCallbackDataOut* data_out,
8282
}
8383
}
8484

85-
void checkGetCallbackDataOutDouble(const HighsCallbackDataOut* data_out,
85+
void checkGetCallbackDataOutDouble(const HighsCCallbackDataOut* data_out,
8686
const char* name, double value) {
8787
const void* name_p = Highs_getCallbackDataOutItem(data_out, name);
8888
if (!name_p) {
@@ -101,8 +101,8 @@ void checkGetCallbackDataOutDouble(const HighsCallbackDataOut* data_out,
101101
}
102102

103103
static void userCallback(const int callback_type, const char* message,
104-
const HighsCallbackDataOut* data_out,
105-
HighsCallbackDataIn* data_in,
104+
const HighsCCallbackDataOut* data_out,
105+
HighsCCallbackDataIn* data_in,
106106
void* user_callback_data) {
107107
// Extract the double value pointed to from void* user_callback_data
108108
const double local_callback_data =

src/highs_bindings.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class readonly_ptr_wrapper {
3333

3434
py::array_t<T, py::array::c_style> to_array(std::size_t size) {
3535
return py::array_t<T, py::array::c_style>(py::buffer_info(
36-
ptr, sizeof(T), py::format_descriptor<T>::format(), 1, {size}, {1}));
36+
ptr, sizeof(T), py::format_descriptor<T>::format(), 1, {size}, { sizeof(double) }));
3737
}
3838

3939
private:
@@ -1585,12 +1585,8 @@ PYBIND11_MODULE(_core, m, py::mod_gil_not_used()) {
15851585
.value("kNumCallbackType", HighsCallbackType::kNumCallbackType)
15861586
.export_values();
15871587
// Classes
1588-
py::class_<readonly_ptr_wrapper<double>>(m, "readonly_ptr_wrapper_double", py::module_local())
1589-
.def(py::init<double*>())
1590-
.def("__getitem__", &readonly_ptr_wrapper<double>::operator[])
1591-
.def("__bool__", &readonly_ptr_wrapper<double>::is_valid)
1592-
.def("to_array", &readonly_ptr_wrapper<double>::to_array);
1593-
py::class_<HighsCallbackDataOut>(callbacks, "HighsCallbackDataOut", py::module_local())
1588+
py::class_<HighsCallbackDataOut>(callbacks, "HighsCallbackDataOut",
1589+
py::module_local())
15941590
.def(py::init<>())
15951591
.def_readwrite("log_type", &HighsCallbackDataOut::log_type)
15961592
.def_readwrite("running_time", &HighsCallbackDataOut::running_time)
@@ -1607,12 +1603,15 @@ PYBIND11_MODULE(_core, m, py::mod_gil_not_used()) {
16071603
&HighsCallbackDataOut::mip_primal_bound)
16081604
.def_readwrite("mip_dual_bound", &HighsCallbackDataOut::mip_dual_bound)
16091605
.def_readwrite("mip_gap", &HighsCallbackDataOut::mip_gap)
1610-
.def_property_readonly(
1611-
"mip_solution",
1612-
[](const HighsCallbackDataOut& self) -> readonly_ptr_wrapper<double> {
1613-
return readonly_ptr_wrapper<double>(self.mip_solution);
1614-
});
1606+
.def_readwrite("mip_solution", &HighsCallbackDataOut::mip_solution)
1607+
.def_readwrite("cutpool_num_col", &HighsCallbackDataOut::cutpool_num_col)
1608+
.def_readwrite("cutpool_index", &HighsCallbackDataOut::cutpool_index)
1609+
.def_readwrite("cutpool_value", &HighsCallbackDataOut::cutpool_value)
1610+
.def_readwrite("cutpool_lower", &HighsCallbackDataOut::cutpool_lower)
1611+
.def_readwrite("cutpool_upper", &HighsCallbackDataOut::cutpool_upper);
1612+
16151613
py::class_<HighsCallbackDataIn>(callbacks, "HighsCallbackDataIn", py::module_local())
16161614
.def(py::init<>())
1617-
.def_readwrite("user_interrupt", &HighsCallbackDataIn::user_interrupt);
1615+
.def_readwrite("user_interrupt", &HighsCallbackDataIn::user_interrupt)
1616+
.def_readwrite("user_solution", &HighsCallbackDataIn::user_solution);
16181617
}

src/highspy/_core/cb.pyi

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ __all__ = [
2626

2727
class HighsCallbackDataIn:
2828
user_interrupt: int
29+
user_solution: list[float]
2930
def __init__(self) -> None: ...
3031

3132
class HighsCallbackDataOut:
@@ -39,9 +40,14 @@ class HighsCallbackDataOut:
3940
pdlp_iteration_count: int
4041
running_time: float
4142
simplex_iteration_count: int
43+
mip_solution: list[float]
44+
cutpool_num_col: int
45+
cutpool_start: list[int]
46+
cutpool_index: list[int]
47+
cutpool_value: list[float]
48+
cutpool_lower: list[float]
49+
cutpool_upper: list[float]
4250
def __init__(self) -> None: ...
43-
@property
44-
def mip_solution(self) -> highspy._core.readonly_ptr_wrapper_double: ...
4551

4652
class HighsCallbackType:
4753
"""

src/highspy/highs.py

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
HighsLinearObjective,
1414
cb, # type: ignore
1515
_Highs, # type: ignore
16-
readonly_ptr_wrapper_double,
1716
kHighsInf,
1817
)
1918

@@ -179,19 +178,17 @@ def optimize(self):
179178
"""
180179
return self.solve()
181180

182-
# reset the objective and sense, then solve
183-
def minimize(self, obj: Optional[Union[highs_var, highs_linear_expression]] = None):
181+
# reset the objective
182+
def setObjective(self, obj: Optional[Union[highs_var, highs_linear_expression]] = None, sense: Optional[ObjSense] = None):
184183
"""
185-
Solves a minimization of the objective and optionally updates the costs.
184+
Updates the costs.
186185
187186
Args:
188187
obj: An optional highs_linear_expression representing the new objective function.
188+
sense: An optional ObjSense value representing the new objective sense.
189189
190190
Raises:
191191
Exception: If obj is an inequality or not a highs_linear_expression.
192-
193-
Returns:
194-
A HighsStatus object containing the solve status after minimization.
195192
"""
196193
if obj is not None:
197194
# if we have a single variable, wrap it in a linear expression
@@ -212,13 +209,13 @@ def minimize(self, obj: Optional[Union[highs_var, highs_linear_expression]] = No
212209
super().changeColsCost(len(idxs), idxs, vals)
213210
super().changeObjectiveOffset(expr.constant or 0.0)
214211

215-
super().changeObjectiveSense(ObjSense.kMinimize)
216-
return self.solve()
212+
if sense is not None:
213+
super().changeObjectiveSense(sense)
217214

218215
# reset the objective and sense, then solve
219-
def maximize(self, obj: Optional[Union[highs_var, highs_linear_expression]] = None):
216+
def minimize(self, obj: Optional[Union[highs_var, highs_linear_expression]] = None):
220217
"""
221-
Solves a maximization of the objective and optionally updates the costs.
218+
Solves a minimization of the objective and optionally updates the costs.
222219
223220
Args:
224221
obj: An optional highs_linear_expression representing the new objective function.
@@ -227,32 +224,31 @@ def maximize(self, obj: Optional[Union[highs_var, highs_linear_expression]] = No
227224
Exception: If obj is an inequality or not a highs_linear_expression.
228225
229226
Returns:
230-
A HighsStatus object containing the solve status after maximization.
227+
A HighsStatus object containing the solve status after minimization.
231228
"""
232-
if obj is not None:
233-
# if we have a single variable, wrap it in a linear expression
234-
expr = highs_linear_expression(obj) if isinstance(obj, highs_var) else obj
229+
self.setObjective(obj, ObjSense.kMinimize)
230+
return self.solve()
235231

236-
if expr.bounds is not None:
237-
raise Exception("Objective cannot be an inequality")
232+
# reset the objective and sense, then solve
233+
def maximize(self, obj: Optional[Union[highs_var, highs_linear_expression]] = None):
234+
"""
235+
Solves a maximization of the objective and optionally updates the costs.
238236
239-
# reset objective
240-
super().changeColsCost(
241-
self.numVariables,
242-
np.arange(self.numVariables, dtype=np.int32),
243-
np.full(self.numVariables, 0, dtype=np.float64),
244-
)
237+
Args:
238+
obj: An optional highs_linear_expression representing the new objective function.
245239
246-
# if we have duplicate variables, add the vals
247-
idxs, vals = expr.unique_elements()
248-
super().changeColsCost(len(idxs), idxs, vals)
249-
super().changeObjectiveOffset(expr.constant or 0.0)
240+
Raises:
241+
Exception: If obj is an inequality or not a highs_linear_expression.
250242
251-
super().changeObjectiveSense(ObjSense.kMaximize)
243+
Returns:
244+
A HighsStatus object containing the solve status after maximization.
245+
"""
246+
self.setObjective(obj, ObjSense.kMaximize)
252247
return self.solve()
248+
253249
@staticmethod
254250
def internal_get_value(
255-
array_values: Union[Sequence[float], np.ndarray[Any, np.dtype[np.float64]], readonly_ptr_wrapper_double],
251+
array_values: Union[Sequence[float], np.ndarray[Any, np.dtype[np.float64]]],
256252
index_collection: Union[
257253
Integral, highs_var, highs_cons, highs_linear_expression, Mapping[Any, Any], Sequence[Any], np.ndarray[Any, np.dtype[Any]]
258254
],
@@ -1203,7 +1199,7 @@ def __internal_callback(
12031199
data_in: Optional[cb.HighsCallbackDataIn],
12041200
user_callback_data: Any,
12051201
):
1206-
user_callback_data.callbacks[int(callback_type)].fire(message, data_out, data_in)
1202+
user_callback_data.callbacks[int(callback_type)].fire(callback_type, message, data_out, data_in)
12071203

12081204
def enableCallbacks(self):
12091205
"""
@@ -1216,9 +1212,16 @@ def enableCallbacks(self):
12161212
if len(c.callbacks) > 0:
12171213
self.startCallback(c.callback_type)
12181214

1215+
def clearCallbacks(self):
1216+
"""
1217+
Clears all callbacks.
1218+
"""
1219+
for c in self.callbacks:
1220+
c.clear()
1221+
12191222
def disableCallbacks(self):
12201223
"""
1221-
Disables all callbacks.
1224+
Disables all callbacks, but does not clear them.
12221225
"""
12231226
status = super().setCallback(None, None) # this will also stop all callbacks
12241227

@@ -1264,8 +1267,8 @@ def HandleUserInterrupt(self, value: bool):
12641267
self.cbMipInterrupt -= self.__user_interrupt_event
12651268

12661269
def __user_interrupt_event(self, e: HighsCallbackEvent):
1267-
if self.__solver_should_stop and e.data_in is not None:
1268-
e.data_in.user_interrupt = True
1270+
if self.__solver_should_stop:
1271+
e.interrupt()
12691272

12701273
@property
12711274
def cbLogging(self):
@@ -1359,20 +1362,29 @@ def cbMipDefineLazyConstraints(self, value: HighsCallback):
13591362
## Callback support
13601363
##
13611364
class HighsCallbackEvent(object):
1362-
__slots__ = ["message", "data_out", "data_in", "user_data"]
1365+
__slots__ = ["callback_type", "message", "data_out", "data_in", "user_data"]
13631366

13641367
def __init__(
13651368
self,
1369+
callback_type: cb.HighsCallbackType,
13661370
message: str,
13671371
data_out: cb.HighsCallbackDataOut,
13681372
data_in: Optional[cb.HighsCallbackDataIn],
1369-
user_data: Optional[Any],
1373+
user_data: Optional[Any]
13701374
):
1375+
self.callback_type = callback_type
13711376
self.message = message
13721377
self.data_out = data_out
13731378
self.data_in = data_in
13741379
self.user_data = user_data
13751380

1381+
def interrupt(self, interrupt_value: bool = True):
1382+
"""
1383+
Sets the user interrupt flag in the callback data.
1384+
"""
1385+
if self.data_in is not None:
1386+
self.data_in.user_interrupt = interrupt_value
1387+
13761388
def val(
13771389
self,
13781390
var_expr: Union[Integral, highs_var, highs_cons, highs_linear_expression, Mapping[Any, Any], np.ndarray[Any, np.dtype[Any]]],
@@ -1468,14 +1480,15 @@ def clear(self):
14681480

14691481
def fire(
14701482
self,
1483+
callback_type: cb.HighsCallbackType,
14711484
message: str,
14721485
data_out: cb.HighsCallbackDataOut,
14731486
data_in: cb.HighsCallbackDataIn,
14741487
):
14751488
"""
14761489
Fires the event, executing all subscribed callbacks.
14771490
"""
1478-
e = HighsCallbackEvent(message, data_out, data_in, None)
1491+
e = HighsCallbackEvent(callback_type, message, data_out, data_in, None)
14791492

14801493
for fn, user_data in zip(self.callbacks, self.user_callback_data):
14811494
e.user_data = user_data
@@ -1792,7 +1805,7 @@ def copy(self):
17921805

17931806
def evaluate(
17941807
self,
1795-
values: Union[Sequence[float], np.ndarray[Any, np.dtype[np.float64]], readonly_ptr_wrapper_double],
1808+
values: Union[Sequence[float], np.ndarray[Any, np.dtype[np.float64]]],
17961809
) -> Union[float, bool]:
17971810
"""
17981811
Evaluates the linear expression given a solution array (values).

src/interfaces/highs_c_api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1433,7 +1433,7 @@ void Highs_resetGlobalScheduler(HighsInt blocking) {
14331433
Highs::resetGlobalScheduler(blocking != 0);
14341434
}
14351435

1436-
const void* Highs_getCallbackDataOutItem(const HighsCallbackDataOut* data_out,
1436+
const void* Highs_getCallbackDataOutItem(const HighsCCallbackDataOut* data_out,
14371437
const char* item_name) {
14381438
// Accessor function for HighsCallbackDataOut
14391439
//

src/interfaces/highs_c_api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2359,7 +2359,7 @@ void Highs_resetGlobalScheduler(const HighsInt blocking);
23592359
* @returns A void* pointer to the callback data item, or NULL if item_name not
23602360
* valid
23612361
*/
2362-
const void* Highs_getCallbackDataOutItem(const HighsCallbackDataOut* data_out,
2362+
const void* Highs_getCallbackDataOutItem(const HighsCCallbackDataOut* data_out,
23632363
const char* item_name);
23642364

23652365
// *********************

src/io/HighsIO.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ void highsLogUser(const HighsLogOptions& log_options_, const HighsLogType type,
151151
if (log_options_.user_callback_active) {
152152
assert(log_options_.user_callback);
153153
HighsCallbackDataOut data_out;
154-
data_out.log_type = int(type);
154+
data_out.log_type = type;
155155
log_options_.user_callback(kCallbackLogging, msgbuffer.data(), &data_out,
156156
nullptr, log_options_.user_callback_data);
157157
}
@@ -210,7 +210,7 @@ void highsLogDev(const HighsLogOptions& log_options_, const HighsLogType type,
210210
} else if (log_options_.user_callback_active) {
211211
assert(log_options_.user_callback);
212212
HighsCallbackDataOut data_out;
213-
data_out.log_type = int(type);
213+
data_out.log_type = type;
214214
log_options_.user_callback(kCallbackLogging, msgbuffer.data(), &data_out,
215215
nullptr, log_options_.user_callback_data);
216216
}

src/lp_data/Highs.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2227,9 +2227,18 @@ HighsStatus Highs::setCallback(HighsCCallbackType c_callback,
22272227
void* user_callback_data) {
22282228
this->callback_.clear();
22292229
this->callback_.user_callback =
2230-
[c_callback](int a, const std::string& b, const HighsCallbackDataOut* c,
2231-
HighsCallbackDataIn* d,
2232-
void* e) { c_callback(a, b.c_str(), c, d, e); };
2230+
[c_callback](int a, const std::string& b, const HighsCallbackDataOut* cb_out,
2231+
HighsCallbackDataIn* cb_in,
2232+
void* e) {
2233+
HighsCCallbackDataOut cc_out = static_cast<HighsCCallbackDataOut>(*cb_out);
2234+
HighsCCallbackDataIn cc_in;
2235+
cc_in.user_interrupt = 0;
2236+
cc_in.user_solution_size = 0;
2237+
cc_in.user_solution = nullptr;
2238+
2239+
c_callback(a, b.c_str(), &cc_out, &cc_in, e);
2240+
*cb_in = cc_in; // copy the data in
2241+
};
22332242
this->callback_.user_callback_data = user_callback_data;
22342243

22352244
options_.log_options.user_callback = this->callback_.user_callback;

0 commit comments

Comments
 (0)