Skip to content

Commit cb75e98

Browse files
committed
math_opt: fixup
* remove stubby sample
1 parent 2d4b78f commit cb75e98

File tree

12 files changed

+122
-179
lines changed

12 files changed

+122
-179
lines changed

ortools/math_opt/elemental/python/elemental.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,14 @@ PYBIND11_MODULE(cpp_elemental, py_module) {
496496
},
497497
py::kw_only(), arg("remove_names") = false);
498498

499+
elemental.def_static(
500+
"from_model_proto",
501+
[](const ModelProto& proto) {
502+
return std::make_unique<Elemental>(
503+
ThrowIfError(Elemental::FromModelProto(proto)));
504+
},
505+
arg("proto"));
506+
499507
elemental.def("add_diff", [](Elemental& e) { return e.AddDiff().id(); });
500508

501509
elemental.def("delete_diff", [](Elemental& e, const int64_t diff_handle) {

ortools/math_opt/elemental/python/elemental_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,36 @@ def test_export_model(self):
461461
expected.variables.names[:] = []
462462
self.assert_protos_equal(e.export_model(remove_names=True), expected)
463463

464+
def test_from_model_proto(self):
465+
proto = model_pb2.ModelProto(
466+
name="model",
467+
variables=model_pb2.VariablesProto(
468+
ids=[2],
469+
lower_bounds=[4.0],
470+
upper_bounds=[math.inf],
471+
integers=[False],
472+
names=["x"],
473+
),
474+
)
475+
e = cpp_elemental.CppElemental.from_model_proto(proto)
476+
self.assertEqual(e.model_name, "model")
477+
x = 2
478+
np.testing.assert_array_equal(
479+
e.get_elements(_VARIABLE), np.array([x], dtype=np.int64), strict=True
480+
)
481+
self.assertEqual(e.get_element_name(_VARIABLE, x), "x")
482+
self.assertEqual(e.get_attr(_VARIABLE_LOWER_BOUND, (x,)), 4.0)
483+
self.assertEqual(e.get_next_element_id(_VARIABLE), 3)
484+
self.assert_protos_equal(e.export_model(), proto)
485+
486+
def test_from_model_proto_empty(self):
487+
proto = model_pb2.ModelProto()
488+
e = cpp_elemental.CppElemental.from_model_proto(proto)
489+
self.assertEqual(e.model_name, "")
490+
self.assertEqual(e.primary_objective_name, "")
491+
for element_type in enums.ElementType:
492+
self.assertEqual(e.get_num_elements(element_type), 0)
493+
464494
def test_repr(self):
465495
e = cpp_elemental.CppElemental()
466496
e.add_element(_VARIABLE, "xyz")

ortools/math_opt/io/BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ cc_library(
8585
hdrs = ["lp_parser.h"],
8686
deps = [
8787
":mps_converter",
88-
"//ortools/base",
8988
"//ortools/base:file",
9089
"//ortools/base:path",
9190
"//ortools/base:status_macros",

ortools/math_opt/python/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ py_library(
6262
":objectives",
6363
":quadratic_constraints",
6464
":variables",
65+
requirement("typing-extensions"),
6566
"//ortools/math_opt:model_py_pb2",
6667
"//ortools/math_opt:model_update_py_pb2",
6768
"//ortools/math_opt/elemental/python:cpp_elemental",

ortools/math_opt/python/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
import math
4545
from typing import Iterator, Optional, Tuple, Union
4646

47+
# typing.Self is only in python 3.11+, for OR-tools supports down to 3.8.
48+
from typing_extensions import Self
49+
4750
from ortools.math_opt import model_pb2
4851
from ortools.math_opt import model_update_pb2
4952
from ortools.math_opt.elemental.python import cpp_elemental
@@ -910,6 +913,13 @@ def get_indicator_constraints(
910913
# Proto import/export
911914
##############################################################################
912915

916+
@classmethod
917+
def from_model_proto(cls, proto: model_pb2.ModelProto) -> Self:
918+
"""Returns a Model equivalent to the input model proto."""
919+
model = cls()
920+
model._elemental = cpp_elemental.CppElemental.from_model_proto(proto)
921+
return model
922+
913923
def export_model(self, *, remove_names: bool = False) -> model_pb2.ModelProto:
914924
"""Returns a protocol buffer equivalent to this model.
915925

ortools/math_opt/python/model_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,58 @@ def test_export(self, remove_names: Optional[bool]) -> None:
518518
),
519519
)
520520

521+
def test_from_model_proto(self) -> None:
522+
model_proto = model_pb2.ModelProto(
523+
name="test_model",
524+
variables=model_pb2.VariablesProto(
525+
ids=[0, 1],
526+
lower_bounds=[0.0, 1.0],
527+
upper_bounds=[2.0, 3.0],
528+
integers=[True, False],
529+
names=["x", "y"],
530+
),
531+
linear_constraints=model_pb2.LinearConstraintsProto(
532+
ids=[0],
533+
lower_bounds=[-1.0],
534+
upper_bounds=[2.0],
535+
names=["c"],
536+
),
537+
objective=model_pb2.ObjectiveProto(
538+
maximize=True,
539+
offset=2.0,
540+
linear_coefficients=sparse_containers_pb2.SparseDoubleVectorProto(
541+
ids=[1], values=[3.0]
542+
),
543+
),
544+
linear_constraint_matrix=sparse_containers_pb2.SparseDoubleMatrixProto(
545+
row_ids=[0, 0], column_ids=[0, 1], coefficients=[1.0, 2.0]
546+
),
547+
)
548+
mod = model.Model.from_model_proto(model_proto)
549+
self.assertEqual(mod.name, "test_model")
550+
self.assertEqual(mod.get_num_variables(), 2)
551+
x = mod.get_variable(0)
552+
y = mod.get_variable(1)
553+
self.assertEqual(x.name, "x")
554+
self.assertEqual(x.lower_bound, 0.0)
555+
self.assertEqual(x.upper_bound, 2.0)
556+
self.assertTrue(x.integer)
557+
self.assertEqual(y.name, "y")
558+
self.assertEqual(y.lower_bound, 1.0)
559+
self.assertEqual(y.upper_bound, 3.0)
560+
self.assertFalse(y.integer)
561+
self.assertEqual(mod.get_num_linear_constraints(), 1)
562+
c = mod.get_linear_constraint(0)
563+
self.assertEqual(c.name, "c")
564+
self.assertEqual(c.lower_bound, -1.0)
565+
self.assertEqual(c.upper_bound, 2.0)
566+
self.assertEqual(c.get_coefficient(x), 1.0)
567+
self.assertEqual(c.get_coefficient(y), 2.0)
568+
self.assertTrue(mod.objective.is_maximize)
569+
self.assertEqual(mod.objective.offset, 2.0)
570+
self.assertEqual(mod.objective.get_linear_coefficient(x), 0.0)
571+
self.assertEqual(mod.objective.get_linear_coefficient(y), 3.0)
572+
521573
def test_update_tracker_simple(self) -> None:
522574
mod = model.Model(name="test_model")
523575
x = mod.add_binary_variable(name="x")

ortools/math_opt/samples/cpp/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ cc_binary(
192192
srcs = ["tsp.cc"],
193193
deps = [
194194
"//ortools/base",
195+
"//ortools/base:file",
195196
"//ortools/base:status_macros",
196197
"//ortools/math_opt/cpp:math_opt",
197198
"//ortools/math_opt/solvers:gscip_solver",

ortools/math_opt/samples/python/BUILD.bazel

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -99,19 +99,6 @@ py_binary(
9999
],
100100
)
101101

102-
#py_binary(
103-
# name = "stubby_remote_streaming_solve_example",
104-
# srcs = ["stubby_remote_streaming_solve_example.py"],
105-
# test_lib = True,
106-
# deps = [
107-
# requirement("absl-py"),
108-
# "//ortools/math_opt:rpc_stubby_pyclif",
109-
# "//ortools/math_opt/python:mathopt",
110-
# "//ortools/math_opt/python/ipc:solve_service_stubby_client",
111-
# "//ortools/math_opt/python/ipc:stubby_remote_streaming_solve",
112-
# ],
113-
#)
114-
115102
py_binary(
116103
name = "time_indexed_scheduling",
117104
srcs = ["time_indexed_scheduling.py"],

ortools/math_opt/samples/python/stubby_remote_streaming_solve_example.py

Lines changed: 0 additions & 144 deletions
This file was deleted.

ortools/math_opt/solvers/gscip/gscip_event_handler_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ TEST(GScipEventHandlerDeathTest, ErrorReturnedByInit) {
221221
const absl::Status status = gscip->Solve().status();
222222
LOG(ERROR) << "status: " << status;
223223
if (!status.ok() &&
224-
absl::StrContains(status.message(), "SCIP error code 0")) {
224+
absl::StrContains(status.message(), "SCIP error code -8")) {
225225
// Write the expected marker only if we see the expected error.
226226
LOG(FATAL) << kMarker;
227227
}

0 commit comments

Comments
 (0)