Skip to content

Commit 23aa4c6

Browse files
authored
Improve memory usage in build_mr_linkage (rapidsai#1550)
Closes rapidsai#1542 This PR improves memory usage in the `build_mr_linkage` function by dropping resources early. Authors: - Jinsol Park (https://github.com/jinsolp) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#1550
1 parent 923cce5 commit 23aa4c6

File tree

1 file changed

+74
-67
lines changed

1 file changed

+74
-67
lines changed

cpp/src/cluster/detail/single_linkage.cuh

Lines changed: 74 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -61,73 +61,80 @@ void build_mr_linkage(
6161
size_t n = X.extent(1);
6262
auto stream = raft::resource::get_cuda_stream(handle);
6363

64-
auto mr_indptr = raft::make_device_vector<value_idx, value_idx>(handle, m + 1);
65-
raft::sparse::COO<value_t, value_idx, nnz_t> mr_coo(stream, min_samples * m * 2);
66-
67-
auto inds = raft::make_device_matrix<value_idx, value_idx>(handle, m, min_samples);
68-
auto dists = raft::make_device_matrix<value_t, value_idx>(handle, m, min_samples);
69-
70-
if (all_neighbors_p.metric != metric) {
71-
RAFT_LOG_WARN("Setting all neighbors metric to given metrix for build_mr_linkage");
72-
all_neighbors_p.metric = metric;
73-
}
74-
cuvs::neighbors::all_neighbors::build(
75-
handle, all_neighbors_p, X, inds.view(), dists.view(), core_dists, alpha);
76-
77-
// self-loops get max distance
78-
auto coo_rows = raft::make_device_vector<value_idx, value_idx>(handle, min_samples * m);
79-
raft::linalg::map_offset(handle, coo_rows.view(), raft::div_const_op<value_idx>(min_samples));
80-
81-
raft::sparse::linalg::symmetrize(handle,
82-
coo_rows.data_handle(),
83-
inds.data_handle(),
84-
dists.data_handle(),
85-
static_cast<value_idx>(m),
86-
static_cast<value_idx>(m),
87-
static_cast<nnz_t>(min_samples * m),
88-
mr_coo);
89-
90-
raft::sparse::convert::sorted_coo_to_csr(
91-
mr_coo.rows(), mr_coo.nnz, mr_indptr.data_handle(), m + 1, stream);
92-
93-
auto rows_view = raft::make_device_vector_view<const value_idx, nnz_t>(mr_coo.rows(), mr_coo.nnz);
94-
auto cols_view = raft::make_device_vector_view<const value_idx, nnz_t>(mr_coo.cols(), mr_coo.nnz);
95-
auto vals_in_view =
96-
raft::make_device_vector_view<const value_t, nnz_t>(mr_coo.vals(), mr_coo.nnz);
97-
auto vals_out_view = raft::make_device_vector_view<value_t, nnz_t>(mr_coo.vals(), mr_coo.nnz);
98-
99-
raft::linalg::map(
100-
handle,
101-
vals_out_view,
102-
[=] __device__(const value_idx row, const value_idx col, const value_t val) {
103-
return row == col ? std::numeric_limits<value_t>::max() : val;
104-
},
105-
rows_view,
106-
cols_view,
107-
vals_in_view);
108-
109-
rmm::device_uvector<value_idx> color(m, raft::resource::get_cuda_stream(handle));
110-
cuvs::sparse::neighbors::MutualReachabilityFixConnectivitiesRedOp<value_idx, value_t>
111-
reduction_op(core_dists.data_handle(), m);
112-
113-
size_t nnz = m * min_samples;
114-
115-
detail::build_sorted_mst<value_idx, value_t>(handle,
116-
X.data_handle(),
117-
mr_indptr.data_handle(),
118-
mr_coo.cols(),
119-
mr_coo.vals(),
120-
m,
121-
n,
122-
out_mst.structure_view().get_rows().data(),
123-
out_mst.structure_view().get_cols().data(),
124-
out_mst.get_elements().data(),
125-
color.data(),
126-
mr_coo.nnz,
127-
reduction_op,
128-
metric,
129-
10);
130-
64+
{ // scope to drop mr_coo and mr_indptr early
65+
std::optional<raft::sparse::COO<value_t, value_idx, nnz_t>> mr_coo;
66+
67+
{ // scope to drop inds and dists matrices early
68+
auto inds = raft::make_device_matrix<value_idx, value_idx>(handle, m, min_samples);
69+
auto dists = raft::make_device_matrix<value_t, value_idx>(handle, m, min_samples);
70+
71+
if (all_neighbors_p.metric != metric) {
72+
RAFT_LOG_WARN("Setting all neighbors metric to given metrix for build_mr_linkage");
73+
all_neighbors_p.metric = metric;
74+
}
75+
cuvs::neighbors::all_neighbors::build(
76+
handle, all_neighbors_p, X, inds.view(), dists.view(), core_dists, alpha);
77+
78+
// allocate memory after all neighbors build
79+
mr_coo.emplace(stream, min_samples * m * 2);
80+
// self-loops get max distance
81+
auto coo_rows = raft::make_device_vector<value_idx, value_idx>(handle, min_samples * m);
82+
raft::linalg::map_offset(handle, coo_rows.view(), raft::div_const_op<value_idx>(min_samples));
83+
84+
raft::sparse::linalg::symmetrize(handle,
85+
coo_rows.data_handle(),
86+
inds.data_handle(),
87+
dists.data_handle(),
88+
static_cast<value_idx>(m),
89+
static_cast<value_idx>(m),
90+
static_cast<nnz_t>(min_samples * m),
91+
mr_coo.value());
92+
} // scope to drop inds and dists matrices early
93+
auto mr_indptr = raft::make_device_vector<value_idx, value_idx>(handle, m + 1);
94+
raft::sparse::convert::sorted_coo_to_csr(
95+
mr_coo.value().rows(), mr_coo.value().nnz, mr_indptr.data_handle(), m + 1, stream);
96+
97+
auto rows_view = raft::make_device_vector_view<const value_idx, nnz_t>(mr_coo.value().rows(),
98+
mr_coo.value().nnz);
99+
auto cols_view = raft::make_device_vector_view<const value_idx, nnz_t>(mr_coo.value().cols(),
100+
mr_coo.value().nnz);
101+
auto vals_in_view = raft::make_device_vector_view<const value_t, nnz_t>(mr_coo.value().vals(),
102+
mr_coo.value().nnz);
103+
auto vals_out_view =
104+
raft::make_device_vector_view<value_t, nnz_t>(mr_coo.value().vals(), mr_coo.value().nnz);
105+
106+
raft::linalg::map(
107+
handle,
108+
vals_out_view,
109+
[=] __device__(const value_idx row, const value_idx col, const value_t val) {
110+
return row == col ? std::numeric_limits<value_t>::max() : val;
111+
},
112+
rows_view,
113+
cols_view,
114+
vals_in_view);
115+
116+
rmm::device_uvector<value_idx> color(m, raft::resource::get_cuda_stream(handle));
117+
cuvs::sparse::neighbors::MutualReachabilityFixConnectivitiesRedOp<value_idx, value_t>
118+
reduction_op(core_dists.data_handle(), m);
119+
120+
size_t nnz = m * min_samples;
121+
122+
detail::build_sorted_mst<value_idx, value_t>(handle,
123+
X.data_handle(),
124+
mr_indptr.data_handle(),
125+
mr_coo.value().cols(),
126+
mr_coo.value().vals(),
127+
m,
128+
n,
129+
out_mst.structure_view().get_rows().data(),
130+
out_mst.structure_view().get_cols().data(),
131+
out_mst.get_elements().data(),
132+
color.data(),
133+
mr_coo.value().nnz,
134+
reduction_op,
135+
metric,
136+
10);
137+
} // scope to drop mr_coo and mr_indptr early
131138
/**
132139
* Perform hierarchical labeling
133140
*/

0 commit comments

Comments
 (0)