Skip to content

Commit 0f54f97

Browse files
committed
Update Matlab bindings to use zero-copy transfer.
1 parent 92194cd commit 0f54f97

File tree

1 file changed

+98
-91
lines changed

1 file changed

+98
-91
lines changed

bindings/matlab/binsparse_read.c

Lines changed: 98 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -19,107 +19,112 @@
1919
#include <binsparse/binsparse.h>
2020
#include <string.h>
2121

22-
/**
23-
* Convert bsp_array_t to MATLAB array
24-
*/
25-
mxArray* bsp_array_to_matlab(const bsp_array_t* array) {
26-
if (array->data == NULL || array->size == 0) {
27-
// Return empty array
28-
return mxCreateDoubleMatrix(0, 1, mxREAL);
29-
}
30-
31-
mxArray* mx_array = NULL;
32-
33-
switch (array->type) {
34-
case BSP_FLOAT64:
35-
mx_array = mxCreateNumericMatrix(array->size, 1, mxDOUBLE_CLASS, mxREAL);
36-
memcpy(mxGetPr(mx_array), array->data, array->size * sizeof(double));
37-
break;
38-
39-
case BSP_FLOAT32:
40-
mx_array = mxCreateNumericMatrix(array->size, 1, mxSINGLE_CLASS, mxREAL);
41-
memcpy(mxGetData(mx_array), array->data, array->size * sizeof(float));
42-
break;
43-
44-
case BSP_UINT64:
45-
mx_array = mxCreateNumericMatrix(array->size, 1, mxUINT64_CLASS, mxREAL);
46-
memcpy(mxGetData(mx_array), array->data, array->size * sizeof(uint64_t));
47-
break;
48-
49-
case BSP_UINT32:
50-
mx_array = mxCreateNumericMatrix(array->size, 1, mxUINT32_CLASS, mxREAL);
51-
memcpy(mxGetData(mx_array), array->data, array->size * sizeof(uint32_t));
52-
break;
22+
static inline void* bsp_matlab_malloc(size_t size) {
23+
void* ptr = mxMalloc(size);
24+
mexMakeMemoryPersistent(ptr);
25+
return ptr;
26+
}
5327

54-
case BSP_UINT16:
55-
mx_array = mxCreateNumericMatrix(array->size, 1, mxUINT16_CLASS, mxREAL);
56-
memcpy(mxGetData(mx_array), array->data, array->size * sizeof(uint16_t));
57-
break;
28+
static const bsp_allocator_t bsp_matlab_allocator = {
29+
.malloc = bsp_matlab_malloc, .free = mxFree};
5830

31+
static inline mxClassID get_mxClassID(bsp_type_t type) {
32+
switch (type) {
5933
case BSP_UINT8:
60-
mx_array = mxCreateNumericMatrix(array->size, 1, mxUINT8_CLASS, mxREAL);
61-
memcpy(mxGetData(mx_array), array->data, array->size * sizeof(uint8_t));
62-
break;
63-
64-
case BSP_INT64:
65-
mx_array = mxCreateNumericMatrix(array->size, 1, mxINT64_CLASS, mxREAL);
66-
memcpy(mxGetData(mx_array), array->data, array->size * sizeof(int64_t));
67-
break;
68-
69-
case BSP_INT32:
70-
mx_array = mxCreateNumericMatrix(array->size, 1, mxINT32_CLASS, mxREAL);
71-
memcpy(mxGetData(mx_array), array->data, array->size * sizeof(int32_t));
72-
break;
73-
34+
return mxUINT8_CLASS;
35+
case BSP_UINT16:
36+
return mxUINT16_CLASS;
37+
case BSP_UINT32:
38+
return mxUINT32_CLASS;
39+
case BSP_UINT64:
40+
return mxUINT64_CLASS;
41+
case BSP_INT8:
42+
return mxINT8_CLASS;
7443
case BSP_INT16:
75-
mx_array = mxCreateNumericMatrix(array->size, 1, mxINT16_CLASS, mxREAL);
76-
memcpy(mxGetData(mx_array), array->data, array->size * sizeof(int16_t));
77-
break;
44+
return mxINT16_CLASS;
45+
case BSP_INT32:
46+
return mxINT32_CLASS;
47+
case BSP_INT64:
48+
return mxINT64_CLASS;
49+
case BSP_FLOAT32:
50+
return mxSINGLE_CLASS;
51+
case BSP_FLOAT64:
52+
return mxDOUBLE_CLASS;
53+
case BSP_BINT8: // Treat BSP_BINT8 as UINT8
54+
return mxUINT8_CLASS;
55+
case BSP_COMPLEX_FLOAT32:
56+
return mxSINGLE_CLASS;
57+
case BSP_COMPLEX_FLOAT64:
58+
return mxDOUBLE_CLASS;
59+
default:
60+
return mxUNKNOWN_CLASS;
61+
}
62+
}
7863

79-
case BSP_INT8:
80-
mx_array = mxCreateNumericMatrix(array->size, 1, mxINT8_CLASS, mxREAL);
81-
memcpy(mxGetData(mx_array), array->data, array->size * sizeof(int8_t));
82-
break;
83-
84-
case BSP_BINT8:
85-
// Treat BSP_BINT8 as UINT8 as suggested
86-
mx_array = mxCreateNumericMatrix(array->size, 1, mxUINT8_CLASS, mxREAL);
87-
memcpy(mxGetData(mx_array), array->data, array->size * sizeof(int8_t));
88-
break;
89-
90-
case BSP_COMPLEX_FLOAT64: {
91-
mx_array = mxCreateNumericMatrix(array->size, 1, mxDOUBLE_CLASS, mxCOMPLEX);
92-
double* in_data =
93-
(double*) array->data; // Treat as array of adjacent real/imag pairs
94-
double* real_data = mxGetPr(mx_array);
95-
double* imag_data = mxGetPi(mx_array);
96-
for (size_t i = 0; i < array->size; i++) {
97-
real_data[i] = in_data[2 * i]; // Real part
98-
imag_data[i] = in_data[2 * i + 1]; // Imaginary part
99-
}
100-
break;
64+
static inline mxComplexity get_mxComplexity(bsp_type_t type) {
65+
if (type == BSP_COMPLEX_FLOAT32 || type == BSP_COMPLEX_FLOAT64) {
66+
return mxCOMPLEX;
67+
} else {
68+
return mxREAL;
10169
}
70+
}
10271

103-
case BSP_COMPLEX_FLOAT32: {
104-
mx_array = mxCreateNumericMatrix(array->size, 1, mxSINGLE_CLASS, mxCOMPLEX);
105-
float* in_data =
106-
(float*) array->data; // Treat as array of adjacent real/imag pairs
107-
float* real_data = (float*) mxGetData(mx_array);
108-
float* imag_data = (float*) mxGetImagData(mx_array);
109-
for (size_t i = 0; i < array->size; i++) {
110-
real_data[i] = in_data[2 * i]; // Real part
111-
imag_data[i] = in_data[2 * i + 1]; // Imaginary part
112-
}
113-
break;
72+
mxArray* bsp_array_to_matlab(bsp_array_t* array) {
73+
if (array->data == NULL || array->size == 0) {
74+
// Return empty array
75+
return mxCreateDoubleMatrix(1, 1, mxREAL);
11476
}
11577

116-
default:
117-
// Fallback: create empty array
118-
mx_array = mxCreateDoubleMatrix(0, 1, mxREAL);
78+
if (get_mxClassID(array->type) == mxUNKNOWN_CLASS) {
11979
mexWarnMsgIdAndTxt("BinSparse:UnsupportedType",
12080
"Unsupported array type %d, returning empty array",
12181
(int) array->type);
122-
break;
82+
return mxCreateDoubleMatrix(1, 1, mxREAL);
83+
}
84+
85+
mxArray* mx_array = NULL;
86+
87+
if ((array->allocator.malloc == bsp_matlab_allocator.malloc &&
88+
array->allocator.free == bsp_matlab_allocator.free) &&
89+
get_mxComplexity(array->type) == mxREAL) {
90+
// Create mx_array in a zero-copy fashion.
91+
92+
mx_array = mxCreateNumericMatrix(0, 1, get_mxClassID(array->type),
93+
get_mxComplexity(array->type));
94+
95+
mxSetData(mx_array, array->data);
96+
mxSetM(mx_array, array->size);
97+
98+
array->data = NULL;
99+
array->size = 0;
100+
} else {
101+
mx_array = mxCreateNumericMatrix(array->size, 1, get_mxClassID(array->type),
102+
get_mxComplexity(array->type));
103+
104+
if (get_mxComplexity(array->type) == mxREAL) {
105+
memcpy(mxGetData(mx_array), array->data,
106+
array->size * bsp_type_size(array->type));
107+
} else {
108+
if (array->type == BSP_COMPLEX_FLOAT32) {
109+
float* in_data =
110+
(float*) array->data; // Treat as array of adjacent real/imag pairs
111+
float* real_data = (float*) mxGetData(mx_array);
112+
float* imag_data = (float*) mxGetImagData(mx_array);
113+
for (size_t i = 0; i < array->size; i++) {
114+
real_data[i] = in_data[2 * i]; // Real part
115+
imag_data[i] = in_data[2 * i + 1]; // Imaginary part
116+
}
117+
} else {
118+
double* in_data =
119+
(double*) array->data; // Treat as array of adjacent real/imag pairs
120+
double* real_data = mxGetPr(mx_array);
121+
double* imag_data = mxGetPi(mx_array);
122+
for (size_t i = 0; i < array->size; i++) {
123+
real_data[i] = in_data[2 * i]; // Real part
124+
imag_data[i] = in_data[2 * i + 1]; // Imaginary part
125+
}
126+
}
127+
}
123128
}
124129

125130
return mx_array;
@@ -128,7 +133,7 @@ mxArray* bsp_array_to_matlab(const bsp_array_t* array) {
128133
/**
129134
* Convert bsp_matrix_t to MATLAB struct
130135
*/
131-
mxArray* bsp_matrix_to_matlab_struct(const bsp_matrix_t* matrix) {
136+
mxArray* bsp_matrix_to_matlab_struct(bsp_matrix_t* matrix) {
132137
const char* field_names[] = {
133138
"values", "indices_0", "indices_1", "pointers_to_1", "nrows",
134139
"ncols", "nnz", "is_iso", "format", "structure"};
@@ -210,7 +215,8 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
210215
}
211216

212217
// Read the matrix using Binsparse
213-
error = bsp_read_matrix(&matrix, filename, group);
218+
error =
219+
bsp_read_matrix_allocator(&matrix, filename, group, bsp_matlab_allocator);
214220

215221
if (error != BSP_SUCCESS) {
216222
// Clean up
@@ -228,6 +234,7 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
228234

229235
// Clean up
230236
bsp_destroy_matrix_t(&matrix);
237+
231238
if (filename)
232239
mxFree(filename);
233240
if (group)

0 commit comments

Comments
 (0)