Skip to content

Commit 4fcc3a6

Browse files
committed
update esa_kernel to the latest
1 parent 248b7fe commit 4fcc3a6

File tree

4 files changed

+337
-90
lines changed

4 files changed

+337
-90
lines changed

ucm/sparse/esa/build_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,5 @@ def build_shared(src_files, target, mode = "release"):
6464
subprocess.run(cmd, check=True)
6565

6666
if __name__ == "__main__":
67-
build_shared(["./diff_map_thrust_pybind.cu"], "diff_map.so")
67+
# build_shared(["./diff_map_thrust_pybind.cu"], "diff_map.so")
6868
build_shared(["./esa_interface.cc", "./esa_kernels.cu", "./esa_sm_copy.cu"], "esa_interface.so")

ucm/sparse/esa/esa_interface.cc

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,14 @@
88

99
namespace py = pybind11;
1010

11-
extern "C" void esa_retrieval_launcher(torch::Tensor query, torch::Tensor repre_cache, torch::Tensor q_index, torch::Tensor repre_index,
12-
torch::Tensor batch_offset, torch::Tensor workspace, torch::Tensor score, torch::Tensor score_sorted, torch::Tensor index, torch::Tensor index_sorted, int batch, int s);
11+
extern "C" int esa_retrieval_launcher(torch::Tensor query, torch::Tensor repre_cache, torch::Tensor q_index, torch::Tensor repre_index, torch::Tensor repre_index_cpu,
12+
torch::Tensor batch_offset, torch::Tensor score, torch::Tensor score_cpu, torch::Tensor score_sorted_cpu, torch::Tensor index_sorted_cpu,
13+
int batch, int s);
14+
15+
extern "C" int esa_retrieval_poll(int handle);
16+
extern "C" int esa_retrieval_cleanup(int handle);
17+
extern "C" int esa_retrieval_pending();
18+
extern "C" void esa_retrieval_shutdown();
1319

1420
extern "C" void esa_topk(torch::Tensor score, torch::Tensor index, torch::Tensor offsets, torch::Tensor score_out, torch::Tensor index_out, torch::Tensor workspace);
1521

@@ -26,33 +32,41 @@ struct RetrievalInputTensor{
2632
torch::Tensor repre_cache;
2733
torch::Tensor q_index;
2834
torch::Tensor repre_index;
35+
torch::Tensor repre_index_cpu;
2936
torch::Tensor batch_offset;
30-
torch::Tensor workspace;
3137
int batch;
3238
int s;
3339
};
3440

3541
struct RetrievalOutputTensor{
3642
torch::Tensor score;
37-
torch::Tensor index;
38-
torch::Tensor score_sorted;
39-
torch::Tensor index_sorted;
43+
// New CPU pinned outputs for async D2H + host callback argsort
44+
torch::Tensor score_cpu; // 1D pinned CPU tensor [s], same dtype as score
45+
torch::Tensor score_sorted_cpu; // 1D pinned CPU tensor [s], same dtype as score
46+
torch::Tensor index_sorted_cpu; // 1D pinned CPU tensor [s], int32
4047
};
4148

4249

43-
void esa_retrieval(RetrievalInputTensor input, RetrievalOutputTensor output){
50+
int esa_retrieval(RetrievalInputTensor input, RetrievalOutputTensor output){
4451
auto query = input.query;
4552
auto repre_cache = input.repre_cache;
4653
auto q_index = input.q_index;
4754
auto repre_index = input.repre_index;
55+
auto repre_index_cpu = input.repre_index_cpu;
4856
auto batch_offset = input.batch_offset;
49-
auto workspace = input.workspace;
5057

5158
auto score = output.score;
52-
auto index = output.index;
53-
auto score_sorted = output.score_sorted;
54-
auto index_sorted = output.index_sorted;
55-
esa_retrieval_launcher(query, repre_cache, q_index, repre_index, batch_offset, workspace, score, score_sorted, index, index_sorted, input.batch, input.s);
59+
// CPU pinned outputs
60+
auto score_cpu = output.score_cpu;
61+
auto score_sorted_cpu = output.score_sorted_cpu;
62+
auto index_sorted_cpu = output.index_sorted_cpu;
63+
64+
return esa_retrieval_launcher(
65+
query, repre_cache, q_index, repre_index, repre_index_cpu,
66+
batch_offset, score,
67+
score_cpu, score_sorted_cpu, index_sorted_cpu,
68+
input.batch, input.s
69+
);
5670
}
5771

5872

@@ -68,22 +82,28 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6882
.def_readwrite("repre_cache", &RetrievalInputTensor::repre_cache)
6983
.def_readwrite("q_index", &RetrievalInputTensor::q_index)
7084
.def_readwrite("repre_index", &RetrievalInputTensor::repre_index)
85+
.def_readwrite("repre_index_cpu", &RetrievalInputTensor::repre_index_cpu)
7186
.def_readwrite("batch_offset", &RetrievalInputTensor::batch_offset)
72-
.def_readwrite("workspace", &RetrievalInputTensor::workspace)
7387
.def_readwrite("batch", &RetrievalInputTensor::batch)
7488
.def_readwrite("s", &RetrievalInputTensor::s);
7589

7690
py::class_<RetrievalOutputTensor>(m, "RetrievalOutputTensor")
7791
.def(py::init<>())
7892
.def_readwrite("score", &RetrievalOutputTensor::score)
79-
.def_readwrite("score_sorted", &RetrievalOutputTensor::score_sorted)
80-
.def_readwrite("index", &RetrievalOutputTensor::index)
81-
.def_readwrite("index_sorted", &RetrievalOutputTensor::index_sorted);
93+
.def_readwrite("score_cpu", &RetrievalOutputTensor::score_cpu)
94+
.def_readwrite("score_sorted_cpu", &RetrievalOutputTensor::score_sorted_cpu)
95+
.def_readwrite("index_sorted_cpu", &RetrievalOutputTensor::index_sorted_cpu);
8296

8397
TORCH_BINDING_COMMON_EXTENSION(esa_retrieval);
8498
TORCH_BINDING_COMMON_EXTENSION(esa_topk);
8599
TORCH_BINDING_COMMON_EXTENSION(esa_repre);
86100
TORCH_BINDING_COMMON_EXTENSION(esa_copy);
87101
TORCH_BINDING_COMMON_EXTENSION(esa_scatter_copy);
88102
TORCH_BINDING_COMMON_EXTENSION(esa_copy_batch);
103+
104+
// Async retrieval helpers
105+
m.def("esa_retrieval_poll", &esa_retrieval_poll, "Poll whether CPU argsort finished (returns 0/1)");
106+
m.def("esa_retrieval_cleanup", &esa_retrieval_cleanup, "Cleanup a retrieval handle");
107+
m.def("esa_retrieval_pending", &esa_retrieval_pending, "Number of pending retrieval contexts");
108+
m.def("esa_retrieval_shutdown", &esa_retrieval_shutdown, "Shutdown retrieval worker/callback streams");
89109
}

0 commit comments

Comments
 (0)