Skip to content

Commit 6e8282c

Browse files
committed
Add test that uses call batch solve api
1 parent 95136b3 commit 6e8282c

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/* clang-format off */
2+
/*
3+
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
* SPDX-License-Identifier: Apache-2.0
5+
*/
6+
/* clang-format on */
7+
8+
#include <cuopt/routing/cython/cython.hpp>
9+
#include <cuopt/routing/solve.hpp>
10+
#include <utilities/copy_helpers.hpp>
11+
12+
#include <raft/core/handle.hpp>
13+
14+
#include <gtest/gtest.h>
15+
16+
#include <vector>
17+
18+
namespace cuopt {
19+
namespace routing {
20+
namespace test {
21+
22+
using i_t = int;
23+
using f_t = float;
24+
25+
/**
26+
* @brief Creates a small symmetric cost matrix for TSP
27+
* @param n_locations Number of locations
28+
* @return Cost matrix as a flattened vector
29+
*/
30+
std::vector<f_t> create_small_tsp_cost_matrix(i_t n_locations)
31+
{
32+
std::vector<f_t> cost_matrix(n_locations * n_locations, 0.0f);
33+
34+
// Create a simple distance matrix based on coordinates on a line
35+
for (i_t i = 0; i < n_locations; ++i) {
36+
for (i_t j = 0; j < n_locations; ++j) {
37+
cost_matrix[i * n_locations + j] = static_cast<f_t>(std::abs(i - j));
38+
}
39+
}
40+
return cost_matrix;
41+
}
42+
43+
/**
44+
* @brief Test running TSPs of varying sizes in parallel using call_batch_solve API
45+
*/
46+
TEST(batch_tsp, varying_sizes)
47+
{
48+
std::vector<i_t> tsp_sizes = {5, 8, 10, 6, 7, 9};
49+
const i_t n_problems = static_cast<i_t>(tsp_sizes.size());
50+
51+
// Create handles and cost matrices for each problem
52+
std::vector<std::unique_ptr<raft::handle_t>> handles;
53+
std::vector<rmm::device_uvector<f_t>> cost_matrices_d;
54+
std::vector<std::unique_ptr<cuopt::routing::data_model_view_t<i_t, f_t>>> data_models;
55+
std::vector<cuopt::routing::data_model_view_t<i_t, f_t>*> data_model_ptrs;
56+
57+
for (i_t i = 0; i < n_problems; ++i) {
58+
handles.push_back(std::make_unique<raft::handle_t>());
59+
auto& handle = *handles.back();
60+
61+
auto cost_matrix_h = create_small_tsp_cost_matrix(tsp_sizes[i]);
62+
cost_matrices_d.push_back(cuopt::device_copy(cost_matrix_h, handle.get_stream()));
63+
64+
data_models.push_back(std::make_unique<cuopt::routing::data_model_view_t<i_t, f_t>>(
65+
&handle, tsp_sizes[i], 1, tsp_sizes[i]));
66+
data_models.back()->add_cost_matrix(cost_matrices_d.back().data());
67+
data_model_ptrs.push_back(data_models.back().get());
68+
}
69+
70+
// Configure solver settings
71+
cuopt::routing::solver_settings_t<i_t, f_t> settings;
72+
settings.set_time_limit(5);
73+
74+
// Call batch solve
75+
auto [solutions, solve_time] = cuopt::cython::call_batch_solve(data_model_ptrs, &settings);
76+
77+
// Verify all solutions
78+
ASSERT_EQ(solutions.size(), n_problems);
79+
for (i_t i = 0; i < n_problems; ++i) {
80+
EXPECT_EQ(solutions[i]->status_, cuopt::routing::solution_status_t::SUCCESS)
81+
<< "TSP " << i << " (size " << tsp_sizes[i] << ") failed";
82+
EXPECT_EQ(solutions[i]->vehicle_count_, 1)
83+
<< "TSP " << i << " (size " << tsp_sizes[i] << ") used multiple vehicles";
84+
}
85+
86+
// Verify solve time is reasonable
87+
EXPECT_GT(solve_time, 0.0) << "Solve time should be positive";
88+
}
89+
90+
} // namespace test
91+
} // namespace routing
92+
} // namespace cuopt

0 commit comments

Comments
 (0)