Skip to content

Commit b2e924d

Browse files
cetagostinijessegrabowski
authored andcommitted
Changes on the branch
1 parent 26a6d14 commit b2e924d

File tree

3 files changed

+55
-18
lines changed

3 files changed

+55
-18
lines changed

doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,38 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stdout",
10+
"output_type": "stream",
11+
"text": [
12+
"Obtaining file:///Users/carlostrujillo/Documents/GitHub/pytensor\n",
13+
" Installing build dependencies ... \u001b[?25ldone\n",
14+
"\u001b[?25h Checking if build backend supports build_editable ... \u001b[?25ldone\n",
15+
"\u001b[?25h Getting requirements to build editable ... \u001b[?25ldone\n",
16+
"\u001b[?25h Preparing editable metadata (pyproject.toml) ... \u001b[?25ldone\n",
17+
"\u001b[?25hBuilding wheels for collected packages: pytensor\n",
18+
" Building editable for pytensor (pyproject.toml) ... \u001b[?25ldone\n",
19+
"\u001b[?25h Created wheel for pytensor: filename=pytensor-2.31.7+80.g06ccf91ba.dirty-0.editable-cp312-cp312-macosx_11_0_arm64.whl size=7323 sha256=c09587a5f3141d49000666d2817c5a01436f13ff5a19aa3deda20f647660afee\n",
20+
" Stored in directory: /private/var/folders/f0/rbz8xs8s17n3k3f_ccp31bvh0000gn/T/pip-ephem-wheel-cache-i00nb67k/wheels/52/f6/4c/e6784e2203d5405c94db1d544248730e598e4397674416af05\n",
21+
"Successfully built pytensor\n",
22+
"Installing collected packages: pytensor\n",
23+
" Attempting uninstall: pytensor\n",
24+
" Found existing installation: pytensor 2.31.7+80.g06ccf91ba.dirty\n",
25+
" Uninstalling pytensor-2.31.7+80.g06ccf91ba.dirty:\n",
26+
" Successfully uninstalled pytensor-2.31.7+80.g06ccf91ba.dirty\n",
27+
"Successfully installed pytensor-2.31.7+80.g06ccf91ba.dirty\n",
28+
"Note: you may need to restart the kernel to use updated packages.\n"
29+
]
30+
}
31+
],
32+
"source": [
33+
"%pip install -e ../.. --no-deps"
34+
]
35+
},
336
{
437
"cell_type": "code",
538
"execution_count": 1,
@@ -88,7 +121,7 @@
88121
" \"\"\"Run comprehensive benchmark comparing PyTensor JAX vs MLX backends\"\"\"\n",
89122
" import pandas as pd\n",
90123
" \n",
91-
" sizes = [2, 4, 2000, 4000]\n",
124+
" sizes = [2, 4, 1080, 2080, 3080]\n",
92125
" results = []\n",
93126
" \n",
94127
" print(f\"Running benchmarks with N={N} repetitions per test...\")\n",
@@ -277,8 +310,9 @@
277310
"Running benchmarks with N=150 repetitions per test...\n",
278311
"Testing 2x2 matrices...\n",
279312
"Testing 4x4 matrices...\n",
280-
"Testing 2000x2000 matrices...\n",
281-
"Testing 4000x4000 matrices...\n"
313+
"Testing 1080x1080 matrices...\n",
314+
"Testing 2080x2080 matrices...\n",
315+
"Testing 3080x3080 matrices...\n"
282316
]
283317
}
284318
],
@@ -299,18 +333,21 @@
299333
"\n",
300334
"Benchmark Results over 150 repetitions:\n",
301335
" Size Operation PyTensor+JAX Mean (s) PyTensor+JAX Std (s) PyTensor+MLX Mean (s) PyTensor+MLX Std (s) MLX Performance\n",
302-
" 2x2 Matrix Chain (A @ B @ C) 0.000009 0.000002 0.000311 0.000266 -3277.7%\n",
303-
" 2x2 Element-wise (sin(A) + cos(B)) 0.000008 0.000003 0.000233 0.000105 -2830.3%\n",
304-
" 2x2 Broadcasting (A + B.T) 0.000007 0.000003 0.000253 0.000151 -3429.1%\n",
305-
" 4x4 Matrix Chain (A @ B @ C) 0.000011 0.000008 0.000285 0.000111 -2537.7%\n",
306-
" 4x4 Element-wise (sin(A) + cos(B)) 0.000007 0.000001 0.000235 0.000124 -3217.0%\n",
307-
" 4x4 Broadcasting (A + B.T) 0.000007 0.000002 0.000202 0.000077 -2755.8%\n",
308-
"2000x2000 Matrix Chain (A @ B @ C) 0.024714 0.000919 0.004166 0.003531 +83.1%\n",
309-
"2000x2000 Element-wise (sin(A) + cos(B)) 0.009464 0.000417 0.000844 0.000284 +91.1%\n",
310-
"2000x2000 Broadcasting (A + B.T) 0.000690 0.000022 0.000821 0.000093 -19.0%\n",
311-
"4000x4000 Matrix Chain (A @ B @ C) 0.196587 0.008780 0.027411 0.001132 +86.1%\n",
312-
"4000x4000 Element-wise (sin(A) + cos(B)) 0.037744 0.001247 0.003355 0.000467 +91.1%\n",
313-
"4000x4000 Broadcasting (A + B.T) 0.012233 0.000421 0.003323 0.000370 +72.8%\n"
336+
" 2x2 Matrix Chain (A @ B @ C) 0.000009 0.000002 0.000305 0.000299 -3213.5%\n",
337+
" 2x2 Element-wise (sin(A) + cos(B)) 0.000007 0.000002 0.000352 0.003757 -5078.0%\n",
338+
" 2x2 Broadcasting (A + B.T) 0.000007 0.000001 0.000188 0.000153 -2721.1%\n",
339+
" 4x4 Matrix Chain (A @ B @ C) 0.000009 0.000001 0.000209 0.000063 -2126.2%\n",
340+
" 4x4 Element-wise (sin(A) + cos(B)) 0.000007 0.000001 0.000180 0.000066 -2449.5%\n",
341+
" 4x4 Broadcasting (A + B.T) 0.000007 0.000003 0.000181 0.000065 -2564.1%\n",
342+
"1080x1080 Matrix Chain (A @ B @ C) 0.005951 0.000356 0.001355 0.000392 +77.2%\n",
343+
"1080x1080 Element-wise (sin(A) + cos(B)) 0.002820 0.000107 0.000432 0.000207 +84.7%\n",
344+
"1080x1080 Broadcasting (A + B.T) 0.000212 0.000035 0.000428 0.000206 -102.0%\n",
345+
"2080x2080 Matrix Chain (A @ B @ C) 0.027609 0.001255 0.004550 0.002528 +83.5%\n",
346+
"2080x2080 Element-wise (sin(A) + cos(B)) 0.010086 0.000417 0.001175 0.000350 +88.3%\n",
347+
"2080x2080 Broadcasting (A + B.T) 0.000856 0.000068 0.001124 0.000241 -31.2%\n",
348+
"3080x3080 Matrix Chain (A @ B @ C) 0.093115 0.003823 0.013649 0.000513 +85.3%\n",
349+
"3080x3080 Element-wise (sin(A) + cos(B)) 0.022586 0.000756 0.001930 0.000287 +91.5%\n",
350+
"3080x3080 Broadcasting (A + B.T) 0.002580 0.000161 0.001937 0.000257 +24.9%\n"
314351
]
315352
}
316353
],
@@ -321,7 +358,7 @@
321358
},
322359
{
323360
"cell_type": "code",
324-
"execution_count": 5,
361+
"execution_count": null,
325362
"metadata": {},
326363
"outputs": [],
327364
"source": [

pytensor/link/mlx/dispatch/blockwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pytensor.link.mlx.dispatch import mlx_funcify
44
from pytensor.tensor.blockwise import Blockwise
5-
from pytensor.tensor.signal.conv import Conv1d
5+
from pytensor.tensor.signal.conv import Convolve1d as Conv1d
66

77

88
def blockwise_conv1d(op, node, **kwargs):

pytensor/link/mlx/dispatch/signal/conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import mlx.core as mx
22

33
from pytensor.link.mlx.dispatch import mlx_funcify
4-
from pytensor.tensor.signal.conv import Conv1d
4+
from pytensor.tensor.signal.conv import Convolve1d as Conv1d
55

66

77
@mlx_funcify.register(Conv1d)

0 commit comments

Comments
 (0)