@@ -1406,7 +1406,14 @@ NDArray_Load(char * filename)
14061406NDArray *
14071407NDArray_AssignRawScalar (NDArray * dst , NDArray * src )
14081408{
1409- NDArray_FDATA (dst )[0 ] = NDArray_FDATA (src )[0 ];
1409+ if (NDArray_DEVICE (dst ) == NDARRAY_DEVICE_CPU ) {
1410+ NDArray_FDATA (dst )[0 ] = NDArray_FDATA (src )[0 ];
1411+ }
1412+ #ifdef HAVE_CUBLAS
1413+ if (NDArray_DEVICE (dst ) == NDARRAY_DEVICE_GPU ) {
1414+ vmemcpyd2d (NDArray_DATA (src ), NDArray_DATA (dst ), sizeof (float ));
1415+ }
1416+ #endif
14101417 return dst ;
14111418}
14121419
@@ -1426,7 +1433,8 @@ NDArray_CompareLists(int const *l1, int const *l2, int n)
14261433int
14271434raw_array_assign_array (int ndim , int const * shape ,
14281435 NDArrayDescriptor * dst_dtype , char * dst_data , int const * dst_strides ,
1429- NDArrayDescriptor * src_dtype , char * src_data , int const * src_strides )
1436+ NDArrayDescriptor * src_dtype , char * src_data , int const * src_strides ,
1437+ int device )
14301438{
14311439 int idim ;
14321440 int shape_it [NDARRAY_MAX_DIMS ];
@@ -1461,7 +1469,15 @@ raw_array_assign_array(int ndim, int const *shape,
14611469 NDARRAY_RAW_ITER_START (idim , ndim , coord , shape_it ) {
14621470 size = shape_it [0 ];
14631471 for (int i = 0 ; i < size ; i ++ ) {
1464- memcpy (dst_data + (int )(i * dst_strides_it [0 ]), src_data + (int )(i * src_strides_it [0 ]), sizeof (float ));
1472+ if (device == NDARRAY_DEVICE_CPU ) {
1473+ memcpy (dst_data + (int ) (i * dst_strides_it [0 ]), src_data + (int ) (i * src_strides_it [0 ]),
1474+ sizeof (float ));
1475+ }
1476+ if (device == NDARRAY_DEVICE_GPU ) {
1477+ #ifdef HAVE_CUBLAS
1478+ vmemcpyd2d (src_data + (int ) (i * src_strides_it [0 ]), dst_data + (int ) (i * dst_strides_it [0 ]), sizeof (float ));
1479+ #endif
1480+ }
14651481 }
14661482 } NDARRAY_RAW_ITER_TWO_NEXT (idim , ndim , coord , shape_it ,
14671483 dst_data , dst_strides_it ,
@@ -1524,7 +1540,7 @@ NDArray_AssignArray(NDArray *dst, NDArray *src)
15241540
15251541 if (raw_array_assign_array (NDArray_NDIM (dst ), NDArray_SHAPE (dst ),
15261542 NDArray_DESCRIPTOR (dst ), NDArray_DATA (dst ), NDArray_STRIDES (dst ),
1527- NDArray_DESCRIPTOR (src ), NDArray_DATA (src ), src_strides ) < 0 ) {
1543+ NDArray_DESCRIPTOR (src ), NDArray_DATA (src ), src_strides , NDArray_DEVICE ( src ) ) < 0 ) {
15281544 goto fail ;
15291545 }
15301546
0 commit comments