Skip to content

Commit 5e53537

Browse files
cetagostinijessegrabowski
authored andcommitted
improving benchmark
1 parent a43f1cf commit 5e53537

File tree

1 file changed

+81
-137
lines changed

1 file changed

+81
-137
lines changed

doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb

Lines changed: 81 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 1,
5+
"execution_count": 5,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -21,7 +21,7 @@
2121
},
2222
{
2323
"cell_type": "code",
24-
"execution_count": 2,
24+
"execution_count": 6,
2525
"metadata": {},
2626
"outputs": [],
2727
"source": [
@@ -30,14 +30,14 @@
3030
"\n",
3131
"# Set up PyTensor JAX mode\n",
3232
"jax_optimizer = RewriteDatabaseQuery(include=[\"jax\"], exclude=[])\n",
33-
"pytensor_jax_mode = Mode(linker=JAXLinker(), optimizer=jax_optimizer)\n",
33+
"pytensor_jax_mode = \"JAX\"\n",
3434
"\n",
3535
"# Try to set up MLX mode\n",
3636
"try:\n",
3737
" from pytensor.link.mlx import MLXLinker\n",
3838
" import mlx.core as mx\n",
3939
" mlx_optimizer = RewriteDatabaseQuery(include=[\"mlx\"], exclude=[])\n",
40-
" pytensor_mlx_mode = Mode(linker=MLXLinker(), optimizer=mlx_optimizer)\n",
40+
" pytensor_mlx_mode = \"MLX\"\n",
4141
" MLX_AVAILABLE = True\n",
4242
"except ImportError:\n",
4343
" MLX_AVAILABLE = False\n",
@@ -101,29 +101,28 @@
101101
" A = np.random.randn(size, size).astype(np.float32)\n",
102102
" B = np.random.randn(size, size).astype(np.float32)\n",
103103
" C = np.random.randn(size, size).astype(np.float32)\n",
104+
"\n",
105+
" pt_A = pt.matrix('A', dtype='float32')\n",
106+
" pt_B = pt.matrix('B', dtype='float32') \n",
107+
" pt_C = pt.matrix('C', dtype='float32')\n",
108+
" result = pt.dot(pt.dot(pt_A, pt_B), pt_C)\n",
109+
"\n",
110+
"\n",
111+
" f_jax = function([pt_A, pt_B, pt_C], result, mode=pytensor_jax_mode, trust_input=True)\n",
112+
" f_mlx = function([pt_A, pt_B, pt_C], result, mode=pytensor_mlx_mode, trust_input=True)\n",
104113
" \n",
105114
" # === TEST 1: Matrix Multiplication Chain ===\n",
106115
" # PyTensor + JAX backend\n",
107116
" @timer_jax\n",
108117
" def pytensor_jax_matmul():\n",
109-
" pt_A = pt.matrix('A', dtype='float32')\n",
110-
" pt_B = pt.matrix('B', dtype='float32') \n",
111-
" pt_C = pt.matrix('C', dtype='float32')\n",
112-
" result = pt.dot(pt.dot(pt_A, pt_B), pt_C)\n",
113-
" f = function([pt_A, pt_B, pt_C], result, mode=pytensor_jax_mode)\n",
114-
" return f(A, B, C)\n",
118+
" return f_jax(A, B, C)\n",
115119
" \n",
116120
" # PyTensor + MLX backend\n",
117121
" @timer_mlx\n",
118122
" def pytensor_mlx_matmul():\n",
119123
" if not MLX_AVAILABLE:\n",
120124
" return None, float('inf'), 0\n",
121-
" pt_A = pt.matrix('A', dtype='float32')\n",
122-
" pt_B = pt.matrix('B', dtype='float32')\n",
123-
" pt_C = pt.matrix('C', dtype='float32')\n",
124-
" result = pt_A @ pt_B @ pt_C\n",
125-
" f = function([pt_A, pt_B, pt_C], result, mode=pytensor_mlx_mode)\n",
126-
" return f(A, B, C)\n",
125+
" return f_mlx(A, B, C)\n",
127126
" \n",
128127
" # Run matrix multiplication test\n",
129128
" _, jax_mean, jax_std = pytensor_jax_matmul()\n",
@@ -145,24 +144,20 @@
145144
" \n",
146145
" # === TEST 2: Element-wise Operations ===\n",
147146
" # PyTensor + JAX\n",
147+
" result = pt.sin(pt_A) + pt.cos(pt_B)\n",
148+
" f_jax = function([pt_A, pt_B], result, mode=pytensor_jax_mode, trust_input=True)\n",
149+
" f_mlx = function([pt_A, pt_B], result, mode=pytensor_mlx_mode, trust_input=True)\n",
150+
"\n",
148151
" @timer_jax\n",
149152
" def pytensor_jax_elemwise():\n",
150-
" pt_A = pt.matrix('A', dtype='float32')\n",
151-
" pt_B = pt.matrix('B', dtype='float32')\n",
152-
" result = pt.sin(pt_A) + pt.cos(pt_B)\n",
153-
" f = function([pt_A, pt_B], result, mode=pytensor_jax_mode)\n",
154-
" return f(A, B)\n",
153+
" return f_jax(A, B)\n",
155154
" \n",
156155
" # PyTensor + MLX\n",
157156
" @timer_mlx\n",
158157
" def pytensor_mlx_elemwise():\n",
159158
" if not MLX_AVAILABLE:\n",
160159
" return None, float('inf'), 0\n",
161-
" pt_A = pt.matrix('A', dtype='float32')\n",
162-
" pt_B = pt.matrix('B', dtype='float32')\n",
163-
" result = pt.sin(pt_A) + pt.cos(pt_B)\n",
164-
" f = function([pt_A, pt_B], result, mode=pytensor_mlx_mode)\n",
165-
" return f(A, B)\n",
160+
" return f_mlx(A, B)\n",
166161
" \n",
167162
" # Run element-wise test\n",
168163
" _, jax_mean, jax_std = pytensor_jax_elemwise()\n",
@@ -184,24 +179,19 @@
184179
" \n",
185180
" # === TEST 3: Matrix Addition with Broadcasting ===\n",
186181
" # PyTensor + JAX\n",
182+
" result = pt_A + pt_B.T\n",
183+
" f_jax = function([pt_A, pt_B], result, mode=pytensor_jax_mode, trust_input=True)\n",
184+
" f_mlx = function([pt_A, pt_B], result, mode=pytensor_mlx_mode, trust_input=True)\n",
187185
" @timer_jax\n",
188186
" def pytensor_jax_broadcast():\n",
189-
" pt_A = pt.matrix('A', dtype='float32')\n",
190-
" pt_B = pt.matrix('B', dtype='float32')\n",
191-
" result = pt_A + pt_B.T\n",
192-
" f = function([pt_A, pt_B], result, mode=pytensor_jax_mode)\n",
193-
" return f(A, B)\n",
187+
" return f_jax(A, B)\n",
194188
" \n",
195189
" # PyTensor + MLX\n",
196190
" @timer_mlx\n",
197191
" def pytensor_mlx_broadcast():\n",
198192
" if not MLX_AVAILABLE:\n",
199193
" return None, float('inf'), 0\n",
200-
" pt_A = pt.matrix('A', dtype='float32')\n",
201-
" pt_B = pt.matrix('B', dtype='float32')\n",
202-
" result = pt_A + pt_B.T\n",
203-
" f = function([pt_A, pt_B], result, mode=pytensor_mlx_mode)\n",
204-
" return f(A, B)\n",
194+
" return f_mlx(A, B)\n",
205195
" \n",
206196
" # Run broadcasting test\n",
207197
" _, jax_mean, jax_std = pytensor_jax_broadcast()\n",
@@ -225,49 +215,6 @@
225215
" df = pd.DataFrame(results)\n",
226216
" return df\n",
227217
"\n",
228-
"def verify_computation_correctness():\n",
229-
" \"\"\"Verify that JAX and MLX backends produce the same results\"\"\"\n",
230-
" if not MLX_AVAILABLE:\n",
231-
" print(\"MLX not available, skipping correctness check\")\n",
232-
" return\n",
233-
" \n",
234-
" print(\"Verifying computational correctness...\")\n",
235-
" \n",
236-
" # Test with small matrices\n",
237-
" np.random.seed(42)\n",
238-
" A = np.random.randn(4, 4).astype(np.float32)\n",
239-
" B = np.random.randn(4, 4).astype(np.float32)\n",
240-
" C = np.random.randn(4, 4).astype(np.float32)\n",
241-
" \n",
242-
" # Test matrix multiplication\n",
243-
" pt_A = pt.matrix('A', dtype='float32')\n",
244-
" pt_B = pt.matrix('B', dtype='float32')\n",
245-
" pt_C = pt.matrix('C', dtype='float32')\n",
246-
" result_expr = pt_A @ pt_B @ pt_C\n",
247-
" \n",
248-
" f_jax = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_jax_mode)\n",
249-
" f_mlx = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_mlx_mode)\n",
250-
" \n",
251-
" result_jax = f_jax(A, B, C)\n",
252-
" result_mlx = f_mlx(A, B, C)\n",
253-
" \n",
254-
" # Force MLX evaluation\n",
255-
" mx.eval(result_mlx)\n",
256-
" \n",
257-
" # Convert to numpy for comparison\n",
258-
" if hasattr(result_jax, 'block_until_ready'):\n",
259-
" result_jax.block_until_ready()\n",
260-
" \n",
261-
" diff = np.abs(np.array(result_jax) - np.array(result_mlx)).max()\n",
262-
" print(f\"Max difference between JAX and MLX results: {diff:.2e}\")\n",
263-
" \n",
264-
" if diff < 1e-5:\n",
265-
" print(\"✅ Results match within tolerance\")\n",
266-
" else:\n",
267-
" print(\"❌ Results differ significantly\")\n",
268-
" \n",
269-
" return diff\n",
270-
"\n",
271218
"def main(N=1000):\n",
272219
" \"\"\"Main benchmark execution\"\"\"\n",
273220
" # Display system info\n",
@@ -285,9 +232,6 @@
285232
" import pandas as pd\n",
286233
" info_df = pd.DataFrame([system_info])\n",
287234
" \n",
288-
" # First verify correctness\n",
289-
" verify_computation_correctness()\n",
290-
" \n",
291235
" # Then run benchmarks\n",
292236
" results_df = run_benchmark(N=N)\n",
293237
" \n",
@@ -296,50 +240,50 @@
296240
},
297241
{
298242
"cell_type": "code",
299-
"execution_count": null,
243+
"execution_count": 10,
300244
"metadata": {},
301245
"outputs": [
302246
{
303247
"name": "stdout",
304248
"output_type": "stream",
305249
"text": [
306-
"Verifying computational correctness...\n",
307-
"Max difference between JAX and MLX results: 0.00e+00\n",
308-
"✅ Results match within tolerance\n",
309-
"Running benchmarks with N=20 repetitions per test...\n",
310-
"Testing 128x128 matrices...\n"
250+
"Running benchmarks with N=100 repetitions per test...\n",
251+
"Testing 128x128 matrices...\n",
252+
"Testing 256x256 matrices...\n",
253+
"Testing 512x512 matrices...\n",
254+
"Testing 1024x1024 matrices...\n"
311255
]
312256
}
313257
],
314258
"source": [
315-
"iteration=20\n",
259+
"iteration=100\n",
316260
"_, results = main(N=iteration)"
317261
]
318262
},
319263
{
320264
"cell_type": "code",
321-
"execution_count": 27,
265+
"execution_count": 11,
322266
"metadata": {},
323267
"outputs": [
324268
{
325269
"name": "stdout",
326270
"output_type": "stream",
327271
"text": [
328272
"\n",
329-
"Benchmark Results over 1000 repetitions:\n",
273+
"Benchmark Results over 100 repetitions:\n",
330274
" Size Operation PyTensor+JAX Mean (s) PyTensor+JAX Std (s) PyTensor+MLX Mean (s) PyTensor+MLX Std (s) MLX Speedup\n",
331-
" 128x128 Matrix Chain (A @ B @ C) 0.005700 0.002127 0.001215 0.000497 4.69x\n",
332-
" 128x128 Element-wise (sin(A) + cos(B)) 0.008280 0.002158 0.000876 0.000451 9.45x\n",
333-
" 128x128 Broadcasting (A + B.T) 0.008083 0.002485 0.000861 0.000207 9.39x\n",
334-
" 256x256 Matrix Chain (A @ B @ C) 0.005705 0.002307 0.001085 0.000210 5.26x\n",
335-
" 256x256 Element-wise (sin(A) + cos(B)) 0.009794 0.001994 0.000998 0.001895 9.82x\n",
336-
" 256x256 Broadcasting (A + B.T) 0.010467 0.002573 0.001056 0.000578 9.91x\n",
337-
" 512x512 Matrix Chain (A @ B @ C) 0.006898 0.002576 0.001300 0.000391 5.31x\n",
338-
" 512x512 Element-wise (sin(A) + cos(B)) 0.010997 0.002435 0.000976 0.000584 11.27x\n",
339-
" 512x512 Broadcasting (A + B.T) 0.009730 0.002690 0.000968 0.000315 10.05x\n",
340-
"1024x1024 Matrix Chain (A @ B @ C) 0.010941 0.002035 0.001735 0.000302 6.31x\n",
341-
"1024x1024 Element-wise (sin(A) + cos(B)) 0.013936 0.003774 0.001103 0.000253 12.64x\n",
342-
"1024x1024 Broadcasting (A + B.T) 0.011153 0.002297 0.001084 0.000242 10.29x\n"
275+
" 128x128 Matrix Chain (A @ B @ C) 0.000131 0.000300 0.000283 0.000216 0.46x\n",
276+
" 128x128 Element-wise (sin(A) + cos(B)) 0.000104 0.000304 0.000209 0.000145 0.50x\n",
277+
" 128x128 Broadcasting (A + B.T) 0.000037 0.000296 0.000215 0.000153 0.17x\n",
278+
" 256x256 Matrix Chain (A @ B @ C) 0.000394 0.000372 0.000441 0.000239 0.89x\n",
279+
" 256x256 Element-wise (sin(A) + cos(B)) 0.000247 0.000389 0.000255 0.000168 0.97x\n",
280+
" 256x256 Broadcasting (A + B.T) 0.000063 0.000329 0.000217 0.000153 0.29x\n",
281+
" 512x512 Matrix Chain (A @ B @ C) 0.001004 0.000255 0.000399 0.000188 2.51x\n",
282+
" 512x512 Element-wise (sin(A) + cos(B)) 0.000664 0.000328 0.000263 0.000163 2.53x\n",
283+
" 512x512 Broadcasting (A + B.T) 0.000115 0.000339 0.000254 0.000156 0.45x\n",
284+
"1024x1024 Matrix Chain (A @ B @ C) 0.005281 0.000359 0.000993 0.000342 5.32x\n",
285+
"1024x1024 Element-wise (sin(A) + cos(B)) 0.002595 0.000359 0.000408 0.000220 6.36x\n",
286+
"1024x1024 Broadcasting (A + B.T) 0.000501 0.000346 0.000385 0.000155 1.30x\n"
343287
]
344288
}
345289
],
@@ -367,46 +311,46 @@
367311
}
368312
],
369313
"source": [
370-
"# Additional timing analysis - separate compilation vs execution time\n",
371-
"if MLX_AVAILABLE:\n",
372-
" print(\"\\n=== Detailed MLX Timing Analysis ===\")\n",
314+
"# # Additional timing analysis - separate compilation vs execution time\n",
315+
"# if MLX_AVAILABLE:\n",
316+
"# print(\"\\n=== Detailed MLX Timing Analysis ===\")\n",
373317
" \n",
374-
" # Test with medium-sized matrix\n",
375-
" np.random.seed(42)\n",
376-
" A = np.random.randn(512, 512).astype(np.float32)\n",
377-
" B = np.random.randn(512, 512).astype(np.float32)\n",
378-
" C = np.random.randn(512, 512).astype(np.float32)\n",
318+
"# # Test with medium-sized matrix\n",
319+
"# np.random.seed(42)\n",
320+
"# A = np.random.randn(512, 512).astype(np.float32)\n",
321+
"# B = np.random.randn(512, 512).astype(np.float32)\n",
322+
"# C = np.random.randn(512, 512).astype(np.float32)\n",
379323
" \n",
380-
" # Create PyTensor function (compilation time)\n",
381-
" start = time.perf_counter()\n",
382-
" pt_A = pt.matrix('A', dtype='float32')\n",
383-
" pt_B = pt.matrix('B', dtype='float32')\n",
384-
" pt_C = pt.matrix('C', dtype='float32')\n",
385-
" result_expr = pt_A @ pt_B @ pt_C\n",
386-
" f_mlx = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_mlx_mode)\n",
387-
" compilation_time = time.perf_counter() - start\n",
324+
"# # Create PyTensor function (compilation time)\n",
325+
"# start = time.perf_counter()\n",
326+
"# pt_A = pt.matrix('A', dtype='float32')\n",
327+
"# pt_B = pt.matrix('B', dtype='float32')\n",
328+
"# pt_C = pt.matrix('C', dtype='float32')\n",
329+
"# result_expr = pt_A @ pt_B @ pt_C\n",
330+
"# f_mlx = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_mlx_mode)\n",
331+
"# compilation_time = time.perf_counter() - start\n",
388332
" \n",
389-
" # First execution (may include additional compilation/optimization)\n",
390-
" start = time.perf_counter()\n",
391-
" result = f_mlx(A, B, C)\n",
392-
" mx.eval(result) # Force evaluation\n",
393-
" first_exec_time = time.perf_counter() - start\n",
333+
"# # First execution (may include additional compilation/optimization)\n",
334+
"# start = time.perf_counter()\n",
335+
"# result = f_mlx(A, B, C)\n",
336+
"# mx.eval(result) # Force evaluation\n",
337+
"# first_exec_time = time.perf_counter() - start\n",
394338
" \n",
395-
" # Subsequent executions (should be faster)\n",
396-
" exec_times = []\n",
397-
" for _ in range(1000):\n",
398-
" start = time.perf_counter()\n",
399-
" result = f_mlx(A, B, C)\n",
400-
" mx.eval(result)\n",
401-
" exec_times.append(time.perf_counter() - start)\n",
339+
"# # Subsequent executions (should be faster)\n",
340+
"# exec_times = []\n",
341+
"# for _ in range(1000):\n",
342+
"# start = time.perf_counter()\n",
343+
"# result = f_mlx(A, B, C)\n",
344+
"# mx.eval(result)\n",
345+
"# exec_times.append(time.perf_counter() - start)\n",
402346
" \n",
403-
" avg_exec_time = np.mean(exec_times)\n",
404-
" std_exec_time = np.std(exec_times)\n",
347+
"# avg_exec_time = np.mean(exec_times)\n",
348+
"# std_exec_time = np.std(exec_times)\n",
405349
" \n",
406-
" print(f\"Compilation time: {compilation_time:.4f}s\")\n",
407-
" print(f\"First execution: {first_exec_time:.4f}s\")\n",
408-
" print(f\"Average execution (5 runs): {avg_exec_time:.4f}s ± {std_exec_time:.4f}s\")\n",
409-
" print(f\"Individual execution times: {[f'{t:.4f}' for t in exec_times]}\")\n"
350+
"# print(f\"Compilation time: {compilation_time:.4f}s\")\n",
351+
"# print(f\"First execution: {first_exec_time:.4f}s\")\n",
352+
"# print(f\"Average execution (5 runs): {avg_exec_time:.4f}s ± {std_exec_time:.4f}s\")\n",
353+
"# print(f\"Individual execution times: {[f'{t:.4f}' for t in exec_times]}\")\n"
410354
]
411355
},
412356
{

0 commit comments

Comments
 (0)