1010
1111import numpy as np
1212from time import time
13+ import numba
1314
1415import blosc2
1516
16- N = 5_000 # working size of ~200 MB
17+ N = 30_000 # working size is N * N * 4 * 2 bytes = 7.2 GB
1718
1819# Create some sample data
1920t0 = time ()
2930print (f"Time to create data (NDArray): { time () - t0 :.3f} s" )
3031#print("a.chunks: ", a.chunks, "a.blocks: ", a.blocks)
3132
32- # Compare with NumPy
33+ # Take NumPy as reference
3334def expr_numpy (a , b , c ):
3435 # return np.cumsum(((na**3 + np.sin(na * 2)) < nc) & (nb > 0), axis=0)
3536 # The next is equally illustrative, but can achieve better speedups
@@ -40,18 +41,93 @@ def expr_jit(a, b, c):
4041 # return np.cumsum(((a**3 + np.sin(a * 2)) < c) & (b > 0), axis=0)
4142 return np .sum (((a ** 3 + np .sin (a * 2 )) < np .cumulative_sum (c )) & (b > 0 ), axis = 1 )
4243
44+ @numba .jit
45+ def expr_numba (a , b , c ):
46+ # numba fails with the next with:
47+ # """No implementation of function Function(<function cumsum at 0x101a30720>) found for signature:
48+ # >>> cumsum(array(bool, 2d, C), axis=Literal[int](0))"""
49+ # return np.cumsum(((a**3 + np.sin(a * 2)) < c) & (b > 0), axis=0)
50+ # The np.cumulative_sum() is not supported yet by numba
51+ # return np.sum(((a**3 + np.sin(a * 2)) < np.cumulative_sum(c)) & (b > 0), axis=1)
52+ return np .sum (((a ** 3 + np .sin (a * 2 )) < np .cumsum (c )) & (b > 0 ), axis = 1 )
53+
54+ times = []
4355# Call the NumPy function natively on NumPy containers
4456t0 = time ()
4557result = expr_numpy (a , b , c )
4658tref = time () - t0
59+ times .append (tref )
4760print (f"Time for native NumPy: { tref :.3f} s" )
4861
49- # Call the function with the jit decorator, using NumPy containers
62+ # Call the function with the blosc2. jit decorator, using NumPy containers
5063t0 = time ()
5164result = expr_jit (na , nb , nc )
52- print (f"Time for blosc2.jit (np.ndarray): { time () - t0 :.3f} s, speedup: { tref / (time () - t0 ):.2f} x" )
65+ times .append (time () - t0 )
66+ print (f"Time for blosc2.jit (np.ndarray): { times [- 1 ]:.3f} s, speedup: { tref / times [- 1 ]:.2f} x" )
5367
54- # Call the function with the jit decorator, using Blosc2 containers
68+ # Call the function with the blosc2. jit decorator, using Blosc2 containers
5569t0 = time ()
5670result = expr_jit (a , b , c )
57- print (f"Time for blosc2.jit (NDArray): { time () - t0 :.3f} s, speedup: { tref / (time () - t0 ):.2f} x" )
71+ times .append (time () - t0 )
72+ print (f"Time for blosc2.jit (blosc2.NDArray): { times [- 1 ]:.3f} s, speedup: { tref / times [- 1 ]:.2f} x" )
73+
74+ # Call the function with the jit decorator, using NumPy containers
75+ t0 = time ()
76+ result = expr_numba (na , nb , nc )
77+ times .append (time () - t0 )
78+ print (f"Time for numba.jit (np.ndarray, first run): { times [- 1 ]:.3f} s, speedup: { tref / times [- 1 ]:.2f} x" )
79+ t0 = time ()
80+ result = expr_numba (na , nb , nc )
81+ times .append (time () - t0 )
82+ print (f"Time for numba.jit (np.ndarray): { times [- 1 ]:.3f} s, speedup: { tref / times [- 1 ]:.2f} x" )
83+
84+
85+ # Plot the results using an horizontal bar chart
86+ import matplotlib .pyplot as plt
87+
88+ labels = ['NumPy' , 'blosc2.jit (np.ndarray)' , 'blosc2.jit (blosc2.NDArray)' , 'numba.jit (first run)' , 'numba.jit (cached)' ]
89+ # Reverse the labels and times arrays
90+ labels_rev = labels [::- 1 ]
91+ times_rev = times [::- 1 ]
92+
93+ # Create position indices for the reversed data
94+ x = np .arange (len (labels_rev ))
95+
96+ fig , ax = plt .subplots (figsize = (10 , 6 ))
97+
98+ # Define colors for different categories
99+ colors = ['#FF9999' , '#66B2FF' , '#66B2FF' , '#99CC99' , '#99CC99' ] # Red for NumPy, Blue for blosc2, Green for numba
100+ # Note: colors are in reverse order to match the reversed data
101+ colors_rev = colors [::- 1 ]
102+
103+ bars = ax .barh (x , times_rev , height = 0.35 , color = colors_rev , label = 'Time (s)' )
104+
105+ # Add speedup annotations at the end of each bar
106+ # NumPy is our reference (the first element in original array, last in reversed)
107+ numpy_time = tref # Reference time for NumPy
108+ for i , (bar , time ) in enumerate (zip (bars , times_rev )):
109+ # Skip the NumPy bar since it's our reference
110+ if i < len (times_rev ) - 1 : # Skip the last bar (NumPy)
111+ speedup = numpy_time / time
112+ ax .annotate (f'({ speedup :.1f} x)' ,
113+ (bar .get_width () + 0.05 , bar .get_y () + bar .get_height ()/ 2 ),
114+ va = 'center' )
115+
116+ ax .set_xlabel ('Time (s)' )
117+ ax .set_title ("""Compute: np.sum(((a**3 + np.sin(a * 2)) < np.cumsum(c)) & (b > 0), axis=1)
118+ (Execution time for different decorators)""" )
119+ ax .set_yticks (x )
120+ ax .set_yticklabels (labels_rev )
121+
122+ # Create custom legend with only one entry per category
123+ from matplotlib .patches import Patch
124+ legend_elements = [
125+ Patch (facecolor = '#FF9999' , label = 'NumPy' ),
126+ Patch (facecolor = '#66B2FF' , label = 'blosc2.jit' ),
127+ Patch (facecolor = '#99CC99' , label = 'numba.jit' )
128+ ]
129+ ax .legend (handles = legend_elements , loc = 'best' )
130+
131+ plt .tight_layout ()
132+ plt .savefig ('jit_benchmark_comparison.png' , dpi = 300 , bbox_inches = 'tight' )
133+ plt .show ()
0 commit comments