Skip to content

Commit 1e68d88

Browse files
Merge pull request #393 from ricardosp4/nd-transpose
Performance optimization for `permute_dims`
2 parents 51cba77 + a6a86ce commit 1e68d88

File tree

3 files changed

+173
-100
lines changed

3 files changed

+173
-100
lines changed

bench/ndarray/transpose.ipynb

Lines changed: 121 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
"import blosc2\n",
99
"import time\n",
1010
"import plotly.express as px\n",
11-
"import pandas as pd"
11+
"import pandas as pd\n",
12+
"\n",
13+
"from blosc2 import NDArray\n",
14+
"from typing import Any\n",
15+
"\n",
16+
"import builtins"
1217
],
1318
"id": "55765646130156ef",
1419
"outputs": [],
@@ -18,35 +23,104 @@
1823
"metadata": {},
1924
"cell_type": "code",
2025
"source": [
21-
"sizes = [(100, 100), (500, 500), (500, 1000), (1000, 1000), (2000, 2000), (3000, 3000), (4000, 4000), (5000, 5000)]\n",
22-
"sizes_mb = [(np.prod(size) * 8) / 2**20 for size in sizes] # Convert to MB\n",
23-
"results = {\"numpy\": [], \"blosc2\": []}"
26+
"def new_permute_dims(arr: NDArray, axes: tuple[int] | list[int] | None = None, **kwargs: Any) -> NDArray:\n",
27+
" if np.isscalar(arr) or arr.ndim < 2:\n",
28+
" return arr\n",
29+
"\n",
30+
" ndim = arr.ndim\n",
31+
" if axes is None:\n",
32+
" axes = tuple(range(ndim))[::-1]\n",
33+
" else:\n",
34+
" axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes)\n",
35+
" if sorted(axes) != list(range(ndim)):\n",
36+
" raise ValueError(f\"axes {axes} is not a valid permutation of {ndim} dimensions\")\n",
37+
"\n",
38+
" new_shape = tuple(arr.shape[axis] for axis in axes)\n",
39+
" if \"chunks\" not in kwargs or kwargs[\"chunks\"] is None:\n",
40+
" kwargs[\"chunks\"] = tuple(arr.chunks[axis] for axis in axes)\n",
41+
"\n",
42+
" result = blosc2.empty(shape=new_shape, dtype=arr.dtype, **kwargs)\n",
43+
"\n",
44+
" # Precomputar info por dimensión\n",
45+
" chunks = arr.chunks\n",
46+
" shape = arr.shape\n",
47+
"\n",
48+
" for info in arr.iterchunks_info():\n",
49+
" coords = info.coords\n",
50+
" start_stop = [\n",
51+
" (coord * chunk, builtins.min(chunk * (coord + 1), dim))\n",
52+
" for coord, chunk, dim in zip(coords, chunks, shape)\n",
53+
" ]\n",
54+
"\n",
55+
" src_slice = tuple(slice(start, stop) for start, stop in start_stop)\n",
56+
" dst_slice = tuple(slice(start_stop[ax][0], start_stop[ax][1]) for ax in axes)\n",
57+
"\n",
58+
" transposed = np.transpose(arr[src_slice], axes=axes)\n",
59+
" result[dst_slice] = np.ascontiguousarray(transposed)\n",
60+
"\n",
61+
" return result"
2462
],
2563
"id": "1cfb7daa6eee1401",
2664
"outputs": [],
2765
"execution_count": null
2866
},
2967
{
30-
"metadata": {},
68+
"metadata": {
69+
"jupyter": {
70+
"is_executing": true
71+
}
72+
},
3173
"cell_type": "code",
3274
"source": [
33-
"for method in [\"numpy\", \"blosc2\"]:\n",
34-
" for size in sizes:\n",
35-
" arr = np.random.rand(*size)\n",
36-
" arr_b2 = blosc2.asarray(arr)\n",
75+
"def validate_results(result_orig, result_new, shape):\n",
76+
" if not np.allclose(result_orig[:], result_new[:]):\n",
77+
" raise ValueError(f\"Mismatch found for shape {shape}\")\n",
78+
"\n",
79+
"shapes = [\n",
80+
" (100, 100), (2000, 2000), (3000, 3000), (4000, 4000), (3000, 7000),\n",
81+
" (5000, 5000), (6000, 6000), (7000, 7000), (8000, 8000), (6000, 12000),\n",
82+
" (9000, 9000), (10000, 10000), (10500, 10500), (11000, 11000), (11500, 11500),\n",
83+
" (12000, 12000), (12500, 12500), (13000, 13000), (13500, 13500), (14000, 14000),\n",
84+
" (14500, 14500), (15000, 15000), (16000, 16000), (16500, 16500), (17000, 17000),\n",
85+
" (17500, 17500), (18000, 18000)\n",
86+
"]\n",
87+
"\n",
88+
"sizes = []\n",
89+
"time_total = []\n",
90+
"chunk_labels = []\n",
3791
"\n",
38-
" start_time = time.perf_counter()\n",
92+
"def numpy_permute(arr: np.ndarray, axes: tuple[int] | list[int] | None = None) -> np.ndarray:\n",
93+
" if axes is None:\n",
94+
" axes = range(arr.ndim)[::-1]\n",
95+
" return np.transpose(arr, axes=axes).copy()\n",
3996
"\n",
40-
" if method == \"numpy\":\n",
41-
" np.transpose(arr).copy()\n",
42-
" elif method == \"blosc2\":\n",
43-
" blosc2.transpose(arr_b2)\n",
97+
"for shape in shapes:\n",
98+
" size_mb = (np.prod(shape) * 8) / (2 ** 20)\n",
99+
"\n",
100+
" # NumPy transpose\n",
101+
" matrix_numpy = np.linspace(0, 1, np.prod(shape)).reshape(shape)\n",
102+
" t0 = time.perf_counter()\n",
103+
" result_numpy = numpy_permute(matrix_numpy)\n",
104+
" t1 = time.perf_counter()\n",
105+
" time_total.append(t1 - t0)\n",
106+
" sizes.append(size_mb)\n",
107+
" chunk_labels.append(\"numpy.transpose()\")\n",
44108
"\n",
45-
" end_time = time.perf_counter()\n",
46-
" time_b = end_time - start_time\n",
109+
" # New permute dims (optimized)\n",
110+
" matrix_blosc2 = blosc2.linspace(0, 1, np.prod(shape), shape=shape)\n",
111+
" t0 = time.perf_counter()\n",
112+
" result_new_perm = new_permute_dims(matrix_blosc2)\n",
113+
" t1 = time.perf_counter()\n",
114+
" time_total.append(t1 - t0)\n",
115+
" sizes.append(size_mb)\n",
116+
" chunk_labels.append(\"blosc2.permute_dims()\")\n",
117+
"\n",
118+
" try:\n",
119+
" validate_results(result_new_perm, result_numpy, shape)\n",
120+
" except ValueError as e:\n",
121+
" print(e)\n",
47122
"\n",
48-
" print(f\"{method}: shape={size}, Performance = {time_b:.6f} s\")\n",
49-
" results[method].append(time_b)"
123+
" print(f\"Shape={shape}, Chunk={matrix_blosc2.chunks}: permute_dims={time_total[-2]:.6f}s, numpy={time_total[-1]:.6f}s\")"
50124
],
51125
"id": "384d0ad7983a8d26",
52126
"outputs": [],
@@ -57,21 +131,21 @@
57131
"cell_type": "code",
58132
"source": [
59133
"df = pd.DataFrame({\n",
60-
" \"Matrix Size (MB)\": sizes_mb,\n",
61-
" \"NumPy Time (s)\": results[\"numpy\"],\n",
62-
" \"Blosc2 Time (s)\": results[\"blosc2\"]\n",
134+
" \"Matrix Size (MB)\": sizes,\n",
135+
" \"Time (s)\": time_total,\n",
136+
" \"Implementation\": chunk_labels\n",
63137
"})\n",
64138
"\n",
65139
"fig = px.line(df,\n",
66140
" x=\"Matrix Size (MB)\",\n",
67-
" y=[\"NumPy Time (s)\", \"Blosc2 Time (s)\"],\n",
68-
" title=\"Performance of Matrix Transposition (NumPy vs Blosc2)\",\n",
69-
" labels={\"value\": \"Time (s)\", \"variable\": \"Method\"},\n",
141+
" y=\"Time (s)\",\n",
142+
" color=\"Implementation\",\n",
143+
" title=\"Performance: NumPy vs Blosc2\",\n",
144+
" width=1000, height=600,\n",
70145
" markers=True)\n",
71-
"\n",
72146
"fig.show()"
73147
],
74-
"id": "c71ffb39eb28992c",
148+
"id": "786b8b7b5ea95225",
75149
"outputs": [],
76150
"execution_count": null
77151
},
@@ -81,15 +155,14 @@
81155
"source": [
82156
"%%time\n",
83157
"shapes = [\n",
84-
" (100, 100), (2000, 2000), (3000, 3000), (4000, 4000), (3000, 7000)\n",
85-
" # (5000, 5000), (6000, 6000), (7000, 7000), (8000, 8000), (6000, 12000),\n",
86-
" # (9000, 9000), (10000, 10000),\n",
87-
" # (10500, 10500), (11000, 11000), (11500, 11500), (12000, 12000),\n",
88-
" # (12500, 12500), (13000, 13000), (13500, 13500), (14000, 14000),\n",
89-
" # (14500, 14500), (15000, 15000), (15500, 15500), (16000, 16000),\n",
90-
" # (16500, 16500), (17000, 17000)\n",
158+
" (100, 100), (1000, 1000), (2000, 2000), (3000, 3000), (4000, 4000),\n",
159+
" (5000, 5000), (6000, 6000), (7000, 7000), (8000, 8000), (9000, 9000),\n",
160+
" (9500, 9500), (10000, 10000), (10500, 10500), (11000, 11000), (11500, 11500),\n",
161+
" (12000, 12000), (12500, 12500), (13000, 13000), (13500, 13500), (14000, 14000),\n",
162+
" (14500, 14500), (15000, 15000), (16000, 16000), (16500, 16500), (17000, 17000)\n",
91163
"]\n",
92-
"chunkshapes = [None, (150, 300), (200, 500), (500, 200), (1000, 1000)]\n",
164+
"\n",
165+
"chunkshapes = [None, (150, 300), (1000, 1000), (4000, 4000)]\n",
93166
"\n",
94167
"sizes = []\n",
95168
"time_total = []\n",
@@ -111,18 +184,27 @@
111184
" print(f\"NumPy: Shape={shape}, Time = {numpy_time:.6f} s\")\n",
112185
"\n",
113186
" for chunk in chunkshapes:\n",
114-
" matrix_blosc2 = blosc2.asarray(matrix_np, chunks=chunk)\n",
187+
" matrix_blosc2 = blosc2.asarray(matrix_np)\n",
188+
" matrix_blosc2 = blosc2.linspace(0, 1, np.prod(shape), shape=shape)\n",
115189
"\n",
116190
" t0 = time.perf_counter()\n",
117-
" result_blosc2 = blosc2.transpose(matrix_blosc2)\n",
191+
" result_blosc2 = new_permute_dims(matrix_blosc2, chunks=chunk)\n",
118192
" blosc2_time = time.perf_counter() - t0\n",
119193
"\n",
120194
" sizes.append(size_mb)\n",
121195
" time_total.append(blosc2_time)\n",
122196
" chunk_labels.append(f\"{chunk[0]}x{chunk[1]}\" if chunk else \"Auto\")\n",
123197
"\n",
124-
" print(f\"Blosc2: Shape={shape}, Chunks = {matrix_blosc2.chunks}, Time = {blosc2_time:.6f} s\")\n",
125-
"\n",
198+
" print(f\"Blosc2: Shape={shape}, Chunks = {result_blosc2.chunks}, Time = {blosc2_time:.6f} s\")"
199+
],
200+
"id": "bcdd8aa5f65df561",
201+
"outputs": [],
202+
"execution_count": null
203+
},
204+
{
205+
"metadata": {},
206+
"cell_type": "code",
207+
"source": [
126208
"df = pd.DataFrame({\n",
127209
" \"Matrix Size (MB)\": sizes,\n",
128210
" \"Time (s)\": time_total,\n",
@@ -135,17 +217,10 @@
135217
" color=\"Chunk Shape\",\n",
136218
" title=\"Performance of Matrix Transposition (Blosc2 vs NumPy)\",\n",
137219
" labels={\"value\": \"Time (s)\", \"variable\": \"Metric\"},\n",
220+
" width=1000, height=600,\n",
138221
" markers=True)\n",
139222
"fig.show()"
140223
],
141-
"id": "bcdd8aa5f65df561",
142-
"outputs": [],
143-
"execution_count": null
144-
},
145-
{
146-
"metadata": {},
147-
"cell_type": "code",
148-
"source": "",
149224
"id": "1d2f48f370ba7e7a",
150225
"outputs": [],
151226
"execution_count": null

src/blosc2/ndarray.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3876,11 +3876,11 @@ def matmul(x1: NDArray, x2: NDArray, **kwargs: Any) -> NDArray:
38763876
r = x2.chunks[-1]
38773877

38783878
for row in range(0, n, p):
3879-
row_end = (row + p) if (row + p) < n else n
3879+
row_end = builtins.min(row + p, n)
38803880
for col in range(0, m, q):
3881-
col_end = (col + q) if (col + q) < m else m
3881+
col_end = builtins.min(col + q, m)
38823882
for aux in range(0, k, r):
3883-
aux_end = (aux + r) if (aux + r) < k else k
3883+
aux_end = builtins.min(aux + r, k)
38843884
bx1 = x1[row:row_end, aux:aux_end]
38853885
bx2 = x2[aux:aux_end, col:col_end]
38863886
result[row:row_end, col:col_end] += np.matmul(bx1, bx2)
@@ -3951,6 +3951,7 @@ def permute_dims(arr: NDArray, axes: tuple[int] | list[int] | None = None, **kwa
39513951
[[13, 14, 15, 16],
39523952
[17, 18, 19, 20],
39533953
[21, 22, 23, 24]]])
3954+
39543955
>>> at = blosc2.permute_dims(a, axes=(1, 0, 2))
39553956
>>> at[:]
39563957
array([[[ 1, 2, 3, 4],
@@ -3960,37 +3961,39 @@ def permute_dims(arr: NDArray, axes: tuple[int] | list[int] | None = None, **kwa
39603961
[[ 9, 10, 11, 12],
39613962
[21, 22, 23, 24]]])
39623963
"""
3963-
39643964
if np.isscalar(arr) or arr.ndim < 2:
39653965
return arr
39663966

3967+
ndim = arr.ndim
3968+
39673969
if axes is None:
3968-
axes = range(arr.ndim)[::-1]
3970+
axes = tuple(range(ndim))[::-1]
39693971
else:
3970-
axes_transformed = tuple(axis if axis >= 0 else arr.ndim + axis for axis in axes)
3971-
if sorted(axes_transformed) != list(range(arr.ndim)):
3972-
raise ValueError(f"axes {axes} is not a valid permutation of {arr.ndim} dimensions")
3973-
axes = axes_transformed
3972+
axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes)
3973+
if sorted(axes) != list(range(ndim)):
3974+
raise ValueError(f"axes {axes} is not a valid permutation of {ndim} dimensions")
39743975

39753976
new_shape = tuple(arr.shape[axis] for axis in axes)
3976-
if "chunks" not in kwargs:
3977+
if "chunks" not in kwargs or kwargs["chunks"] is None:
39773978
kwargs["chunks"] = tuple(arr.chunks[axis] for axis in axes)
39783979

39793980
result = blosc2.empty(shape=new_shape, dtype=arr.dtype, **kwargs)
39803981

3981-
chunk_slices = [
3982-
[slice(start, builtins.min(dim, start + chunk)) for start in range(0, dim, chunk)]
3983-
for dim, chunk in zip(arr.shape, arr.chunks, strict=False)
3984-
]
3982+
chunks = arr.chunks
3983+
shape = arr.shape
3984+
3985+
for info in arr.iterchunks_info():
3986+
coords = info.coords
3987+
start_stop = [
3988+
(coord * chunk, builtins.min(chunk * (coord + 1), dim))
3989+
for coord, chunk, dim in zip(coords, chunks, shape, strict=False)
3990+
]
39853991

3986-
block_counts = [len(s) for s in chunk_slices]
3987-
grid = np.indices(block_counts).reshape(len(block_counts), -1).T
3992+
src_slice = tuple(slice(start, stop) for start, stop in start_stop)
3993+
dst_slice = tuple(slice(start_stop[ax][0], start_stop[ax][1]) for ax in axes)
39883994

3989-
for idx in grid:
3990-
block_slices = tuple(chunk_slices[axis][i] for axis, i in enumerate(idx))
3991-
block = arr[block_slices]
3992-
target_slices = tuple(block_slices[axis] for axis in axes)
3993-
result[target_slices] = np.transpose(block, axes=axes).copy()
3995+
transposed = np.transpose(arr[src_slice], axes=axes)
3996+
result[dst_slice] = np.ascontiguousarray(transposed)
39943997

39953998
return result
39963999

@@ -4024,7 +4027,7 @@ def transpose(x, **kwargs: Any) -> NDArray:
40244027
stacklevel=2,
40254028
)
40264029

4027-
# If arguments are dimension < 2 they are returned
4030+
# If arguments are dimension < 2, they are returned
40284031
if np.isscalar(x) or x.ndim < 2:
40294032
return x
40304033

0 commit comments

Comments
 (0)