|
2 | 2 | "cells": [
|
3 | 3 | {
|
4 | 4 | "cell_type": "code",
|
5 |
| - "execution_count": 1, |
| 5 | + "execution_count": 5, |
6 | 6 | "metadata": {},
|
7 | 7 | "outputs": [],
|
8 | 8 | "source": [
|
|
21 | 21 | },
|
22 | 22 | {
|
23 | 23 | "cell_type": "code",
|
24 |
| - "execution_count": 2, |
| 24 | + "execution_count": 6, |
25 | 25 | "metadata": {},
|
26 | 26 | "outputs": [],
|
27 | 27 | "source": [
|
|
30 | 30 | "\n",
|
31 | 31 | "# Set up PyTensor JAX mode\n",
|
32 | 32 | "jax_optimizer = RewriteDatabaseQuery(include=[\"jax\"], exclude=[])\n",
|
33 |
| - "pytensor_jax_mode = Mode(linker=JAXLinker(), optimizer=jax_optimizer)\n", |
| 33 | + "pytensor_jax_mode = \"JAX\"\n", |
34 | 34 | "\n",
|
35 | 35 | "# Try to set up MLX mode\n",
|
36 | 36 | "try:\n",
|
37 | 37 | " from pytensor.link.mlx import MLXLinker\n",
|
38 | 38 | " import mlx.core as mx\n",
|
39 | 39 | " mlx_optimizer = RewriteDatabaseQuery(include=[\"mlx\"], exclude=[])\n",
|
40 |
| - " pytensor_mlx_mode = Mode(linker=MLXLinker(), optimizer=mlx_optimizer)\n", |
| 40 | + " pytensor_mlx_mode = \"MLX\"\n", |
41 | 41 | " MLX_AVAILABLE = True\n",
|
42 | 42 | "except ImportError:\n",
|
43 | 43 | " MLX_AVAILABLE = False\n",
|
|
101 | 101 | " A = np.random.randn(size, size).astype(np.float32)\n",
|
102 | 102 | " B = np.random.randn(size, size).astype(np.float32)\n",
|
103 | 103 | " C = np.random.randn(size, size).astype(np.float32)\n",
|
| 104 | + "\n", |
| 105 | + " pt_A = pt.matrix('A', dtype='float32')\n", |
| 106 | + " pt_B = pt.matrix('B', dtype='float32') \n", |
| 107 | + " pt_C = pt.matrix('C', dtype='float32')\n", |
| 108 | + " result = pt.dot(pt.dot(pt_A, pt_B), pt_C)\n", |
| 109 | + "\n", |
| 110 | + "\n", |
| 111 | + " f_jax = function([pt_A, pt_B, pt_C], result, mode=pytensor_jax_mode, trust_input=True)\n", |
| 112 | + " f_mlx = function([pt_A, pt_B, pt_C], result, mode=pytensor_mlx_mode, trust_input=True)\n", |
104 | 113 | " \n",
|
105 | 114 | " # === TEST 1: Matrix Multiplication Chain ===\n",
|
106 | 115 | " # PyTensor + JAX backend\n",
|
107 | 116 | " @timer_jax\n",
|
108 | 117 | " def pytensor_jax_matmul():\n",
|
109 |
| - " pt_A = pt.matrix('A', dtype='float32')\n", |
110 |
| - " pt_B = pt.matrix('B', dtype='float32') \n", |
111 |
| - " pt_C = pt.matrix('C', dtype='float32')\n", |
112 |
| - " result = pt.dot(pt.dot(pt_A, pt_B), pt_C)\n", |
113 |
| - " f = function([pt_A, pt_B, pt_C], result, mode=pytensor_jax_mode)\n", |
114 |
| - " return f(A, B, C)\n", |
| 118 | + " return f_jax(A, B, C)\n", |
115 | 119 | " \n",
|
116 | 120 | " # PyTensor + MLX backend\n",
|
117 | 121 | " @timer_mlx\n",
|
118 | 122 | " def pytensor_mlx_matmul():\n",
|
119 | 123 | " if not MLX_AVAILABLE:\n",
|
120 | 124 | " return None, float('inf'), 0\n",
|
121 |
| - " pt_A = pt.matrix('A', dtype='float32')\n", |
122 |
| - " pt_B = pt.matrix('B', dtype='float32')\n", |
123 |
| - " pt_C = pt.matrix('C', dtype='float32')\n", |
124 |
| - " result = pt_A @ pt_B @ pt_C\n", |
125 |
| - " f = function([pt_A, pt_B, pt_C], result, mode=pytensor_mlx_mode)\n", |
126 |
| - " return f(A, B, C)\n", |
| 125 | + " return f_mlx(A, B, C)\n", |
127 | 126 | " \n",
|
128 | 127 | " # Run matrix multiplication test\n",
|
129 | 128 | " _, jax_mean, jax_std = pytensor_jax_matmul()\n",
|
|
145 | 144 | " \n",
|
146 | 145 | " # === TEST 2: Element-wise Operations ===\n",
|
147 | 146 | " # PyTensor + JAX\n",
|
| 147 | + " result = pt.sin(pt_A) + pt.cos(pt_B)\n", |
| 148 | + " f_jax = function([pt_A, pt_B], result, mode=pytensor_jax_mode, trust_input=True)\n", |
| 149 | + " f_mlx = function([pt_A, pt_B], result, mode=pytensor_mlx_mode, trust_input=True)\n", |
| 150 | + "\n", |
148 | 151 | " @timer_jax\n",
|
149 | 152 | " def pytensor_jax_elemwise():\n",
|
150 |
| - " pt_A = pt.matrix('A', dtype='float32')\n", |
151 |
| - " pt_B = pt.matrix('B', dtype='float32')\n", |
152 |
| - " result = pt.sin(pt_A) + pt.cos(pt_B)\n", |
153 |
| - " f = function([pt_A, pt_B], result, mode=pytensor_jax_mode)\n", |
154 |
| - " return f(A, B)\n", |
| 153 | + " return f_jax(A, B)\n", |
155 | 154 | " \n",
|
156 | 155 | " # PyTensor + MLX\n",
|
157 | 156 | " @timer_mlx\n",
|
158 | 157 | " def pytensor_mlx_elemwise():\n",
|
159 | 158 | " if not MLX_AVAILABLE:\n",
|
160 | 159 | " return None, float('inf'), 0\n",
|
161 |
| - " pt_A = pt.matrix('A', dtype='float32')\n", |
162 |
| - " pt_B = pt.matrix('B', dtype='float32')\n", |
163 |
| - " result = pt.sin(pt_A) + pt.cos(pt_B)\n", |
164 |
| - " f = function([pt_A, pt_B], result, mode=pytensor_mlx_mode)\n", |
165 |
| - " return f(A, B)\n", |
| 160 | + " return f_mlx(A, B)\n", |
166 | 161 | " \n",
|
167 | 162 | " # Run element-wise test\n",
|
168 | 163 | " _, jax_mean, jax_std = pytensor_jax_elemwise()\n",
|
|
184 | 179 | " \n",
|
185 | 180 | " # === TEST 3: Matrix Addition with Broadcasting ===\n",
|
186 | 181 | " # PyTensor + JAX\n",
|
| 182 | + " result = pt_A + pt_B.T\n", |
| 183 | + " f_jax = function([pt_A, pt_B], result, mode=pytensor_jax_mode, trust_input=True)\n", |
| 184 | + " f_mlx = function([pt_A, pt_B], result, mode=pytensor_mlx_mode, trust_input=True)\n", |
187 | 185 | " @timer_jax\n",
|
188 | 186 | " def pytensor_jax_broadcast():\n",
|
189 |
| - " pt_A = pt.matrix('A', dtype='float32')\n", |
190 |
| - " pt_B = pt.matrix('B', dtype='float32')\n", |
191 |
| - " result = pt_A + pt_B.T\n", |
192 |
| - " f = function([pt_A, pt_B], result, mode=pytensor_jax_mode)\n", |
193 |
| - " return f(A, B)\n", |
| 187 | + " return f_jax(A, B)\n", |
194 | 188 | " \n",
|
195 | 189 | " # PyTensor + MLX\n",
|
196 | 190 | " @timer_mlx\n",
|
197 | 191 | " def pytensor_mlx_broadcast():\n",
|
198 | 192 | " if not MLX_AVAILABLE:\n",
|
199 | 193 | " return None, float('inf'), 0\n",
|
200 |
| - " pt_A = pt.matrix('A', dtype='float32')\n", |
201 |
| - " pt_B = pt.matrix('B', dtype='float32')\n", |
202 |
| - " result = pt_A + pt_B.T\n", |
203 |
| - " f = function([pt_A, pt_B], result, mode=pytensor_mlx_mode)\n", |
204 |
| - " return f(A, B)\n", |
| 194 | + " return f_mlx(A, B)\n", |
205 | 195 | " \n",
|
206 | 196 | " # Run broadcasting test\n",
|
207 | 197 | " _, jax_mean, jax_std = pytensor_jax_broadcast()\n",
|
|
225 | 215 | " df = pd.DataFrame(results)\n",
|
226 | 216 | " return df\n",
|
227 | 217 | "\n",
|
228 |
| - "def verify_computation_correctness():\n", |
229 |
| - " \"\"\"Verify that JAX and MLX backends produce the same results\"\"\"\n", |
230 |
| - " if not MLX_AVAILABLE:\n", |
231 |
| - " print(\"MLX not available, skipping correctness check\")\n", |
232 |
| - " return\n", |
233 |
| - " \n", |
234 |
| - " print(\"Verifying computational correctness...\")\n", |
235 |
| - " \n", |
236 |
| - " # Test with small matrices\n", |
237 |
| - " np.random.seed(42)\n", |
238 |
| - " A = np.random.randn(4, 4).astype(np.float32)\n", |
239 |
| - " B = np.random.randn(4, 4).astype(np.float32)\n", |
240 |
| - " C = np.random.randn(4, 4).astype(np.float32)\n", |
241 |
| - " \n", |
242 |
| - " # Test matrix multiplication\n", |
243 |
| - " pt_A = pt.matrix('A', dtype='float32')\n", |
244 |
| - " pt_B = pt.matrix('B', dtype='float32')\n", |
245 |
| - " pt_C = pt.matrix('C', dtype='float32')\n", |
246 |
| - " result_expr = pt_A @ pt_B @ pt_C\n", |
247 |
| - " \n", |
248 |
| - " f_jax = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_jax_mode)\n", |
249 |
| - " f_mlx = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_mlx_mode)\n", |
250 |
| - " \n", |
251 |
| - " result_jax = f_jax(A, B, C)\n", |
252 |
| - " result_mlx = f_mlx(A, B, C)\n", |
253 |
| - " \n", |
254 |
| - " # Force MLX evaluation\n", |
255 |
| - " mx.eval(result_mlx)\n", |
256 |
| - " \n", |
257 |
| - " # Convert to numpy for comparison\n", |
258 |
| - " if hasattr(result_jax, 'block_until_ready'):\n", |
259 |
| - " result_jax.block_until_ready()\n", |
260 |
| - " \n", |
261 |
| - " diff = np.abs(np.array(result_jax) - np.array(result_mlx)).max()\n", |
262 |
| - " print(f\"Max difference between JAX and MLX results: {diff:.2e}\")\n", |
263 |
| - " \n", |
264 |
| - " if diff < 1e-5:\n", |
265 |
| - " print(\"✅ Results match within tolerance\")\n", |
266 |
| - " else:\n", |
267 |
| - " print(\"❌ Results differ significantly\")\n", |
268 |
| - " \n", |
269 |
| - " return diff\n", |
270 |
| - "\n", |
271 | 218 | "def main(N=1000):\n",
|
272 | 219 | " \"\"\"Main benchmark execution\"\"\"\n",
|
273 | 220 | " # Display system info\n",
|
|
285 | 232 | " import pandas as pd\n",
|
286 | 233 | " info_df = pd.DataFrame([system_info])\n",
|
287 | 234 | " \n",
|
288 |
| - " # First verify correctness\n", |
289 |
| - " verify_computation_correctness()\n", |
290 |
| - " \n", |
291 | 235 | " # Then run benchmarks\n",
|
292 | 236 | " results_df = run_benchmark(N=N)\n",
|
293 | 237 | " \n",
|
|
296 | 240 | },
|
297 | 241 | {
|
298 | 242 | "cell_type": "code",
|
299 |
| - "execution_count": null, |
| 243 | + "execution_count": 10, |
300 | 244 | "metadata": {},
|
301 | 245 | "outputs": [
|
302 | 246 | {
|
303 | 247 | "name": "stdout",
|
304 | 248 | "output_type": "stream",
|
305 | 249 | "text": [
|
306 |
| - "Verifying computational correctness...\n", |
307 |
| - "Max difference between JAX and MLX results: 0.00e+00\n", |
308 |
| - "✅ Results match within tolerance\n", |
309 |
| - "Running benchmarks with N=20 repetitions per test...\n", |
310 |
| - "Testing 128x128 matrices...\n" |
| 250 | + "Running benchmarks with N=100 repetitions per test...\n", |
| 251 | + "Testing 128x128 matrices...\n", |
| 252 | + "Testing 256x256 matrices...\n", |
| 253 | + "Testing 512x512 matrices...\n", |
| 254 | + "Testing 1024x1024 matrices...\n" |
311 | 255 | ]
|
312 | 256 | }
|
313 | 257 | ],
|
314 | 258 | "source": [
|
315 |
| - "iteration=20\n", |
| 259 | + "iteration=100\n", |
316 | 260 | "_, results = main(N=iteration)"
|
317 | 261 | ]
|
318 | 262 | },
|
319 | 263 | {
|
320 | 264 | "cell_type": "code",
|
321 |
| - "execution_count": 27, |
| 265 | + "execution_count": 11, |
322 | 266 | "metadata": {},
|
323 | 267 | "outputs": [
|
324 | 268 | {
|
325 | 269 | "name": "stdout",
|
326 | 270 | "output_type": "stream",
|
327 | 271 | "text": [
|
328 | 272 | "\n",
|
329 |
| - "Benchmark Results over 1000 repetitions:\n", |
| 273 | + "Benchmark Results over 100 repetitions:\n", |
330 | 274 | " Size Operation PyTensor+JAX Mean (s) PyTensor+JAX Std (s) PyTensor+MLX Mean (s) PyTensor+MLX Std (s) MLX Speedup\n",
|
331 |
| - " 128x128 Matrix Chain (A @ B @ C) 0.005700 0.002127 0.001215 0.000497 4.69x\n", |
332 |
| - " 128x128 Element-wise (sin(A) + cos(B)) 0.008280 0.002158 0.000876 0.000451 9.45x\n", |
333 |
| - " 128x128 Broadcasting (A + B.T) 0.008083 0.002485 0.000861 0.000207 9.39x\n", |
334 |
| - " 256x256 Matrix Chain (A @ B @ C) 0.005705 0.002307 0.001085 0.000210 5.26x\n", |
335 |
| - " 256x256 Element-wise (sin(A) + cos(B)) 0.009794 0.001994 0.000998 0.001895 9.82x\n", |
336 |
| - " 256x256 Broadcasting (A + B.T) 0.010467 0.002573 0.001056 0.000578 9.91x\n", |
337 |
| - " 512x512 Matrix Chain (A @ B @ C) 0.006898 0.002576 0.001300 0.000391 5.31x\n", |
338 |
| - " 512x512 Element-wise (sin(A) + cos(B)) 0.010997 0.002435 0.000976 0.000584 11.27x\n", |
339 |
| - " 512x512 Broadcasting (A + B.T) 0.009730 0.002690 0.000968 0.000315 10.05x\n", |
340 |
| - "1024x1024 Matrix Chain (A @ B @ C) 0.010941 0.002035 0.001735 0.000302 6.31x\n", |
341 |
| - "1024x1024 Element-wise (sin(A) + cos(B)) 0.013936 0.003774 0.001103 0.000253 12.64x\n", |
342 |
| - "1024x1024 Broadcasting (A + B.T) 0.011153 0.002297 0.001084 0.000242 10.29x\n" |
| 275 | + " 128x128 Matrix Chain (A @ B @ C) 0.000131 0.000300 0.000283 0.000216 0.46x\n", |
| 276 | + " 128x128 Element-wise (sin(A) + cos(B)) 0.000104 0.000304 0.000209 0.000145 0.50x\n", |
| 277 | + " 128x128 Broadcasting (A + B.T) 0.000037 0.000296 0.000215 0.000153 0.17x\n", |
| 278 | + " 256x256 Matrix Chain (A @ B @ C) 0.000394 0.000372 0.000441 0.000239 0.89x\n", |
| 279 | + " 256x256 Element-wise (sin(A) + cos(B)) 0.000247 0.000389 0.000255 0.000168 0.97x\n", |
| 280 | + " 256x256 Broadcasting (A + B.T) 0.000063 0.000329 0.000217 0.000153 0.29x\n", |
| 281 | + " 512x512 Matrix Chain (A @ B @ C) 0.001004 0.000255 0.000399 0.000188 2.51x\n", |
| 282 | + " 512x512 Element-wise (sin(A) + cos(B)) 0.000664 0.000328 0.000263 0.000163 2.53x\n", |
| 283 | + " 512x512 Broadcasting (A + B.T) 0.000115 0.000339 0.000254 0.000156 0.45x\n", |
| 284 | + "1024x1024 Matrix Chain (A @ B @ C) 0.005281 0.000359 0.000993 0.000342 5.32x\n", |
| 285 | + "1024x1024 Element-wise (sin(A) + cos(B)) 0.002595 0.000359 0.000408 0.000220 6.36x\n", |
| 286 | + "1024x1024 Broadcasting (A + B.T) 0.000501 0.000346 0.000385 0.000155 1.30x\n" |
343 | 287 | ]
|
344 | 288 | }
|
345 | 289 | ],
|
|
367 | 311 | }
|
368 | 312 | ],
|
369 | 313 | "source": [
|
370 |
| - "# Additional timing analysis - separate compilation vs execution time\n", |
371 |
| - "if MLX_AVAILABLE:\n", |
372 |
| - " print(\"\\n=== Detailed MLX Timing Analysis ===\")\n", |
| 314 | + "# # Additional timing analysis - separate compilation vs execution time\n", |
| 315 | + "# if MLX_AVAILABLE:\n", |
| 316 | + "# print(\"\\n=== Detailed MLX Timing Analysis ===\")\n", |
373 | 317 | " \n",
|
374 |
| - " # Test with medium-sized matrix\n", |
375 |
| - " np.random.seed(42)\n", |
376 |
| - " A = np.random.randn(512, 512).astype(np.float32)\n", |
377 |
| - " B = np.random.randn(512, 512).astype(np.float32)\n", |
378 |
| - " C = np.random.randn(512, 512).astype(np.float32)\n", |
| 318 | + "# # Test with medium-sized matrix\n", |
| 319 | + "# np.random.seed(42)\n", |
| 320 | + "# A = np.random.randn(512, 512).astype(np.float32)\n", |
| 321 | + "# B = np.random.randn(512, 512).astype(np.float32)\n", |
| 322 | + "# C = np.random.randn(512, 512).astype(np.float32)\n", |
379 | 323 | " \n",
|
380 |
| - " # Create PyTensor function (compilation time)\n", |
381 |
| - " start = time.perf_counter()\n", |
382 |
| - " pt_A = pt.matrix('A', dtype='float32')\n", |
383 |
| - " pt_B = pt.matrix('B', dtype='float32')\n", |
384 |
| - " pt_C = pt.matrix('C', dtype='float32')\n", |
385 |
| - " result_expr = pt_A @ pt_B @ pt_C\n", |
386 |
| - " f_mlx = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_mlx_mode)\n", |
387 |
| - " compilation_time = time.perf_counter() - start\n", |
| 324 | + "# # Create PyTensor function (compilation time)\n", |
| 325 | + "# start = time.perf_counter()\n", |
| 326 | + "# pt_A = pt.matrix('A', dtype='float32')\n", |
| 327 | + "# pt_B = pt.matrix('B', dtype='float32')\n", |
| 328 | + "# pt_C = pt.matrix('C', dtype='float32')\n", |
| 329 | + "# result_expr = pt_A @ pt_B @ pt_C\n", |
| 330 | + "# f_mlx = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_mlx_mode)\n", |
| 331 | + "# compilation_time = time.perf_counter() - start\n", |
388 | 332 | " \n",
|
389 |
| - " # First execution (may include additional compilation/optimization)\n", |
390 |
| - " start = time.perf_counter()\n", |
391 |
| - " result = f_mlx(A, B, C)\n", |
392 |
| - " mx.eval(result) # Force evaluation\n", |
393 |
| - " first_exec_time = time.perf_counter() - start\n", |
| 333 | + "# # First execution (may include additional compilation/optimization)\n", |
| 334 | + "# start = time.perf_counter()\n", |
| 335 | + "# result = f_mlx(A, B, C)\n", |
| 336 | + "# mx.eval(result) # Force evaluation\n", |
| 337 | + "# first_exec_time = time.perf_counter() - start\n", |
394 | 338 | " \n",
|
395 |
| - " # Subsequent executions (should be faster)\n", |
396 |
| - " exec_times = []\n", |
397 |
| - " for _ in range(1000):\n", |
398 |
| - " start = time.perf_counter()\n", |
399 |
| - " result = f_mlx(A, B, C)\n", |
400 |
| - " mx.eval(result)\n", |
401 |
| - " exec_times.append(time.perf_counter() - start)\n", |
| 339 | + "# # Subsequent executions (should be faster)\n", |
| 340 | + "# exec_times = []\n", |
| 341 | + "# for _ in range(1000):\n", |
| 342 | + "# start = time.perf_counter()\n", |
| 343 | + "# result = f_mlx(A, B, C)\n", |
| 344 | + "# mx.eval(result)\n", |
| 345 | + "# exec_times.append(time.perf_counter() - start)\n", |
402 | 346 | " \n",
|
403 |
| - " avg_exec_time = np.mean(exec_times)\n", |
404 |
| - " std_exec_time = np.std(exec_times)\n", |
| 347 | + "# avg_exec_time = np.mean(exec_times)\n", |
| 348 | + "# std_exec_time = np.std(exec_times)\n", |
405 | 349 | " \n",
|
406 |
| - " print(f\"Compilation time: {compilation_time:.4f}s\")\n", |
407 |
| - " print(f\"First execution: {first_exec_time:.4f}s\")\n", |
408 |
| - " print(f\"Average execution (5 runs): {avg_exec_time:.4f}s ± {std_exec_time:.4f}s\")\n", |
409 |
| - " print(f\"Individual execution times: {[f'{t:.4f}' for t in exec_times]}\")\n" |
| 350 | + "# print(f\"Compilation time: {compilation_time:.4f}s\")\n", |
| 351 | + "# print(f\"First execution: {first_exec_time:.4f}s\")\n", |
| 352 | + "# print(f\"Average execution (5 runs): {avg_exec_time:.4f}s ± {std_exec_time:.4f}s\")\n", |
| 353 | + "# print(f\"Individual execution times: {[f'{t:.4f}' for t in exec_times]}\")\n" |
410 | 354 | ]
|
411 | 355 | },
|
412 | 356 | {
|
|
0 commit comments