Skip to content

Commit bf8341e

Browse files
fix: new routines GPU implementation
1 parent 8f392c8 commit bf8341e

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

src/manipulation.c

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,21 @@ NDArray_ConcatenateFlat(NDArray **arrays, int num_arrays)
337337

338338
for (iarrays = 0; iarrays < narrays; ++iarrays) {
339339
sliding_view->dimensions[0] = NDArray_NUMELEMENTS(arrays[iarrays]);
340-
341-
memcpy(sliding_view->data, arrays[iarrays]->data, sliding_view->strides[0] * NDArray_NUMELEMENTS(arrays[iarrays]));
342-
340+
if (NDArray_DEVICE(ret) == NDARRAY_DEVICE_CPU) {
341+
memcpy(sliding_view->data, arrays[iarrays]->data,
342+
sliding_view->strides[0] * NDArray_NUMELEMENTS(arrays[iarrays]));
343+
}
344+
#ifdef HAVE_CUBLAS
345+
if (NDArray_DEVICE(ret) == NDARRAY_DEVICE_GPU) {
346+
if (NDArray_NDIM(arrays[iarrays]) > 0) {
347+
vmemcpyd2d(arrays[iarrays]->data, sliding_view->data,
348+
sliding_view->strides[0] * NDArray_NUMELEMENTS(arrays[iarrays]));
349+
} else {
350+
vmemcpyh2d(arrays[iarrays]->data, sliding_view->data,
351+
sliding_view->strides[0] * NDArray_NUMELEMENTS(arrays[iarrays]));
352+
}
353+
}
354+
#endif
343355
/* Slide to the start of the next window */
344356
sliding_view->data += sliding_view->strides[0] * NDArray_NUMELEMENTS(arrays[iarrays]);
345357
}

src/ndarray.c

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,7 +1406,14 @@ NDArray_Load(char * filename)
14061406
NDArray*
14071407
NDArray_AssignRawScalar(NDArray *dst, NDArray *src)
14081408
{
1409-
NDArray_FDATA(dst)[0] = NDArray_FDATA(src)[0];
1409+
if (NDArray_DEVICE(dst) == NDARRAY_DEVICE_CPU) {
1410+
NDArray_FDATA(dst)[0] = NDArray_FDATA(src)[0];
1411+
}
1412+
#ifdef HAVE_CUBLAS
1413+
if (NDArray_DEVICE(dst) == NDARRAY_DEVICE_GPU) {
1414+
vmemcpyd2d(NDArray_DATA(src), NDArray_DATA(dst), sizeof(float));
1415+
}
1416+
#endif
14101417
return dst;
14111418
}
14121419

@@ -1426,7 +1433,8 @@ NDArray_CompareLists(int const *l1, int const *l2, int n)
14261433
int
14271434
raw_array_assign_array(int ndim, int const *shape,
14281435
NDArrayDescriptor *dst_dtype, char *dst_data, int const *dst_strides,
1429-
NDArrayDescriptor *src_dtype, char *src_data, int const *src_strides)
1436+
NDArrayDescriptor *src_dtype, char *src_data, int const *src_strides,
1437+
int device)
14301438
{
14311439
int idim;
14321440
int shape_it[NDARRAY_MAX_DIMS];
@@ -1461,7 +1469,15 @@ raw_array_assign_array(int ndim, int const *shape,
14611469
NDARRAY_RAW_ITER_START(idim, ndim, coord, shape_it) {
14621470
size = shape_it[0];
14631471
for (int i = 0; i < size; i++) {
1464-
memcpy(dst_data + (int)(i * dst_strides_it[0]), src_data + (int)(i * src_strides_it[0]), sizeof(float));
1472+
if (device == NDARRAY_DEVICE_CPU) {
1473+
memcpy(dst_data + (int) (i * dst_strides_it[0]), src_data + (int) (i * src_strides_it[0]),
1474+
sizeof(float));
1475+
}
1476+
if (device == NDARRAY_DEVICE_GPU) {
1477+
#ifdef HAVE_CUBLAS
1478+
vmemcpyd2d(src_data + (int) (i * src_strides_it[0]), dst_data + (int) (i * dst_strides_it[0]), sizeof(float));
1479+
#endif
1480+
}
14651481
}
14661482
} NDARRAY_RAW_ITER_TWO_NEXT(idim, ndim, coord, shape_it,
14671483
dst_data, dst_strides_it,
@@ -1524,7 +1540,7 @@ NDArray_AssignArray(NDArray *dst, NDArray *src)
15241540

15251541
if (raw_array_assign_array(NDArray_NDIM(dst), NDArray_SHAPE(dst),
15261542
NDArray_DESCRIPTOR(dst), NDArray_DATA(dst), NDArray_STRIDES(dst),
1527-
NDArray_DESCRIPTOR(src), NDArray_DATA(src), src_strides) < 0) {
1543+
NDArray_DESCRIPTOR(src), NDArray_DATA(src), src_strides, NDArray_DEVICE(src)) < 0) {
15281544
goto fail;
15291545
}
15301546

0 commit comments

Comments
 (0)