|
1 | 1 | {
|
2 | 2 | "cells": [
|
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "execution_count": 1, |
| 6 | + "metadata": {}, |
| 7 | + "outputs": [ |
| 8 | + { |
| 9 | + "name": "stdout", |
| 10 | + "output_type": "stream", |
| 11 | + "text": [ |
| 12 | + "Obtaining file:///Users/carlostrujillo/Documents/GitHub/pytensor\n", |
| 13 | + " Installing build dependencies ... \u001b[?25ldone\n", |
| 14 | + "\u001b[?25h Checking if build backend supports build_editable ... \u001b[?25ldone\n", |
| 15 | + "\u001b[?25h Getting requirements to build editable ... \u001b[?25ldone\n", |
| 16 | + "\u001b[?25h Preparing editable metadata (pyproject.toml) ... \u001b[?25ldone\n", |
| 17 | + "\u001b[?25hBuilding wheels for collected packages: pytensor\n", |
| 18 | + " Building editable for pytensor (pyproject.toml) ... \u001b[?25ldone\n", |
| 19 | + "\u001b[?25h Created wheel for pytensor: filename=pytensor-2.31.7+80.g06ccf91ba.dirty-0.editable-cp312-cp312-macosx_11_0_arm64.whl size=7323 sha256=c09587a5f3141d49000666d2817c5a01436f13ff5a19aa3deda20f647660afee\n", |
| 20 | + " Stored in directory: /private/var/folders/f0/rbz8xs8s17n3k3f_ccp31bvh0000gn/T/pip-ephem-wheel-cache-i00nb67k/wheels/52/f6/4c/e6784e2203d5405c94db1d544248730e598e4397674416af05\n", |
| 21 | + "Successfully built pytensor\n", |
| 22 | + "Installing collected packages: pytensor\n", |
| 23 | + " Attempting uninstall: pytensor\n", |
| 24 | + " Found existing installation: pytensor 2.31.7+80.g06ccf91ba.dirty\n", |
| 25 | + " Uninstalling pytensor-2.31.7+80.g06ccf91ba.dirty:\n", |
| 26 | + " Successfully uninstalled pytensor-2.31.7+80.g06ccf91ba.dirty\n", |
| 27 | + "Successfully installed pytensor-2.31.7+80.g06ccf91ba.dirty\n", |
| 28 | + "Note: you may need to restart the kernel to use updated packages.\n" |
| 29 | + ] |
| 30 | + } |
| 31 | + ], |
| 32 | + "source": [ |
| 33 | + "%pip install -e ../.. --no-deps" |
| 34 | + ] |
| 35 | + }, |
3 | 36 | {
|
4 | 37 | "cell_type": "code",
|
5 | 38 | "execution_count": 1,
|
|
88 | 121 | " \"\"\"Run comprehensive benchmark comparing PyTensor JAX vs MLX backends\"\"\"\n",
|
89 | 122 | " import pandas as pd\n",
|
90 | 123 | " \n",
|
91 |
| - " sizes = [2, 4, 2000, 4000]\n", |
| 124 | + " sizes = [2, 4, 1080, 2080, 3080]\n", |
92 | 125 | " results = []\n",
|
93 | 126 | " \n",
|
94 | 127 | " print(f\"Running benchmarks with N={N} repetitions per test...\")\n",
|
|
277 | 310 | "Running benchmarks with N=150 repetitions per test...\n",
|
278 | 311 | "Testing 2x2 matrices...\n",
|
279 | 312 | "Testing 4x4 matrices...\n",
|
280 |
| - "Testing 2000x2000 matrices...\n", |
281 |
| - "Testing 4000x4000 matrices...\n" |
| 313 | + "Testing 1080x1080 matrices...\n", |
| 314 | + "Testing 2080x2080 matrices...\n", |
| 315 | + "Testing 3080x3080 matrices...\n" |
282 | 316 | ]
|
283 | 317 | }
|
284 | 318 | ],
|
|
299 | 333 | "\n",
|
300 | 334 | "Benchmark Results over 150 repetitions:\n",
|
301 | 335 | " Size Operation PyTensor+JAX Mean (s) PyTensor+JAX Std (s) PyTensor+MLX Mean (s) PyTensor+MLX Std (s) MLX Performance\n",
|
302 |
| - " 2x2 Matrix Chain (A @ B @ C) 0.000009 0.000002 0.000311 0.000266 -3277.7%\n", |
303 |
| - " 2x2 Element-wise (sin(A) + cos(B)) 0.000008 0.000003 0.000233 0.000105 -2830.3%\n", |
304 |
| - " 2x2 Broadcasting (A + B.T) 0.000007 0.000003 0.000253 0.000151 -3429.1%\n", |
305 |
| - " 4x4 Matrix Chain (A @ B @ C) 0.000011 0.000008 0.000285 0.000111 -2537.7%\n", |
306 |
| - " 4x4 Element-wise (sin(A) + cos(B)) 0.000007 0.000001 0.000235 0.000124 -3217.0%\n", |
307 |
| - " 4x4 Broadcasting (A + B.T) 0.000007 0.000002 0.000202 0.000077 -2755.8%\n", |
308 |
| - "2000x2000 Matrix Chain (A @ B @ C) 0.024714 0.000919 0.004166 0.003531 +83.1%\n", |
309 |
| - "2000x2000 Element-wise (sin(A) + cos(B)) 0.009464 0.000417 0.000844 0.000284 +91.1%\n", |
310 |
| - "2000x2000 Broadcasting (A + B.T) 0.000690 0.000022 0.000821 0.000093 -19.0%\n", |
311 |
| - "4000x4000 Matrix Chain (A @ B @ C) 0.196587 0.008780 0.027411 0.001132 +86.1%\n", |
312 |
| - "4000x4000 Element-wise (sin(A) + cos(B)) 0.037744 0.001247 0.003355 0.000467 +91.1%\n", |
313 |
| - "4000x4000 Broadcasting (A + B.T) 0.012233 0.000421 0.003323 0.000370 +72.8%\n" |
| 336 | + " 2x2 Matrix Chain (A @ B @ C) 0.000009 0.000002 0.000305 0.000299 -3213.5%\n", |
| 337 | + " 2x2 Element-wise (sin(A) + cos(B)) 0.000007 0.000002 0.000352 0.003757 -5078.0%\n", |
| 338 | + " 2x2 Broadcasting (A + B.T) 0.000007 0.000001 0.000188 0.000153 -2721.1%\n", |
| 339 | + " 4x4 Matrix Chain (A @ B @ C) 0.000009 0.000001 0.000209 0.000063 -2126.2%\n", |
| 340 | + " 4x4 Element-wise (sin(A) + cos(B)) 0.000007 0.000001 0.000180 0.000066 -2449.5%\n", |
| 341 | + " 4x4 Broadcasting (A + B.T) 0.000007 0.000003 0.000181 0.000065 -2564.1%\n", |
| 342 | + "1080x1080 Matrix Chain (A @ B @ C) 0.005951 0.000356 0.001355 0.000392 +77.2%\n", |
| 343 | + "1080x1080 Element-wise (sin(A) + cos(B)) 0.002820 0.000107 0.000432 0.000207 +84.7%\n", |
| 344 | + "1080x1080 Broadcasting (A + B.T) 0.000212 0.000035 0.000428 0.000206 -102.0%\n", |
| 345 | + "2080x2080 Matrix Chain (A @ B @ C) 0.027609 0.001255 0.004550 0.002528 +83.5%\n", |
| 346 | + "2080x2080 Element-wise (sin(A) + cos(B)) 0.010086 0.000417 0.001175 0.000350 +88.3%\n", |
| 347 | + "2080x2080 Broadcasting (A + B.T) 0.000856 0.000068 0.001124 0.000241 -31.2%\n", |
| 348 | + "3080x3080 Matrix Chain (A @ B @ C) 0.093115 0.003823 0.013649 0.000513 +85.3%\n", |
| 349 | + "3080x3080 Element-wise (sin(A) + cos(B)) 0.022586 0.000756 0.001930 0.000287 +91.5%\n", |
| 350 | + "3080x3080 Broadcasting (A + B.T) 0.002580 0.000161 0.001937 0.000257 +24.9%\n" |
314 | 351 | ]
|
315 | 352 | }
|
316 | 353 | ],
|
|
321 | 358 | },
|
322 | 359 | {
|
323 | 360 | "cell_type": "code",
|
324 |
| - "execution_count": 5, |
| 361 | + "execution_count": null, |
325 | 362 | "metadata": {},
|
326 | 363 | "outputs": [],
|
327 | 364 | "source": [
|
|
0 commit comments