Skip to content

Commit 87931c3

Browse files
committed
Fix bug of loading scalar np value
1 parent 0343a38 commit 87931c3

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tinyloader/loader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,10 @@ def share_ndarray(array: np.ndarray, buffer: SharedBuffer) -> SharedNDArray:
9292
f"Expected data ndarray size {array.nbytes} should be equal to {buffer.actual_block_size}"
9393
)
9494
shared_ndarray = np.ndarray(shape=array.shape, dtype=array.dtype, buffer=buffer.buf)
95-
shared_ndarray[:] = array[:]
95+
if np.isscalar(array) or array.ndim == 0:
96+
shared_ndarray[()] = array
97+
else:
98+
shared_ndarray[:] = array[:]
9699
return SharedNDArray(
97100
shape=shared_ndarray.shape,
98101
dtype=shared_ndarray.dtype,

0 commit comments

Comments
 (0)