88
99namespace py = pybind11;
1010
11- extern " C" void esa_retrieval_launcher (torch::Tensor query, torch::Tensor repre_cache, torch::Tensor q_index, torch::Tensor repre_index,
12- torch::Tensor batch_offset, torch::Tensor workspace, torch::Tensor score, torch::Tensor score_sorted, torch::Tensor index, torch::Tensor index_sorted, int batch, int s);
11+ extern " C" int esa_retrieval_launcher (torch::Tensor query, torch::Tensor repre_cache, torch::Tensor q_index, torch::Tensor repre_index, torch::Tensor repre_index_cpu,
12+ torch::Tensor batch_offset, torch::Tensor score, torch::Tensor score_cpu, torch::Tensor score_sorted_cpu, torch::Tensor index_sorted_cpu,
13+ int batch, int s);
14+
15+ extern " C" int esa_retrieval_poll (int handle);
16+ extern " C" int esa_retrieval_cleanup (int handle);
17+ extern " C" int esa_retrieval_pending ();
18+ extern " C" void esa_retrieval_shutdown ();
1319
1420extern " C" void esa_topk (torch::Tensor score, torch::Tensor index, torch::Tensor offsets, torch::Tensor score_out, torch::Tensor index_out, torch::Tensor workspace);
1521
@@ -26,33 +32,41 @@ struct RetrievalInputTensor{
2632 torch::Tensor repre_cache;
2733 torch::Tensor q_index;
2834 torch::Tensor repre_index;
35+ torch::Tensor repre_index_cpu;
2936 torch::Tensor batch_offset;
30- torch::Tensor workspace;
3137 int batch;
3238 int s;
3339};
3440
3541struct RetrievalOutputTensor {
3642 torch::Tensor score;
37- torch::Tensor index;
38- torch::Tensor score_sorted;
39- torch::Tensor index_sorted;
43+ // New CPU pinned outputs for async D2H + host callback argsort
44+ torch::Tensor score_cpu; // 1D pinned CPU tensor [s], same dtype as score
45+ torch::Tensor score_sorted_cpu; // 1D pinned CPU tensor [s], same dtype as score
46+ torch::Tensor index_sorted_cpu; // 1D pinned CPU tensor [s], int32
4047};
4148
4249
43- void esa_retrieval (RetrievalInputTensor input, RetrievalOutputTensor output){
50+ int esa_retrieval (RetrievalInputTensor input, RetrievalOutputTensor output){
4451 auto query = input.query ;
4552 auto repre_cache = input.repre_cache ;
4653 auto q_index = input.q_index ;
4754 auto repre_index = input.repre_index ;
55+ auto repre_index_cpu = input.repre_index_cpu ;
4856 auto batch_offset = input.batch_offset ;
49- auto workspace = input.workspace ;
5057
5158 auto score = output.score ;
52- auto index = output.index ;
53- auto score_sorted = output.score_sorted ;
54- auto index_sorted = output.index_sorted ;
55- esa_retrieval_launcher (query, repre_cache, q_index, repre_index, batch_offset, workspace, score, score_sorted, index, index_sorted, input.batch , input.s );
59+ // CPU pinned outputs
60+ auto score_cpu = output.score_cpu ;
61+ auto score_sorted_cpu = output.score_sorted_cpu ;
62+ auto index_sorted_cpu = output.index_sorted_cpu ;
63+
64+ return esa_retrieval_launcher (
65+ query, repre_cache, q_index, repre_index, repre_index_cpu,
66+ batch_offset, score,
67+ score_cpu, score_sorted_cpu, index_sorted_cpu,
68+ input.batch , input.s
69+ );
5670}
5771
5872
@@ -68,22 +82,28 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6882 .def_readwrite (" repre_cache" , &RetrievalInputTensor::repre_cache)
6983 .def_readwrite (" q_index" , &RetrievalInputTensor::q_index)
7084 .def_readwrite (" repre_index" , &RetrievalInputTensor::repre_index)
85+ .def_readwrite (" repre_index_cpu" , &RetrievalInputTensor::repre_index_cpu)
7186 .def_readwrite (" batch_offset" , &RetrievalInputTensor::batch_offset)
72- .def_readwrite (" workspace" , &RetrievalInputTensor::workspace)
7387 .def_readwrite (" batch" , &RetrievalInputTensor::batch)
7488 .def_readwrite (" s" , &RetrievalInputTensor::s);
7589
7690 py::class_<RetrievalOutputTensor>(m, " RetrievalOutputTensor" )
7791 .def (py::init<>())
7892 .def_readwrite (" score" , &RetrievalOutputTensor::score)
79- .def_readwrite (" score_sorted " , &RetrievalOutputTensor::score_sorted )
80- .def_readwrite (" index " , &RetrievalOutputTensor::index )
81- .def_readwrite (" index_sorted " , &RetrievalOutputTensor::index_sorted );
93+ .def_readwrite (" score_cpu " , &RetrievalOutputTensor::score_cpu )
94+ .def_readwrite (" score_sorted_cpu " , &RetrievalOutputTensor::score_sorted_cpu )
95+ .def_readwrite (" index_sorted_cpu " , &RetrievalOutputTensor::index_sorted_cpu );
8296
8397 TORCH_BINDING_COMMON_EXTENSION (esa_retrieval);
8498 TORCH_BINDING_COMMON_EXTENSION (esa_topk);
8599 TORCH_BINDING_COMMON_EXTENSION (esa_repre);
86100 TORCH_BINDING_COMMON_EXTENSION (esa_copy);
87101 TORCH_BINDING_COMMON_EXTENSION (esa_scatter_copy);
88102 TORCH_BINDING_COMMON_EXTENSION (esa_copy_batch);
103+
104+ // Async retrieval helpers
105+ m.def (" esa_retrieval_poll" , &esa_retrieval_poll, " Poll whether CPU argsort finished (returns 0/1)" );
106+ m.def (" esa_retrieval_cleanup" , &esa_retrieval_cleanup, " Cleanup a retrieval handle" );
107+ m.def (" esa_retrieval_pending" , &esa_retrieval_pending, " Number of pending retrieval contexts" );
108+ m.def (" esa_retrieval_shutdown" , &esa_retrieval_shutdown, " Shutdown retrieval worker/callback streams" );
89109}
0 commit comments