Skip to content

Commit c4a0461

Browse files
committed
CSR Data transfer from Daphne to Python in Scipy csr form
1 parent ecf741a commit c4a0461

File tree

3 files changed

+115
-8
lines changed

3 files changed

+115
-8
lines changed

src/api/python/daphne/operator/operation_node.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929

3030
import numpy as np
3131
import pandas as pd
32+
try:
33+
import scipy.sparse as sp
34+
except ImportError as e:
35+
sp = e
3236
try:
3337
import torch as torch
3438
except ImportError as e:
@@ -188,11 +192,37 @@ def compute(self, type="shared memory", verbose=False, asTensorFlow=False, asPyT
188192
self.clear_tmp()
189193
elif self._output_type == OutputType.MATRIX and type=="shared memory":
190194
daphneLibResult = DaphneLib.getResult()
191-
result = np.ctypeslib.as_array(
192-
ctypes.cast(daphneLibResult.address, ctypes.POINTER(self.getType(daphneLibResult.vtc))),
193-
shape=[daphneLibResult.rows, daphneLibResult.cols]
194-
)
195-
self.clear_tmp()
195+
if not daphneLibResult.isSparse:
196+
# Dense Matrix
197+
daphneLibResult = DaphneLib.getResult()
198+
result = np.ctypeslib.as_array(
199+
ctypes.cast(daphneLibResult.address, ctypes.POINTER(self.getType(daphneLibResult.vtc))),
200+
shape=[daphneLibResult.rows, daphneLibResult.cols]
201+
)
202+
else:
203+
# CSR Matrix
204+
VT = self.getType(daphneLibResult.vtc)
205+
206+
# wrap each pointer into a numpy array
207+
indptr = np.ctypeslib.as_array(
208+
ctypes.cast(daphneLibResult.row_related, ctypes.POINTER(ctypes.c_size_t)),
209+
shape=(daphneLibResult.rows + 1,)
210+
)
211+
nnz = int(indptr[-1] - indptr[0])
212+
213+
214+
data = np.ctypeslib.as_array(
215+
ctypes.cast(daphneLibResult.data, ctypes.POINTER(VT)),
216+
shape=(nnz,)
217+
)
218+
219+
indices = np.ctypeslib.as_array(
220+
ctypes.cast(daphneLibResult.col_related, ctypes.POINTER(ctypes.c_size_t)),
221+
shape=(nnz,)
222+
)
223+
# build scipy CSR
224+
result = sp.csr_matrix((data, indices, indptr), shape=(daphneLibResult.rows, daphneLibResult.cols))
225+
self.clear_tmp()
196226
elif self._output_type == OutputType.MATRIX and type=="files":
197227
arr = np.genfromtxt(result, delimiter=',')
198228
self.clear_tmp()

src/api/python/daphne/utils/daphnelib.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,28 @@
2020
# Python representation of the struct DaphneLibResult.
2121
class DaphneLibResult(ctypes.Structure):
2222
_fields_ = [
23-
# For matrices.
24-
("address", ctypes.c_void_p),
23+
# For matrices
2524
("rows", ctypes.c_int64),
2625
("cols", ctypes.c_int64),
2726
("vtc", ctypes.c_int64),
27+
("isSparse", ctypes.c_bool),
28+
29+
# For dense matrices
30+
("address", ctypes.c_void_p),
31+
32+
# For sparse matrices
33+
("data", ctypes.c_void_p),
34+
("row_related",ctypes.c_void_p),
35+
("col_related",ctypes.c_void_p),
36+
2837
# For frames.
2938
("vtcs", ctypes.POINTER(ctypes.c_int64)),
3039
("labels", ctypes.POINTER(ctypes.c_char_p)),
3140
("columns", ctypes.POINTER(ctypes.c_void_p)),
41+
3242
# To pass error messages to Python code.
3343
("error_message", ctypes.c_char_p)
3444
]
3545

3646
DaphneLib = ctypes.CDLL(os.path.join(PROTOTYPE_PATH, DAPHNELIB_FILENAME))
37-
DaphneLib.getResult.restype = DaphneLibResult
47+
DaphneLib.getResult.restype = DaphneLibResult

src/runtime/local/kernels/SaveDaphneLibResult.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#define SRC_RUNTIME_LOCAL_KERNELS_SAVEDAPHNELIBRESULT_H
1919

2020
#include <runtime/local/context/DaphneContext.h>
21+
#include <runtime/local/datastructures/CSRMatrix.h>
2122
#include <runtime/local/datastructures/DenseMatrix.h>
2223
#include <runtime/local/datastructures/Frame.h>
2324

@@ -64,6 +65,72 @@ template <typename VT> struct SaveDaphneLibResult<DenseMatrix<VT>> {
6465
}
6566
};
6667

