Skip to content

Commit d18ca9a

Browse files
committed
math_opt: export from google3
1 parent 435bc09 commit d18ca9a

20 files changed

+501
-118
lines changed

ortools/math_opt/callback.proto

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ enum CallbackEventProto {
4747
// node). Useful for early termination. Note that this event does not provide
4848
// information on LP relaxations nor about new incumbent solutions.
4949
//
50-
// This event is supported for MIP models by SOLVER_TYPE_GUROBI only.
50+
// This event is fully supported for MIP models by SOLVER_TYPE_GUROBI only. If
51+
// used with SOLVER_TYPE_CP_SAT, it is called when the dual bound is improved.
5152
CALLBACK_EVENT_MIP = 3;
5253

5354
// Called every time a new MIP incumbent is found.
@@ -127,7 +128,8 @@ message CallbackDataProto {
127128
BarrierStats barrier_stats = 6;
128129

129130
// MIP B&B stats. Only available during CALLBACK_EVENT_MIPxxxx events.
130-
// Not supported for CP-SAT.
131+
// When using CP-SAT, only primal_bound, dual_bound and
132+
// number_of_solutions_found are populated.
131133
message MipStats {
132134
optional double primal_bound = 1;
133135
optional double dual_bound = 2;

ortools/math_opt/core/base_solver.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ namespace operations_research::math_opt {
3333

3434
// The API of solvers (in-process, sub-process and streaming RPC ones).
3535
//
36-
// Thread-safety: methods Solve() and Update() must not be called concurrently;
37-
// they should immediately return with an error status if this happens.
36+
// Thread-safety: methods Solve(), ComputeInfeasibleSubsystem() and Update()
37+
// must not be called concurrently; they should immediately return with an error
38+
// status if this happens.
3839
//
3940
// TODO: b/350984134 - Rename `Solver` into `InProcessSolver` and then rename
4041
// `BaseSolver` into `Solver`.
@@ -65,7 +66,14 @@ class BaseSolver {
6566
// printed on stdout/stderr/logs anymore.
6667
MessageCallback message_callback = nullptr;
6768

69+
// Registration parameter controlling calls to user_cb.
6870
CallbackRegistrationProto callback_registration;
71+
72+
// An optional MIP/LP callback. Only called for events registered in
73+
// callback_registration.
74+
//
75+
// Solve() returns an error if called without a user_cb but with some
76+
// non-empty callback_registration.request_registration.
6977
Callback user_cb = nullptr;
7078

7179
// An optional interrupter that the solver can use to interrupt the solve

ortools/math_opt/core/solver.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,16 @@ absl::StatusOr<SolveResultProto> Solver::Solve(const SolveArgs& arguments) {
120120
ValidateModelSolveParameters(arguments.model_parameters, model_summary_))
121121
<< "invalid model_parameters";
122122

123+
RETURN_IF_ERROR(ValidateCallbackRegistration(arguments.callback_registration,
124+
model_summary_));
123125
SolverInterface::Callback cb = nullptr;
126+
if (!arguments.callback_registration.request_registration().empty() &&
127+
arguments.user_cb == nullptr) {
128+
return absl::InvalidArgumentError(
129+
"no callback function was provided but callback events were "
130+
"registered");
131+
}
124132
if (arguments.user_cb != nullptr) {
125-
RETURN_IF_ERROR(ValidateCallbackRegistration(
126-
arguments.callback_registration, model_summary_));
127133
cb = [&](const CallbackDataProto& callback_data)
128134
-> absl::StatusOr<CallbackResultProto> {
129135
RETURN_IF_ERROR(ValidateCallbackDataProto(

ortools/math_opt/core/solver_interface.h

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ namespace math_opt {
4343
//
4444
// This interface is not meant to be used directly. The actual API is the one of
4545
// the Solver class. The Solver class validates the models before calling this
46-
// interface. It makes sure no concurrent calls happen on Solve(), CanUpdate()
47-
// and Update(). It makes sure no other function is called after Solve(),
46+
// interface. It makes sure no concurrent calls happen on Solve(),
47+
// ComputeInfeasibleSubsystem(), CanUpdate() and Update(). It makes sure no
48+
// other function is called after Solve(), ComputeInfeasibleSubsystem(),
4849
// Update() or a callback have failed.
4950
//
5051
// Implementations of this interface should not have public constructors but
@@ -69,12 +70,28 @@ class SolverInterface {
6970
// See Solver::MessageCallback documentation for details.
7071
using MessageCallback = std::function<void(const std::vector<std::string>&)>;
7172

72-
// A callback function (if non null) is a function that validates its input
73-
// and its output, and if fails, return a status. The invariant is that the
74-
// solver implementation can rely on receiving valid data. The implementation
75-
// of this interface must provide valid input (which will be validated) and
76-
// in error, it will return a status (without actually calling the callback
77-
// function). This is enforced in the solver.cc layer.
73+
// A callback function (if non null) provided by the Solver class to its
74+
// SolverInterface that wraps the user callback function
75+
// (BaseSolver::Callback) and validates its inputs (provided by the
76+
// SolverInterface implementation) and outputs (provided by the user). A
77+
// failing status is returned if those inputs or outputs are invalid.
78+
//
79+
// To be clear the SolverInterface::Callback is implemented by the Solver
80+
// class and looks like:
81+
//
82+
// absl::Status Callback(const CallbackDataProto& callback_data) {
83+
// RETURN_IF_ERROR(ValidateCallbackDataProto(callback_data, ...));
84+
// CallbackResultProto result = user_cb(callback_data);
85+
// RETURN_IF_ERROR(ValidateCallbackResultProto(result));
86+
// return result;
87+
// }
88+
//
89+
// As a consequence SolverInterface implementations can rely on receiving a
90+
// valid CallbackResultProto.
91+
//
92+
// When the SolverInterface::Callback returns an error the SolverInterface
93+
// implementation must interrupt the Solve() as soon as possible and return
94+
// this error.
7895
using Callback = std::function<absl::StatusOr<CallbackResultProto>(
7996
const CallbackDataProto&)>;
8097

@@ -114,7 +131,11 @@ class SolverInterface {
114131
// When parameter `message_cb` is not null and the underlying solver does not
115132
// supports message callbacks, it should ignore it.
116133
//
117-
// Solvers should return a InvalidArgumentError when called with events on
134+
// The parameter `cb` won't be null when
135+
// callback_registration.request_registration is not empty (solver.cc will
136+
// return an error in that case before calling SolverInterface::Solve()).
137+
//
138+
// Solvers should return an InvalidArgumentError when called with events on
118139
// callback_registration that are not supported by the solver for the type of
119140
// model being solved (for example MIP events if the model is an LP, or events
120141
// that are not emitted by the solver). Solvers should use

ortools/math_opt/cpp/BUILD.bazel

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,7 @@ cc_library(
895895

896896
cc_library(
897897
name = "incremental_solver",
898+
srcs = ["incremental_solver.cc"],
898899
hdrs = ["incremental_solver.h"],
899900
deps = [
900901
":compute_infeasible_subsystem_arguments",
@@ -903,10 +904,25 @@ cc_library(
903904
":solve_arguments",
904905
":solve_result",
905906
":update_result",
907+
"//ortools/base:status_macros",
906908
"@abseil-cpp//absl/status:statusor",
907909
],
908910
)
909911

912+
cc_test(
913+
name = "incremental_solver_test",
914+
srcs = ["incremental_solver_test.cc"],
915+
deps = [
916+
":incremental_solver",
917+
":matchers",
918+
":math_opt",
919+
"//ortools/base:gmock_main",
920+
"@abseil-cpp//absl/status",
921+
"@abseil-cpp//absl/status:statusor",
922+
"@abseil-cpp//absl/strings:string_view",
923+
],
924+
)
925+
910926
cc_library(
911927
name = "remote_streaming_mode",
912928
srcs = ["remote_streaming_mode.cc"],

ortools/math_opt/cpp/callback.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ enum class CallbackEvent {
109109
// node). Useful for early termination. Note that this event does not provide
110110
// information on LP relaxations nor about new incumbent solutions.
111111
//
112-
// This event is supported for MIP models with SolverType::kGurobi only.
112+
// This event is fully supported for MIP models with SolverType::kGurobi only.
113+
// If used with SolverType::kCpSat, it is called when the dual bound is
114+
// improved.
113115
kMip = CALLBACK_EVENT_MIP,
114116

115117
// Called every time a new MIP incumbent is found.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2010-2025 Google LLC
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
#include "ortools/math_opt/cpp/incremental_solver.h"
15+
16+
#include "absl/status/statusor.h"
17+
#include "ortools/base/status_macros.h"
18+
19+
namespace operations_research::math_opt {
20+
21+
absl::StatusOr<SolveResult> IncrementalSolver::Solve(
22+
const SolveArguments& arguments) {
23+
RETURN_IF_ERROR(Update().status());
24+
return SolveWithoutUpdate(arguments);
25+
}
26+
27+
absl::StatusOr<ComputeInfeasibleSubsystemResult>
28+
IncrementalSolver::ComputeInfeasibleSubsystem(
29+
const ComputeInfeasibleSubsystemArguments& arguments) {
30+
RETURN_IF_ERROR(Update().status());
31+
return ComputeInfeasibleSubsystemWithoutUpdate(arguments);
32+
}
33+
34+
} // namespace operations_research::math_opt

ortools/math_opt/cpp/incremental_solver.h

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,21 +112,14 @@ class IncrementalSolver {
112112
//
113113
// See callback.h for documentation on arguments.callback and
114114
// arguments.callback_registration.
115-
virtual absl::StatusOr<SolveResult> Solve(
116-
const SolveArguments& arguments) = 0;
117-
absl::StatusOr<SolveResult> Solve() { return Solve({}); }
115+
absl::StatusOr<SolveResult> Solve(const SolveArguments& arguments = {});
118116

119117
// Updates the underlying solver with latest model changes and runs the
120118
// computation.
121119
//
122120
// Same as Solve() but compute the infeasible subsystem.
123-
virtual absl::StatusOr<ComputeInfeasibleSubsystemResult>
124-
ComputeInfeasibleSubsystem(
125-
const ComputeInfeasibleSubsystemArguments& arguments) = 0;
126-
absl::StatusOr<ComputeInfeasibleSubsystemResult>
127-
ComputeInfeasibleSubsystem() {
128-
return ComputeInfeasibleSubsystem({});
129-
}
121+
absl::StatusOr<ComputeInfeasibleSubsystemResult> ComputeInfeasibleSubsystem(
122+
const ComputeInfeasibleSubsystemArguments& arguments = {});
130123

131124
// Updates the model to solve.
132125
//
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Copyright 2010-2025 Google LLC
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
#include "ortools/math_opt/cpp/incremental_solver.h"
15+
16+
#include <string>
17+
18+
#include "absl/status/status.h"
19+
#include "absl/status/statusor.h"
20+
#include "absl/strings/string_view.h"
21+
#include "gtest/gtest.h"
22+
#include "ortools/base/gmock.h"
23+
#include "ortools/math_opt/cpp/matchers.h"
24+
#include "ortools/math_opt/cpp/math_opt.h"
25+
26+
namespace operations_research::math_opt {
27+
namespace {
28+
29+
using ::testing::_;
30+
using ::testing::Return;
31+
using ::testing::status::IsOkAndHolds;
32+
using ::testing::status::StatusIs;
33+
34+
class MockIncrementalSolver final : public IncrementalSolver {
35+
public:
36+
MOCK_METHOD(absl::StatusOr<UpdateResult>, Update, (), (override));
37+
MOCK_METHOD(absl::StatusOr<SolveResult>, SolveWithoutUpdate,
38+
(const SolveArguments&), (const, override));
39+
MOCK_METHOD(absl::StatusOr<ComputeInfeasibleSubsystemResult>,
40+
ComputeInfeasibleSubsystemWithoutUpdate,
41+
(const ComputeInfeasibleSubsystemArguments&), (const, override));
42+
MOCK_METHOD(SolverType, solver_type, (), (const, override));
43+
};
44+
45+
TEST(IncrementalSolverTest, SolveWithFailingUpdate) {
46+
MockIncrementalSolver incremental_solver;
47+
EXPECT_CALL(incremental_solver, Update())
48+
.WillOnce(Return(absl::InternalError("oops")));
49+
EXPECT_THAT(incremental_solver.Solve(),
50+
StatusIs(absl::StatusCode::kInternal, "oops"));
51+
}
52+
53+
TEST(IncrementalSolverTest, SolveWithFailingSolveWithoutUpdate) {
54+
MockIncrementalSolver incremental_solver;
55+
EXPECT_CALL(incremental_solver, Update())
56+
.WillOnce(Return(UpdateResult(/*did_update=*/true)));
57+
EXPECT_CALL(incremental_solver, SolveWithoutUpdate(_))
58+
.WillOnce(Return(absl::InternalError("oops")));
59+
EXPECT_THAT(incremental_solver.Solve(),
60+
StatusIs(absl::StatusCode::kInternal, "oops"));
61+
}
62+
63+
TEST(IncrementalSolverTest, SuccessfulSolve) {
64+
MockIncrementalSolver incremental_solver;
65+
EXPECT_CALL(incremental_solver, Update())
66+
.WillOnce(Return(UpdateResult(/*did_update=*/true)));
67+
constexpr double kObjectiveValue = 3.5;
68+
constexpr absl::string_view kDetail = "found the optimum!";
69+
EXPECT_CALL(incremental_solver, SolveWithoutUpdate(_))
70+
.WillOnce(Return(
71+
SolveResult(Termination::Optimal(/*objective_value=*/kObjectiveValue,
72+
/*detail=*/std::string(kDetail)))));
73+
ASSERT_OK_AND_ASSIGN(const SolveResult solve_result,
74+
incremental_solver.Solve());
75+
EXPECT_THAT(solve_result.termination,
76+
TerminationIsOptimal(/*primal_objective_value=*/kObjectiveValue));
77+
EXPECT_EQ(solve_result.termination.detail, kDetail);
78+
}
79+
80+
TEST(IncrementalSolverTest, ComputeInfeasibleSubsystemWithFailingUpdate) {
81+
MockIncrementalSolver incremental_solver;
82+
EXPECT_CALL(incremental_solver, Update())
83+
.WillOnce(Return(absl::InternalError("oops")));
84+
EXPECT_THAT(incremental_solver.ComputeInfeasibleSubsystem(),
85+
StatusIs(absl::StatusCode::kInternal, "oops"));
86+
}
87+
88+
TEST(IncrementalSolverTest,
89+
ComputeInfeasibleSubsystemWithFailingComputeWithoutUpdate) {
90+
MockIncrementalSolver incremental_solver;
91+
EXPECT_CALL(incremental_solver, Update())
92+
.WillOnce(Return(UpdateResult(/*did_update=*/true)));
93+
EXPECT_CALL(incremental_solver, ComputeInfeasibleSubsystemWithoutUpdate(_))
94+
.WillOnce(Return(absl::InternalError("oops")));
95+
EXPECT_THAT(incremental_solver.ComputeInfeasibleSubsystem(),
96+
StatusIs(absl::StatusCode::kInternal, "oops"));
97+
}
98+
99+
TEST(IncrementalSolverTest, SuccessfulComputeInfeasibleSubsystem) {
100+
MockIncrementalSolver incremental_solver;
101+
EXPECT_CALL(incremental_solver, Update())
102+
.WillOnce(Return(UpdateResult(/*did_update=*/true)));
103+
Model model;
104+
const Variable v = model.AddBinaryVariable("v");
105+
const ModelSubset model_subset = {
106+
.variable_integrality = {v},
107+
};
108+
EXPECT_CALL(incremental_solver, ComputeInfeasibleSubsystemWithoutUpdate(_))
109+
.WillOnce(Return(ComputeInfeasibleSubsystemResult{
110+
.feasibility = FeasibilityStatus::kInfeasible,
111+
.infeasible_subsystem = model_subset,
112+
.is_minimal = false,
113+
}));
114+
ASSERT_THAT(incremental_solver.ComputeInfeasibleSubsystem(),
115+
IsOkAndHolds(IsInfeasible(
116+
/*expected_is_minimal=*/false,
117+
/*expected_infeasible_subsystem=*/model_subset)));
118+
}
119+
120+
} // namespace
121+
} // namespace operations_research::math_opt

ortools/math_opt/cpp/matchers.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ class MapToDoubleMatcher
209209

210210
} // namespace
211211

212-
Matcher<VariableMap<double>> IsNearlySubsetOf(VariableMap<double> expected,
213-
double tolerance) {
212+
Matcher<VariableMap<double>> IsNearlySupersetOf(VariableMap<double> expected,
213+
double tolerance) {
214214
return Matcher<VariableMap<double>>(new MapToDoubleMatcher<Variable>(
215215
std::move(expected), /*all_keys=*/false, tolerance));
216216
}
@@ -221,7 +221,7 @@ Matcher<VariableMap<double>> IsNear(VariableMap<double> expected,
221221
std::move(expected), /*all_keys=*/true, tolerance));
222222
}
223223

224-
Matcher<LinearConstraintMap<double>> IsNearlySubsetOf(
224+
Matcher<LinearConstraintMap<double>> IsNearlySupersetOf(
225225
LinearConstraintMap<double> expected, double tolerance) {
226226
return Matcher<LinearConstraintMap<double>>(
227227
new MapToDoubleMatcher<LinearConstraint>(std::move(expected),
@@ -243,7 +243,7 @@ Matcher<absl::flat_hash_map<QuadraticConstraint, double>> IsNear(
243243
std::move(expected), /*all_keys=*/true, tolerance));
244244
}
245245

246-
Matcher<absl::flat_hash_map<QuadraticConstraint, double>> IsNearlySubsetOf(
246+
Matcher<absl::flat_hash_map<QuadraticConstraint, double>> IsNearlySupersetOf(
247247
absl::flat_hash_map<QuadraticConstraint, double> expected,
248248
double tolerance) {
249249
return Matcher<absl::flat_hash_map<QuadraticConstraint, double>>(
@@ -260,7 +260,7 @@ Matcher<absl::flat_hash_map<K, double>> IsNear(
260260
}
261261

262262
template <typename K>
263-
Matcher<absl::flat_hash_map<K, double>> IsNearlySubsetOf(
263+
Matcher<absl::flat_hash_map<K, double>> IsNearlySupersetOf(
264264
absl::flat_hash_map<K, double> expected, const double tolerance) {
265265
return Matcher<absl::flat_hash_map<K, double>>(new MapToDoubleMatcher<K>(
266266
std::move(expected), /*all_keys=*/false, tolerance));

0 commit comments

Comments
 (0)