Skip to content

Commit deeb48e

Browse files
committed
[FEAT] Optimized permute_dims()
1 parent e06f954 commit deeb48e

File tree

1 file changed

+31
-31
lines changed

1 file changed

+31
-31
lines changed

src/blosc2/ndarray.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3871,11 +3871,11 @@ def matmul(x1: NDArray, x2: NDArray, **kwargs: Any) -> NDArray:
38713871
r = x2.chunks[-1]
38723872

38733873
for row in range(0, n, p):
3874-
row_end = (row + p) if (row + p) < n else n
3874+
row_end = builtins.min(row + p, n)
38753875
for col in range(0, m, q):
3876-
col_end = (col + q) if (col + q) < m else m
3876+
col_end = builtins.min(col + q, m)
38773877
for aux in range(0, k, r):
3878-
aux_end = (aux + r) if (aux + r) < k else k
3878+
aux_end = builtins.min(aux + r, k)
38793879
bx1 = x1[row:row_end, aux:aux_end]
38803880
bx2 = x2[aux:aux_end, col:col_end]
38813881
result[row:row_end, col:col_end] += np.matmul(bx1, bx2)
@@ -3905,8 +3905,8 @@ def permute_dims(arr: NDArray, axes: tuple[int] | list[int] | None = None, **kwa
39053905
39063906
Returns
39073907
-------
3908-
out: :ref:`NDArray`
3909-
A Blosc2 :ref:`NDArray` with axes transposed.
3908+
out:: ref:`NDArray`
3909+
A Blosc2: ref:`NDArray` with axes transposed.
39103910
39113911
Raises
39123912
------
@@ -3947,8 +3947,6 @@ def permute_dims(arr: NDArray, axes: tuple[int] | list[int] | None = None, **kwa
39473947
[17, 18, 19, 20],
39483948
[21, 22, 23, 24]]])
39493949
3950-
3951-
39523950
>>> at = blosc2.permute_dims(a, axes=(1, 0, 2))
39533951
>>> at[:]
39543952
array([[[ 1, 2, 3, 4],
@@ -3959,37 +3957,39 @@ def permute_dims(arr: NDArray, axes: tuple[int] | list[int] | None = None, **kwa
39593957
[21, 22, 23, 24]]])
39603958
39613959
"""
3962-
39633960
if np.isscalar(arr) or arr.ndim < 2:
39643961
return arr
39653962

3963+
ndim = arr.ndim
3964+
39663965
if axes is None:
3967-
axes = range(arr.ndim)[::-1]
3966+
axes = tuple(range(ndim))[::-1]
39683967
else:
3969-
axes_transformed = tuple(axis if axis >= 0 else arr.ndim + axis for axis in axes)
3970-
if sorted(axes_transformed) != list(range(arr.ndim)):
3971-
raise ValueError(f"axes {axes} is not a valid permutation of {arr.ndim} dimensions")
3972-
axes = axes_transformed
3968+
axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes)
3969+
if sorted(axes) != list(range(ndim)):
3970+
raise ValueError(f"axes {axes} is not a valid permutation of {ndim} dimensions")
39733971

39743972
new_shape = tuple(arr.shape[axis] for axis in axes)
3975-
if "chunks" not in kwargs:
3973+
if "chunks" not in kwargs or kwargs["chunks"] is None:
39763974
kwargs["chunks"] = tuple(arr.chunks[axis] for axis in axes)
39773975

39783976
result = blosc2.empty(shape=new_shape, dtype=arr.dtype, **kwargs)
39793977

3980-
chunk_slices = [
3981-
[slice(start, builtins.min(dim, start + chunk)) for start in range(0, dim, chunk)]
3982-
for dim, chunk in zip(arr.shape, arr.chunks, strict=False)
3983-
]
3978+
chunks = arr.chunks
3979+
shape = arr.shape
3980+
3981+
for info in arr.iterchunks_info():
3982+
coords = info.coords
3983+
start_stop = [
3984+
(coord * chunk, builtins.min(chunk * (coord + 1), dim))
3985+
for coord, chunk, dim in zip(coords, chunks, shape, strict=False)
3986+
]
39843987

3985-
block_counts = [len(s) for s in chunk_slices]
3986-
grid = np.indices(block_counts).reshape(len(block_counts), -1).T
3988+
src_slice = tuple(slice(start, stop) for start, stop in start_stop)
3989+
dst_slice = tuple(slice(start_stop[ax][0], start_stop[ax][1]) for ax in axes)
39873990

3988-
for idx in grid:
3989-
block_slices = tuple(chunk_slices[axis][i] for axis, i in enumerate(idx))
3990-
block = arr[block_slices]
3991-
target_slices = tuple(block_slices[axis] for axis in axes)
3992-
result[target_slices] = np.transpose(block, axes=axes).copy()
3991+
transposed = np.transpose(arr[src_slice], axes=axes)
3992+
result[dst_slice] = np.ascontiguousarray(transposed)
39933993

39943994
return result
39953995

@@ -4002,14 +4002,14 @@ def transpose(x, **kwargs: Any) -> NDArray:
40024002
40034003
Parameters
40044004
----------
4005-
x: :ref:`NDArray`
4005+
x:: ref:`NDArray`
40064006
The input array.
40074007
kwargs: Any, optional
40084008
Keyword arguments that are supported by the :func:`empty` constructor.
40094009
40104010
Returns
40114011
-------
4012-
out: :ref:`NDArray`
4012+
out:: ref:`NDArray`
40134013
The Blosc2 NDArray with axes transposed.
40144014
40154015
References
@@ -4023,7 +4023,7 @@ def transpose(x, **kwargs: Any) -> NDArray:
40234023
stacklevel=2,
40244024
)
40254025

4026-
# If arguments are dimension < 2 they are returned
4026+
# If arguments are dimension < 2, they are returned
40274027
if np.isscalar(x) or x.ndim < 2:
40284028
return x
40294029

@@ -4039,14 +4039,14 @@ def matrix_transpose(arr: NDArray, **kwargs: Any) -> NDArray:
40394039
40404040
Parameters
40414041
----------
4042-
arr: :ref:`NDArray`
4042+
arr:: ref:`NDArray`
40434043
The input NDArray having shape ``(..., M, N)`` and whose innermost two dimensions form
40444044
``MxN`` matrices.
40454045
40464046
Returns
40474047
-------
4048-
out: :ref:`NDArray`
4049-
A new :ref:`NDArray` containing the transpose for each matrix and having shape
4048+
out:: ref:`NDArray`
4049+
A new: ref:`NDArray` containing the transpose for each matrix and having shape
40504050
``(..., N, M)``.
40514051
"""
40524052
axes = None

0 commit comments

Comments
 (0)