68+
// ----------------------------------------------------------------------------
69+
// CSRMatrix
70+
// ----------------------------------------------------------------------------
71+
72+
template <typename VT> struct SaveDaphneLibResult<CSRMatrix<VT>> {
73+
static void apply(const CSRMatrix<VT> *arg, DCTX(ctx)) {
74+
75+
const_cast<CSRMatrix<VT> *>(arg)->increaseRefCounter();
76+
77+
DaphneLibResult *daphneLibRes = ctx->getUserConfig().result_struct;
78+
if (!daphneLibRes)
79+
throw std::runtime_error("saveDaphneLibRes(): daphneLibRes is nullptr");
80+
81+
// Result is a sparse matrix
82+
daphneLibRes->isSparse = true;
83+
84+
const size_t rows = arg->getNumRows();
85+
const size_t cols = arg->getNumCols();
86+
daphneLibRes->rows = rows;
87+
daphneLibRes->cols = cols;
88+
daphneLibRes->vtc = (int64_t)ValueTypeUtils::codeFor<VT>;
89+
90+
// original raw pointers
91+
auto origVals = arg->getValuesSharedPtr().get();
92+
auto origCols = arg->getColIdxsSharedPtr().get();
93+
auto origOffsets = arg->getRowOffsetsSharedPtr().get();
94+
95+
// first pass: count *actual* non-zeros
96+
size_t actualNNZ = 0;
97+
for (size_t r = 0; r < rows; r++) {
98+
auto start = origOffsets[r];
99+
auto end = origOffsets[r + 1];
100+
for (size_t i = start; i < end; i++)
101+
if (origVals[i] != VT(0))
102+
actualNNZ++;
103+
}
104+
105+
// allocate new tight buffers
106+
VT *cleanVals = new VT[actualNNZ];
107+
size_t *cleanCols = new size_t[actualNNZ];
108+
size_t *cleanOffs = new size_t[rows + 1];
109+
110+
// second pass: copy only non-zeros
111+
size_t p = 0;
112+
cleanOffs[0] = 0;
113+
for (size_t r = 0; r < rows; r++) {
114+
auto start = origOffsets[r];
115+
auto end = origOffsets[r + 1];
116+
for (size_t i = start; i < end; i++) {
117+
VT v = origVals[i];
118+
if (v != VT(0)) {
119+
cleanVals[p] = v;
120+
cleanCols[p] = origCols[i];
121+
p++;
122+
}
123+
}
124+
cleanOffs[r + 1] = p;
125+
}
126+
127+
// hand them off in the result struct
128+
daphneLibRes->data = static_cast<void *>(cleanVals);
129+
daphneLibRes->col_related = static_cast<void *>(cleanCols);
130+
daphneLibRes->row_related = static_cast<void *>(cleanOffs);
131+
}
132+
};
133+
67134
// ----------------------------------------------------------------------------
68135
// Frame
69136
// ----------------------------------------------------------------------------

0 commit comments

Comments
 (0)