Skip to content

Commit de848af

Browse files
authored
Spectral Clustering dataset api (rapidsai#1653)
Resolves rapidsai#1580 Adding an API to support passing in a dataset to spectral clustering. Authors: - Anupam (https://github.com/aamijar) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#1653
1 parent e497d97 commit de848af

File tree

3 files changed

+58
-0
lines changed

3 files changed

+58
-0
lines changed

cpp/include/cuvs/cluster/spectral.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,38 @@ void fit_predict(raft::resources const& handle,
124124
raft::device_coo_matrix_view<double, int, int, int> connectivity_graph,
125125
raft::device_vector_view<int, int> labels);
126126

127+
/**
128+
* @brief Perform spectral clustering on a dense dataset
129+
*
130+
* This overload automatically constructs the connectivity graph from the input dataset
131+
* using k-nearest neighbors.
132+
*
133+
* @param[in] handle RAFT resource handle
134+
* @param[in] config Spectral clustering parameters
135+
* @param[in] dataset Dense row-major matrix of shape (n_samples, n_features)
136+
* @param[out] labels Device vector of size n_samples to store cluster assignments (0 to
137+
* n_clusters-1)
138+
*
139+
* @code{.cpp}
140+
* #include <cuvs/cluster/spectral.hpp>
141+
*
142+
* raft::resources handle;
143+
*
144+
* // Configure spectral clustering
145+
* cuvs::cluster::spectral::params params;
146+
* params.n_clusters = 5;
147+
* params.n_components = 5;
148+
* params.n_neighbors = 15;
149+
* params.n_init = 10;
150+
*
151+
* auto labels = raft::make_device_vector<int>(handle, n_samples);
152+
* cuvs::cluster::spectral::fit_predict(handle, params, X.view(), labels.view());
153+
* @endcode
154+
*/
155+
void fit_predict(raft::resources const& handle,
156+
params config,
157+
raft::device_matrix_view<float, int, raft::row_major> dataset,
158+
raft::device_vector_view<int, int> labels);
127159
/**
128160
* @}
129161
*/

cpp/src/cluster/detail/spectral.cuh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,22 @@ void fit_predict(raft::resources const& handle,
6161
raft::make_host_scalar_view(&n_iter));
6262
}
6363

64+
void fit_predict(raft::resources const& handle,
65+
params config,
66+
raft::device_matrix_view<float, int, raft::row_major> dataset,
67+
raft::device_vector_view<int, int> labels)
68+
{
69+
int n_samples = dataset.extent(0);
70+
71+
auto graph = raft::make_device_coo_matrix<float, int, int, int>(handle, n_samples, n_samples);
72+
73+
cuvs::preprocessing::spectral_embedding::params embed_params;
74+
embed_params.n_neighbors = config.n_neighbors;
75+
76+
cuvs::preprocessing::spectral_embedding::helpers::create_connectivity_graph(
77+
handle, embed_params, dataset, graph);
78+
79+
fit_predict(handle, config, graph.view(), labels);
80+
}
81+
6482
} // namespace cuvs::cluster::spectral::detail

cpp/src/cluster/spectral.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,12 @@ CUVS_INST_SPECTRAL(double);
2323

2424
#undef CUVS_INST_SPECTRAL
2525

26+
void fit_predict(raft::resources const& handle,
27+
params config,
28+
raft::device_matrix_view<float, int, raft::row_major> dataset,
29+
raft::device_vector_view<int, int> labels)
30+
{
31+
detail::fit_predict(handle, config, dataset, labels);
32+
}
33+
2634
} // namespace cuvs::cluster::spectral

0 commit comments

Comments
 (0)