Skip to content

Commit 95136b3

Browse files
committed
Implement batch solve api
1 parent 052e454 commit 95136b3

File tree

4 files changed

+194
-0
lines changed

4 files changed

+194
-0
lines changed

cpp/include/cuopt/routing/cython/cython.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include <raft/core/handle.hpp>
1717

1818
#include <memory>
19+
#include <utility>
20+
#include <vector>
1921

2022
namespace cuopt {
2123
namespace cython {
@@ -82,6 +84,10 @@ struct dataset_ret_t {
8284
std::unique_ptr<vehicle_routing_ret_t> call_solve(routing::data_model_view_t<int, float>*,
8385
routing::solver_settings_t<int, float>*);
8486

87+
// Wrapper for batch solve to expose the API to cython.
88+
std::pair<std::vector<std::unique_ptr<vehicle_routing_ret_t>>, double> call_batch_solve(
89+
std::vector<routing::data_model_view_t<int, float>*>, routing::solver_settings_t<int, float>*);
90+
8591
// Wrapper for dataset to expose the API to cython.
8692
std::unique_ptr<dataset_ret_t> call_generate_dataset(
8793
raft::handle_t const& handle, routing::generator::dataset_params_t<int, float> const& params);

cpp/src/routing/utilities/cython.cu

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77

88
#include <cuopt/routing/cython/cython.hpp>
99
#include <cuopt/routing/solve.hpp>
10+
#include <raft/common/nvtx.hpp>
1011
#include <raft/core/handle.hpp>
1112
#include <rmm/device_buffer.hpp>
1213
#include <routing/generator/generator.hpp>
1314

15+
#include <omp.h>
16+
#include <chrono>
17+
1418
namespace cuopt {
1519
namespace cython {
1620

@@ -86,6 +90,54 @@ std::unique_ptr<vehicle_routing_ret_t> call_solve(
8690
return std::make_unique<vehicle_routing_ret_t>(std::move(vr_ret));
8791
}
8892

93+
/**
94+
* @brief Wrapper for batch vehicle_routing to expose the API to cython
95+
*
96+
* @param data_models Vector of data model pointers
97+
* @param settings Composable solver settings object
98+
* @return std::pair<std::vector<std::unique_ptr<vehicle_routing_ret_t>>, double>
99+
*/
100+
std::pair<std::vector<std::unique_ptr<vehicle_routing_ret_t>>, double> call_batch_solve(
101+
std::vector<routing::data_model_view_t<int, float>*> data_models,
102+
routing::solver_settings_t<int, float>* settings)
103+
{
104+
raft::common::nvtx::range fun_scope("Call batch solve routing");
105+
106+
const std::size_t size = data_models.size();
107+
std::vector<std::unique_ptr<vehicle_routing_ret_t>> list(size);
108+
109+
auto start_solver = std::chrono::high_resolution_clock::now();
110+
111+
// Use OpenMP for parallel execution
112+
const int max_thread = std::min(static_cast<int>(size), omp_get_max_threads());
113+
114+
#pragma omp parallel for num_threads(max_thread)
115+
for (std::size_t i = 0; i < size; ++i) {
116+
auto routing_solution = cuopt::routing::solve(*data_models[i], *settings);
117+
vehicle_routing_ret_t vr_ret{
118+
routing_solution.get_vehicle_count(),
119+
routing_solution.get_total_objective(),
120+
routing_solution.get_objectives(),
121+
std::make_unique<rmm::device_buffer>(routing_solution.get_route().release()),
122+
std::make_unique<rmm::device_buffer>(routing_solution.get_order_locations().release()),
123+
std::make_unique<rmm::device_buffer>(routing_solution.get_arrival_stamp().release()),
124+
std::make_unique<rmm::device_buffer>(routing_solution.get_truck_id().release()),
125+
std::make_unique<rmm::device_buffer>(routing_solution.get_node_types().release()),
126+
std::make_unique<rmm::device_buffer>(routing_solution.get_unserviced_nodes().release()),
127+
std::make_unique<rmm::device_buffer>(routing_solution.get_accepted().release()),
128+
routing_solution.get_status(),
129+
routing_solution.get_status_string(),
130+
routing_solution.get_error_status().get_error_type(),
131+
routing_solution.get_error_status().what()};
132+
list[i] = std::make_unique<vehicle_routing_ret_t>(std::move(vr_ret));
133+
}
134+
135+
auto end = std::chrono::high_resolution_clock::now();
136+
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_solver);
137+
138+
return {std::move(list), duration.count() / 1000.0};
139+
}
140+
89141
/**
90142
* @brief Wrapper for dataset_t to expose the API to cython.
91143
* @param solver Composable solver object

python/cuopt/cuopt/routing/vehicle_routing.pxd

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
# cython: language_level = 3
99

1010
from libcpp cimport bool
11+
from libcpp.pair cimport pair
1112
from libcpp.string cimport string
13+
from libcpp.vector cimport vector
1214

1315
from pylibraft.common.handle cimport *
1416

@@ -133,3 +135,8 @@ cdef extern from "cuopt/routing/cython/cython.hpp" namespace "cuopt::cython": #
133135
data_model_view_t[int, float]* data_model,
134136
solver_settings_t[int, float]* solver_settings
135137
) except +
138+
139+
cdef pair[vector[unique_ptr[vehicle_routing_ret_t]], double] call_batch_solve(
140+
vector[data_model_view_t[int, float] *] data_models,
141+
solver_settings_t[int, float]* solver_settings,
142+
) except +

python/cuopt/cuopt/routing/vehicle_routing_wrapper.pyx

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from pylibraft.common.handle cimport *
1111

1212
from cuopt.routing.structure.routing_utilities cimport *
1313
from cuopt.routing.vehicle_routing cimport (
14+
call_batch_solve,
1415
call_solve,
1516
data_model_view_t,
1617
node_type_t,
@@ -32,8 +33,10 @@ from libc.stdlib cimport free, malloc
3233
from libc.string cimport memcpy, strcpy, strlen
3334
from libcpp cimport bool
3435
from libcpp.memory cimport unique_ptr
36+
from libcpp.pair cimport pair
3537
from libcpp.string cimport string
3638
from libcpp.utility cimport move
39+
from libcpp.vector cimport vector
3740

3841
from rmm.pylibrmm.device_buffer cimport DeviceBuffer
3942

@@ -834,3 +837,129 @@ def Solve(DataModel data_model, SolverSettings solver_settings):
834837
error_message,
835838
unserviced_nodes
836839
)
840+
841+
842+
cdef create_assignment_from_vr_ret(vehicle_routing_ret_t& vr_ret):
843+
"""Helper function to create an Assignment from a vehicle_routing_ret_t"""
844+
vehicle_count = vr_ret.vehicle_count_
845+
total_objective_value = vr_ret.total_objective_value_
846+
847+
objective_values = {}
848+
for k in vr_ret.objective_values_:
849+
obj = Objective(int(k.first))
850+
objective_values[obj] = k.second
851+
852+
status = vr_ret.status_
853+
cdef char* c_sol_string = c_get_string(vr_ret.solution_string_)
854+
try:
855+
solver_status_string = \
856+
c_sol_string[:vr_ret.solution_string_.length()].decode('UTF-8')
857+
finally:
858+
free(c_sol_string)
859+
860+
route = DeviceBuffer.c_from_unique_ptr(move(vr_ret.d_route_))
861+
route_locations = DeviceBuffer.c_from_unique_ptr(
862+
move(vr_ret.d_route_locations_)
863+
)
864+
arrival_stamp = DeviceBuffer.c_from_unique_ptr(
865+
move(vr_ret.d_arrival_stamp_)
866+
)
867+
truck_id = DeviceBuffer.c_from_unique_ptr(move(vr_ret.d_truck_id_))
868+
node_types = DeviceBuffer.c_from_unique_ptr(move(vr_ret.d_node_types_))
869+
unserviced_nodes_buf = \
870+
DeviceBuffer.c_from_unique_ptr(move(vr_ret.d_unserviced_nodes_))
871+
accepted_buf = \
872+
DeviceBuffer.c_from_unique_ptr(move(vr_ret.d_accepted_))
873+
874+
route_df = cudf.DataFrame()
875+
route_df['route'] = series_from_buf(route, pa.int32())
876+
route_df['arrival_stamp'] = series_from_buf(arrival_stamp, pa.float64())
877+
route_df['truck_id'] = series_from_buf(truck_id, pa.int32())
878+
route_df['location'] = series_from_buf(route_locations, pa.int32())
879+
route_df['type'] = series_from_buf(node_types, pa.int32())
880+
881+
unserviced_nodes = cudf.Series._from_column(
882+
series_from_buf(unserviced_nodes_buf, pa.int32())
883+
)
884+
accepted = cudf.Series._from_column(
885+
series_from_buf(accepted_buf, pa.int32())
886+
)
887+
888+
def get_type_from_int(type_in_int):
889+
if type_in_int == int(NodeType.DEPOT):
890+
return "Depot"
891+
elif type_in_int == int(NodeType.PICKUP):
892+
return "Pickup"
893+
elif type_in_int == int(NodeType.DELIVERY):
894+
return "Delivery"
895+
elif type_in_int == int(NodeType.BREAK):
896+
return "Break"
897+
898+
node_types_string = [
899+
get_type_from_int(type_in_int)
900+
for type_in_int in route_df['type'].to_pandas()]
901+
route_df['type'] = node_types_string
902+
error_status = vr_ret.error_status_
903+
error_message = vr_ret.error_message_
904+
905+
return Assignment(
906+
vehicle_count,
907+
total_objective_value,
908+
objective_values,
909+
route_df,
910+
accepted,
911+
<solution_status_t> status,
912+
solver_status_string,
913+
<error_type_t> error_status,
914+
error_message,
915+
unserviced_nodes
916+
)
917+
918+
919+
def BatchSolve(py_data_model_list, SolverSettings solver_settings):
920+
"""
921+
Solve multiple routing problems in batch mode using parallel execution.
922+
923+
Parameters
924+
----------
925+
py_data_model_list : list of DataModel
926+
List of data model objects representing routing problems to solve.
927+
solver_settings : SolverSettings
928+
Solver settings to use for all problems.
929+
930+
Returns
931+
-------
932+
tuple
933+
A tuple containing:
934+
- list of Assignment: Solutions for each routing problem
935+
- float: Total solve time in seconds
936+
"""
937+
cdef solver_settings_t[int, float]* c_solver_settings = (
938+
solver_settings.c_solver_settings.get()
939+
)
940+
941+
cdef vector[data_model_view_t[int, float] *] data_model_views
942+
943+
for data_model_obj in py_data_model_list:
944+
data_model_views.push_back(
945+
(<DataModel>data_model_obj).c_data_model_view.get()
946+
)
947+
948+
cdef pair[
949+
vector[unique_ptr[vehicle_routing_ret_t]],
950+
double] batch_solve_result = (
951+
move(call_batch_solve(data_model_views, c_solver_settings))
952+
)
953+
954+
cdef vector[unique_ptr[vehicle_routing_ret_t]] c_solutions = (
955+
move(batch_solve_result.first)
956+
)
957+
cdef double solve_time = batch_solve_result.second
958+
959+
solutions = []
960+
for i in range(c_solutions.size()):
961+
solutions.append(
962+
create_assignment_from_vr_ret(c_solutions[i].get()[0])
963+
)
964+
965+
return solutions, solve_time

0 commit comments

Comments
 (0)