Skip to content

Commit 5a54cc1

Browse files
committed
Bench numba and make a performance plot
1 parent 1a39739 commit 5a54cc1

File tree

1 file changed

+82
-6
lines changed

1 file changed

+82
-6
lines changed

bench/ndarray/jit-numpy-funcs.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010

1111
import numpy as np
1212
from time import time
13+
import numba
1314

1415
import blosc2
1516

16-
N = 5_000 # working size of ~200 MB
17+
N = 30_000 # working size is N * N * 4 * 2 bytes = 7.2 GB
1718

1819
# Create some sample data
1920
t0 = time()
@@ -29,7 +30,7 @@
2930
print(f"Time to create data (NDArray): {time() - t0:.3f} s")
3031
#print("a.chunks: ", a.chunks, "a.blocks: ", a.blocks)
3132

32-
# Compare with NumPy
33+
# Take NumPy as reference
3334
def expr_numpy(a, b, c):
3435
# return np.cumsum(((na**3 + np.sin(na * 2)) < nc) & (nb > 0), axis=0)
3536
# The next is equally illustrative, but can achieve better speedups
@@ -40,18 +41,93 @@ def expr_jit(a, b, c):
4041
# return np.cumsum(((a**3 + np.sin(a * 2)) < c) & (b > 0), axis=0)
4142
return np.sum(((a**3 + np.sin(a * 2)) < np.cumulative_sum(c)) & (b > 0), axis=1)
4243

44+
@numba.jit
45+
def expr_numba(a, b, c):
46+
# numba fails with the next with:
47+
# """No implementation of function Function(<function cumsum at 0x101a30720>) found for signature:
48+
# >>> cumsum(array(bool, 2d, C), axis=Literal[int](0))"""
49+
# return np.cumsum(((a**3 + np.sin(a * 2)) < c) & (b > 0), axis=0)
50+
# The np.cumulative_sum() is not supported yet by numba
51+
# return np.sum(((a**3 + np.sin(a * 2)) < np.cumulative_sum(c)) & (b > 0), axis=1)
52+
return np.sum(((a**3 + np.sin(a * 2)) < np.cumsum(c)) & (b > 0), axis=1)
53+
54+
times = []
4355
# Call the NumPy function natively on NumPy containers
4456
t0 = time()
4557
result = expr_numpy(a, b, c)
4658
tref = time() - t0
59+
times.append(tref)
4760
print(f"Time for native NumPy: {tref:.3f} s")
4861

49-
# Call the function with the jit decorator, using NumPy containers
62+
# Call the function with the blosc2.jit decorator, using NumPy containers
5063
t0 = time()
5164
result = expr_jit(na, nb, nc)
52-
print(f"Time for blosc2.jit (np.ndarray): {time() - t0:.3f} s, speedup: {tref / (time() - t0):.2f}x")
65+
times.append(time() - t0)
66+
print(f"Time for blosc2.jit (np.ndarray): {times[-1]:.3f} s, speedup: {tref / times[-1]:.2f}x")
5367

54-
# Call the function with the jit decorator, using Blosc2 containers
68+
# Call the function with the blosc2.jit decorator, using Blosc2 containers
5569
t0 = time()
5670
result = expr_jit(a, b, c)
57-
print(f"Time for blosc2.jit (NDArray): {time() - t0:.3f} s, speedup: {tref / (time() - t0):.2f}x")
71+
times.append(time() - t0)
72+
print(f"Time for blosc2.jit (blosc2.NDArray): {times[-1]:.3f} s, speedup: {tref / times[-1]:.2f}x")
73+
74+
# Call the function with the jit decorator, using NumPy containers
75+
t0 = time()
76+
result = expr_numba(na, nb, nc)
77+
times.append(time() - t0)
78+
print(f"Time for numba.jit (np.ndarray, first run): {times[-1]:.3f} s, speedup: {tref / times[-1]:.2f}x")
79+
t0 = time()
80+
result = expr_numba(na, nb, nc)
81+
times.append(time() - t0)
82+
print(f"Time for numba.jit (np.ndarray): {times[-1]:.3f} s, speedup: {tref / times[-1]:.2f}x")
83+
84+
85+
# Plot the results using an horizontal bar chart
86+
import matplotlib.pyplot as plt
87+
88+
labels = ['NumPy', 'blosc2.jit (np.ndarray)', 'blosc2.jit (blosc2.NDArray)', 'numba.jit (first run)', 'numba.jit (cached)']
89+
# Reverse the labels and times arrays
90+
labels_rev = labels[::-1]
91+
times_rev = times[::-1]
92+
93+
# Create position indices for the reversed data
94+
x = np.arange(len(labels_rev))
95+
96+
fig, ax = plt.subplots(figsize=(10, 6))
97+
98+
# Define colors for different categories
99+
colors = ['#FF9999', '#66B2FF', '#66B2FF', '#99CC99', '#99CC99'] # Red for NumPy, Blue for blosc2, Green for numba
100+
# Note: colors are in reverse order to match the reversed data
101+
colors_rev = colors[::-1]
102+
103+
bars = ax.barh(x, times_rev, height=0.35, color=colors_rev, label='Time (s)')
104+
105+
# Add speedup annotations at the end of each bar
106+
# NumPy is our reference (the first element in original array, last in reversed)
107+
numpy_time = tref # Reference time for NumPy
108+
for i, (bar, time) in enumerate(zip(bars, times_rev)):
109+
# Skip the NumPy bar since it's our reference
110+
if i < len(times_rev) - 1: # Skip the last bar (NumPy)
111+
speedup = numpy_time / time
112+
ax.annotate(f'({speedup:.1f}x)',
113+
(bar.get_width() + 0.05, bar.get_y() + bar.get_height()/2),
114+
va='center')
115+
116+
ax.set_xlabel('Time (s)')
117+
ax.set_title("""Compute: np.sum(((a**3 + np.sin(a * 2)) < np.cumsum(c)) & (b > 0), axis=1)
118+
(Execution time for different decorators)""")
119+
ax.set_yticks(x)
120+
ax.set_yticklabels(labels_rev)
121+
122+
# Create custom legend with only one entry per category
123+
from matplotlib.patches import Patch
124+
legend_elements = [
125+
Patch(facecolor='#FF9999', label='NumPy'),
126+
Patch(facecolor='#66B2FF', label='blosc2.jit'),
127+
Patch(facecolor='#99CC99', label='numba.jit')
128+
]
129+
ax.legend(handles=legend_elements, loc='best')
130+
131+
plt.tight_layout()
132+
plt.savefig('jit_benchmark_comparison.png', dpi=300, bbox_inches='tight')
133+
plt.show()

0 commit comments

Comments
 (0)