Skip to content

Commit c322f06

Browse files
committed
linear_solver: export from google3
1 parent 2b3439a commit c322f06

File tree

4 files changed

+143
-121
lines changed

4 files changed

+143
-121
lines changed

ortools/linear_solver/proto_solver/sat_proto_solver.cc

Lines changed: 123 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -157,68 +157,13 @@ MPSolutionResponse TimeLimitResponse(SolverLogger& logger) {
157157

158158
} // namespace
159159

160-
MPSolutionResponse SatSolveProto(
161-
LazyMutableCopy<MPModelRequest> request, std::atomic<bool>* interrupt_solve,
162-
std::function<void(const std::string&)> logging_callback,
163-
std::function<void(const MPSolution&)> solution_callback,
164-
std::function<void(const double)> best_bound_callback) {
165-
sat::SatParameters params;
166-
params.set_log_search_progress(request->enable_internal_solver_output());
167-
168-
// TODO(user): We do not support all the parameters here. In particular the
169-
// logs before the solver is called will not be appended to the response. Fix
170-
// that, and remove code duplication for the logger config. One way should be
171-
// to not touch/configure anything if the logger is already created while
172-
// calling SolveCpModel() and call a common config function from here or from
173-
// inside Solve()?
174-
SolverLogger logger;
175-
if (logging_callback != nullptr) {
176-
logger.AddInfoLoggingCallback(logging_callback);
177-
}
178-
logger.EnableLogging(params.log_search_progress());
179-
logger.SetLogToStdOut(params.log_to_stdout());
180-
181-
// Set it now so that it can be overwritten by the solver specific parameters.
182-
if (request->has_solver_specific_parameters()) {
183-
// See EncodeSatParametersAsString() documentation.
184-
if constexpr (!std::is_base_of<Message, sat::SatParameters>::value) {
185-
if (!params.MergeFromString(request->solver_specific_parameters())) {
186-
return InvalidParametersResponse(
187-
logger,
188-
"solver_specific_parameters is not a valid binary stream of the "
189-
"SatParameters proto");
190-
}
191-
} else {
192-
if (!ProtobufTextFormatMergeFromString(
193-
request->solver_specific_parameters(), &params)) {
194-
return InvalidParametersResponse(
195-
logger,
196-
"solver_specific_parameters is not a valid textual representation "
197-
"of the SatParameters proto");
198-
}
199-
}
200-
}
201-
202-
// Validate parameters.
203-
{
204-
const std::string error = sat::ValidateParameters(params);
205-
if (!error.empty()) {
206-
return InvalidParametersResponse(
207-
logger, absl::StrCat("Invalid CP-SAT parameters: ", error));
208-
}
209-
}
210-
211-
// Reconfigure the logger in case the solver_specific_parameters overwrite its
212-
// configuration. Note that the invalid parameter message will be logged
213-
// before that though according to request.enable_internal_solver_output().
214-
logger.EnableLogging(params.log_search_progress());
215-
logger.SetLogToStdOut(params.log_to_stdout());
216-
217-
if (request->has_solver_time_limit_seconds()) {
218-
params.set_max_time_in_seconds(request->solver_time_limit_seconds());
219-
}
220-
221-
std::unique_ptr<TimeLimit> time_limit = TimeLimit::FromParameters(params);
160+
MPSolutionResponse SatSolveProtoInternal(
161+
LazyMutableCopy<MPModelRequest> request, sat::Model* sat_model,
162+
sat::CpSolverResponse* cp_response,
163+
std::function<void(const MPSolution&)> solution_callback) {
164+
SolverLogger* logger = sat_model->GetOrCreate<SolverLogger>();
165+
sat::SatParameters& params = *sat_model->GetOrCreate<sat::SatParameters>();
166+
TimeLimit* time_limit = sat_model->GetOrCreate<TimeLimit>();
222167

223168
// Model validation and delta handling.
224169
MPSolutionResponse response;
@@ -231,10 +176,10 @@ MPSolutionResponse SatSolveProto(
231176
//
232177
// The logging is only needed for our benchmark script, so we use UNKNOWN
233178
// here, but we could log the proper status instead.
234-
if (logger.LoggingIsEnabled()) {
179+
if (logger->LoggingIsEnabled()) {
235180
sat::CpSolverResponse cp_response;
236181
cp_response.set_status(FromMPSolverResponseStatus(response.status()));
237-
SOLVER_LOG(&logger, CpSolverResponseStats(cp_response));
182+
SOLVER_LOG(logger, CpSolverResponseStats(cp_response));
238183
}
239184
return response;
240185
}
@@ -252,49 +197,48 @@ MPSolutionResponse SatSolveProto(
252197
// of input.
253198
if (params.mip_treat_high_magnitude_bounds_as_infinity()) {
254199
sat::ChangeLargeBoundsToInfinity(params.mip_max_valid_magnitude(),
255-
mp_model.get(), &logger);
200+
mp_model.get(), logger);
256201
}
257-
if (!sat::MPModelProtoValidationBeforeConversion(params, *mp_model,
258-
&logger)) {
259-
return InvalidModelResponse(logger, "Extra CP-SAT validation failed.");
202+
if (!sat::MPModelProtoValidationBeforeConversion(params, *mp_model, logger)) {
203+
return InvalidModelResponse(*logger, "Extra CP-SAT validation failed.");
260204
}
261205

262206
// This is good to do before any presolve.
263207
if (!sat::MakeBoundsOfIntegerVariablesInteger(params, mp_model.get(),
264-
&logger)) {
265-
return InfeasibleResponse(logger,
208+
logger)) {
209+
return InfeasibleResponse(*logger,
266210
"An integer variable has an empty domain");
267211
}
268212

269213
// Coefficients really close to zero can cause issues.
270214
// We remove them right away according to our parameters.
271-
RemoveNearZeroTerms(params, mp_model.get(), &logger);
215+
RemoveNearZeroTerms(params, mp_model.get(), logger);
272216

273217
// Note(user): the LP presolvers API is a bit weird and keep a reference to
274218
// the given GlopParameters, so we need to make sure it outlive them.
275219
const glop::GlopParameters glop_params;
276220
std::vector<std::unique_ptr<glop::Preprocessor>> for_postsolve;
277221
if (!params.enumerate_all_solutions() && params.mip_presolve_level() > 0) {
278222
const glop::ProblemStatus status = ApplyMipPresolveSteps(
279-
glop_params, mp_model.get(), &for_postsolve, &logger);
223+
glop_params, mp_model.get(), &for_postsolve, logger);
280224
switch (status) {
281225
case glop::ProblemStatus::INIT:
282226
// Continue with the solve.
283227
break;
284228
case glop::ProblemStatus::PRIMAL_INFEASIBLE:
285229
return InfeasibleResponse(
286-
logger, "Problem proven infeasible during MIP presolve");
230+
*logger, "Problem proven infeasible during MIP presolve");
287231
case glop::ProblemStatus::INVALID_PROBLEM:
288232
return InvalidModelResponse(
289-
logger, "Problem detected invalid during MIP presolve");
233+
*logger, "Problem detected invalid during MIP presolve");
290234
default:
291235
// TODO(user): We put the INFEASIBLE_OR_UNBOUNBED case here since there
292236
// is no return status that exactly matches it.
293237
if (params.log_search_progress()) {
294238
// This is needed for our benchmark scripts.
295239
sat::CpSolverResponse cp_response;
296240
cp_response.set_status(sat::CpSolverStatus::UNKNOWN);
297-
SOLVER_LOG(&logger, "MIP presolve: problem infeasible or unbounded.");
241+
SOLVER_LOG(logger, "MIP presolve: problem infeasible or unbounded.");
298242
LOG(INFO) << CpSolverResponseStats(cp_response);
299243
}
300244
response.set_status(MPSolverResponseStatus::MPSOLVER_UNKNOWN_STATUS);
@@ -307,22 +251,22 @@ MPSolutionResponse SatSolveProto(
307251
}
308252

309253
if (time_limit->LimitReached()) {
310-
return TimeLimitResponse(logger);
254+
return TimeLimitResponse(*logger);
311255
}
312256
// We need to do that before the automatic detection of integers.
313-
RemoveNearZeroTerms(params, mp_model.get(), &logger);
257+
RemoveNearZeroTerms(params, mp_model.get(), logger);
314258

315-
SOLVER_LOG(&logger, "");
316-
SOLVER_LOG(&logger, "Scaling to pure integer problem.");
259+
SOLVER_LOG(logger, "");
260+
SOLVER_LOG(logger, "Scaling to pure integer problem.");
317261

318262
const int num_variables = mp_model->variable_size();
319263
std::vector<double> var_scaling(num_variables, 1.0);
320264
if (params.mip_automatically_scale_variables()) {
321-
var_scaling = sat::DetectImpliedIntegers(mp_model.get(), &logger);
265+
var_scaling = sat::DetectImpliedIntegers(mp_model.get(), logger);
322266
if (!sat::MakeBoundsOfIntegerVariablesInteger(params, mp_model.get(),
323-
&logger)) {
267+
logger)) {
324268
return InfeasibleResponse(
325-
logger, "A detected integer variable has an empty domain");
269+
*logger, "A detected integer variable has an empty domain");
326270
}
327271
}
328272
if (params.mip_var_scaling() != 1.0) {
@@ -347,7 +291,7 @@ MPSolutionResponse SatSolveProto(
347291
}
348292
if (!all_integer) {
349293
return InvalidModelResponse(
350-
logger,
294+
*logger,
351295
"The model contains non-integer variables but the parameter "
352296
"'only_solve_ip' was set. Change this parameter if you "
353297
"still want to solve a more constrained version of the original MIP "
@@ -357,8 +301,8 @@ MPSolutionResponse SatSolveProto(
357301

358302
sat::CpModelProto cp_model;
359303
if (!ConvertMPModelProtoToCpModelProto(params, *mp_model, &cp_model,
360-
&logger)) {
361-
return InvalidModelResponse(logger,
304+
logger)) {
305+
return InvalidModelResponse(*logger,
362306
"Failed to convert model into CP-SAT model");
363307
}
364308
DCHECK_EQ(cp_model.variables().size(), var_scaling.size());
@@ -391,30 +335,16 @@ MPSolutionResponse SatSolveProto(
391335
const bool is_maximize = mp_model->maximize();
392336
mp_model.reset();
393337

394-
params.set_max_time_in_seconds(time_limit->GetTimeLeft());
395-
if (time_limit->GetDeterministicTimeLeft() !=
396-
std::numeric_limits<double>::infinity()) {
397-
params.set_max_deterministic_time(time_limit->GetDeterministicTimeLeft());
398-
}
399-
400338
// Configure model.
401-
sat::Model sat_model;
402-
sat_model.Register<SolverLogger>(&logger);
403-
sat_model.Add(NewSatParameters(params));
404-
if (interrupt_solve != nullptr) {
405-
sat_model.GetOrCreate<TimeLimit>()->RegisterExternalBooleanAsLimit(
406-
interrupt_solve);
407-
}
408-
409-
auto post_solve = [&](const sat::CpSolverResponse& cp_response) {
339+
auto post_solve = [&](const sat::CpSolverResponse& sat_response) {
410340
MPSolution mp_solution;
411-
mp_solution.set_objective_value(cp_response.objective_value());
341+
mp_solution.set_objective_value(sat_response.objective_value());
412342
// Postsolve the bound shift and scaling.
413343
glop::ProblemSolution glop_solution((glop::RowIndex(old_num_constraints)),
414344
(glop::ColIndex(old_num_variables)));
415345
for (int v = 0; v < glop_solution.primal_values.size(); ++v) {
416346
glop_solution.primal_values[glop::ColIndex(v)] =
417-
static_cast<double>(cp_response.solution(v)) / var_scaling[v];
347+
static_cast<double>(sat_response.solution(v)) / var_scaling[v];
418348
}
419349
for (int i = for_postsolve.size(); --i >= 0;) {
420350
for_postsolve[i]->RecoverSolution(&glop_solution);
@@ -427,33 +357,29 @@ MPSolutionResponse SatSolveProto(
427357
};
428358

429359
if (solution_callback != nullptr) {
430-
sat_model.Add(sat::NewFeasibleSolutionObserver(
431-
[&](const sat::CpSolverResponse& cp_response) {
432-
solution_callback(post_solve(cp_response));
360+
sat_model->Add(sat::NewFeasibleSolutionObserver(
361+
[&](const sat::CpSolverResponse& sat_response) {
362+
solution_callback(post_solve(sat_response));
433363
}));
434364
}
435-
if (best_bound_callback != nullptr) {
436-
sat_model.Add(sat::NewBestBoundCallback(best_bound_callback));
437-
}
438365

439366
// Solve.
440-
const sat::CpSolverResponse cp_response =
441-
sat::SolveCpModel(cp_model, &sat_model);
367+
*cp_response = sat::SolveCpModel(cp_model, sat_model);
442368

443369
// Convert the response.
444370
//
445371
// TODO(user): Implement the row and column status.
446372
response.mutable_solve_info()->set_solve_wall_time_seconds(
447-
cp_response.wall_time());
373+
cp_response->wall_time());
448374
response.mutable_solve_info()->set_solve_user_time_seconds(
449-
cp_response.user_time());
450-
response.set_status(
451-
ToMPSolverResponseStatus(cp_response.status(), cp_model.has_objective()));
375+
cp_response->user_time());
376+
response.set_status(ToMPSolverResponseStatus(cp_response->status(),
377+
cp_model.has_objective()));
452378
if (response.status() == MPSOLVER_FEASIBLE ||
453379
response.status() == MPSOLVER_OPTIMAL) {
454-
response.set_objective_value(cp_response.objective_value());
455-
response.set_best_objective_bound(cp_response.best_objective_bound());
456-
MPSolution post_solved_solution = post_solve(cp_response);
380+
response.set_objective_value(cp_response->objective_value());
381+
response.set_best_objective_bound(cp_response->best_objective_bound());
382+
MPSolution post_solved_solution = post_solve(*cp_response);
457383
*response.mutable_variable_value() =
458384
std::move(*post_solved_solution.mutable_variable_value());
459385
}
@@ -462,9 +388,9 @@ MPSolutionResponse SatSolveProto(
462388
//
463389
// TODO(user): Remove the postsolve hack of copying to a response.
464390
for (const sat::CpSolverSolution& additional_solution :
465-
cp_response.additional_solutions()) {
391+
cp_response->additional_solutions()) {
466392
if (absl::MakeConstSpan(additional_solution.values()) ==
467-
absl::MakeConstSpan(cp_response.solution())) {
393+
absl::MakeConstSpan(cp_response->solution())) {
468394
continue;
469395
}
470396
double obj = cp_model.floating_point_objective().offset();
@@ -494,6 +420,84 @@ MPSolutionResponse SatSolveProto(
494420
return response;
495421
}
496422

423+
MPSolutionResponse SatSolveProto(
424+
LazyMutableCopy<MPModelRequest> request, std::atomic<bool>* interrupt_solve,
425+
std::function<void(const std::string&)> logging_callback,
426+
std::function<void(const MPSolution&)> solution_callback,
427+
std::function<void(const double)> best_bound_callback) {
428+
sat::Model sat_model;
429+
sat::SatParameters& params = *sat_model.GetOrCreate<sat::SatParameters>();
430+
params.set_log_search_progress(request->enable_internal_solver_output());
431+
432+
// TODO(user): We do not support all the parameters here. In particular the
433+
// logs before the solver is called will not be appended to the response. Fix
434+
// that, and remove code duplication for the logger config. One way should be
435+
// to not touch/configure anything if the logger is already created while
436+
// calling SolveCpModel() and call a common config function from here or from
437+
// inside Solve()?
438+
SolverLogger* logger = sat_model.GetOrCreate<SolverLogger>();
439+
if (logging_callback != nullptr) {
440+
logger->AddInfoLoggingCallback(logging_callback);
441+
}
442+
logger->EnableLogging(params.log_search_progress());
443+
logger->SetLogToStdOut(params.log_to_stdout());
444+
445+
// Set it now so that it can be overwritten by the solver specific parameters.
446+
if (request->has_solver_specific_parameters()) {
447+
// See EncodeSatParametersAsString() documentation.
448+
if constexpr (!std::is_base_of<Message, sat::SatParameters>::value) {
449+
if (!params.MergeFromString(request->solver_specific_parameters())) {
450+
return InvalidParametersResponse(
451+
*logger,
452+
"solver_specific_parameters is not a valid binary stream of the "
453+
"SatParameters proto");
454+
}
455+
} else {
456+
if (!ProtobufTextFormatMergeFromString(
457+
request->solver_specific_parameters(), &params)) {
458+
return InvalidParametersResponse(
459+
*logger,
460+
"solver_specific_parameters is not a valid textual representation "
461+
"of the SatParameters proto");
462+
}
463+
}
464+
}
465+
466+
// Validate parameters.
467+
{
468+
const std::string error = sat::ValidateParameters(params);
469+
if (!error.empty()) {
470+
return InvalidParametersResponse(
471+
*logger, absl::StrCat("Invalid CP-SAT parameters: ", error));
472+
}
473+
}
474+
475+
// Reconfigure the logger in case the solver_specific_parameters overwrite its
476+
// configuration. Note that the invalid parameter message will be logged
477+
// before that though according to request.enable_internal_solver_output().
478+
logger->EnableLogging(params.log_search_progress());
479+
logger->SetLogToStdOut(params.log_to_stdout());
480+
481+
if (request->has_solver_time_limit_seconds()) {
482+
params.set_max_time_in_seconds(request->solver_time_limit_seconds());
483+
}
484+
485+
sat_model.GetOrCreate<TimeLimit>()->ResetLimitFromParameters(params);
486+
487+
if (interrupt_solve != nullptr) {
488+
sat_model.GetOrCreate<TimeLimit>()->RegisterExternalBooleanAsLimit(
489+
interrupt_solve);
490+
}
491+
492+
if (best_bound_callback != nullptr) {
493+
sat_model.Add(sat::NewBestBoundCallback(best_bound_callback));
494+
}
495+
496+
sat::CpSolverResponse cp_response;
497+
return SatSolveProtoInternal(std::move(request), &sat_model, &cp_response,
498+
solution_callback);
499+
}
500+
497501
std::string SatSolverVersion() { return sat::CpSatSolverVersion(); }
498502

499503
} // namespace operations_research

ortools/linear_solver/proto_solver/sat_proto_solver.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include <string>
2020

2121
#include "ortools/linear_solver/linear_solver.pb.h"
22+
#include "ortools/sat/cp_model.pb.h"
23+
#include "ortools/sat/model.h"
2224
#include "ortools/util/lazy_mutable_copy.h"
2325
#include "ortools/util/logging.h"
2426

@@ -65,6 +67,13 @@ MPSolutionResponse SatSolveProto(
6567
// Returns a string that describes the version of the CP-SAT solver.
6668
std::string SatSolverVersion();
6769

70+
// Internal version of SatSolveProto that can configure a sat::Model object
71+
// before the solve and return the CpSolverResponse proto to extract statistics.
72+
MPSolutionResponse SatSolveProtoInternal(
73+
LazyMutableCopy<MPModelRequest> request, sat::Model* sat_model,
74+
sat::CpSolverResponse* cp_response,
75+
std::function<void(const MPSolution&)> solution_callback = nullptr);
76+
6877
} // namespace operations_research
6978

7079
#endif // ORTOOLS_LINEAR_SOLVER_PROTO_SOLVER_SAT_PROTO_SOLVER_H_

0 commit comments

Comments
 (0)