Skip to content

Commit 2829100

Browse files
committed
test: add jaxley benchmark
1 parent 99d2b63 commit 2829100

File tree

9 files changed

+447
-24
lines changed

9 files changed

+447
-24
lines changed

builddeps/requirements_lock_3_11.txt

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,10 @@ decorator==5.2.1 \
382382
--hash=sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360 \
383383
--hash=sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a
384384
# via gcsfs
385+
diffrax==0.7.0 \
386+
--hash=sha256:aa9645c40552f11a2d32042ef6b9fcd53c1f0f6bdbe32d37cb788669ca9910be \
387+
--hash=sha256:f3bcc578cd92a9ca86fc6f5a54c1de76c1ba62f74de69b56002624bf205316f1
388+
# via jaxley-mech
385389
dinosaur==1.3.5 \
386390
--hash=sha256:aa3830f66a7ceb5cb900689d9b0717100eea74ae4d04f206a9fa20408cee3dc9 \
387391
--hash=sha256:fd75996d62104d5c602a4f2643a1154268e6cd6ed9fd1c295aab679c6fba60b3
@@ -405,6 +409,14 @@ einops==0.8.1 \
405409
--hash=sha256:919387eb55330f5757c6bea9165c5ff5cfe63a642682ea788a6d472576d81737 \
406410
--hash=sha256:de5d960a7a761225532e0f1959e5315ebeafc0cd43394732f103ca44b9837e84
407411
# via jax-md
412+
equinox==0.13.2 \
413+
--hash=sha256:509ad744ff99b7c684d45230d6890f9e78eac1a556d7a06db1eff664a3cac74f \
414+
--hash=sha256:bc1ee687e4841945d8b776664403839639a05e2f2c02c1da353ff3386e0e43b0
415+
# via
416+
# diffrax
417+
# jaxley-mech
418+
# lineax
419+
# optimistix
408420
etils[epath]==1.13.0 \
409421
--hash=sha256:a5b60c71f95bcd2d43d4e9fb3dc3879120c1f60472bb5ce19f7a860b1d44f607 \
410422
--hash=sha256:d9cd4f40fbe77ad6613b7348a18132cc511237b6c076dbb89105c0b520a4c6bb
@@ -910,21 +922,34 @@ jaraco-functools==4.4.0 \
910922
--hash=sha256:9eec1e36f45c818d9bf307c8948eb03b2b56cd44087b3cdc989abca1f20b9176 \
911923
--hash=sha256:da21933b0417b89515562656547a77b4931f98176eb173644c0d35032a33d6bb
912924
# via cheroot
925+
jax[cpu]==0.8.2 \
926+
--hash=sha256:1a685ded06a8223a7b52e45e668e406049dbbead02873f2b5a4d881ba7b421ae \
927+
--hash=sha256:d0478c5dc74406441efcd25731166a65ee782f13c352fa72dc7d734351909355
928+
# via
929+
# -r builddeps/requirements.in
930+
# jaxley-mech
913931
jax[cuda12]==0.8.2 \
914932
--hash=sha256:1a685ded06a8223a7b52e45e668e406049dbbead02873f2b5a4d881ba7b421ae \
915933
--hash=sha256:d0478c5dc74406441efcd25731166a65ee782f13c352fa72dc7d734351909355
916934
# via
917935
# -r builddeps/requirements.in
918936
# chex
937+
# diffrax
919938
# dinosaur
920939
# e3nn-jax
940+
# equinox
921941
# flax
922942
# jax-md
943+
# jaxley
944+
# jaxley-mech
923945
# jraph
946+
# lineax
924947
# neuralgcm
925948
# optax
949+
# optimistix
926950
# orbax-checkpoint
927951
# tree-math
952+
# tridiax
928953
jax-cuda12-pjrt==0.8.2 \
929954
--hash=sha256:717a1b196a642409ce195ddf031c20bbeadcc886f55e49a1d3f4927373aeedae \
930955
--hash=sha256:e3bab41ca7c48e4163db9e7efd271b3aa85f0fe45f5ed0708d6bbed93a59f977
@@ -947,6 +972,16 @@ jax-md==0.2.27 \
947972
--hash=sha256:3506cf7c07b84d6c9cf09243097bef465c81122a23ca8cc78a3627c8b9d97322 \
948973
--hash=sha256:efbefa5089a995a5c02405a4c930ba42f8eaf9322482998b5a422e45f631a0ab
949974
# via -r builddeps/test-requirements.txt
975+
jaxley==0.13.0 \
976+
--hash=sha256:0d9247b340b402f974aad827e0cd79e32c5cd66d7295d95514792a108e15f00b \
977+
--hash=sha256:277f135714f1370b7246754d64687357ec443e3a944f1a96633dfd4eaaafcc3e
978+
# via
979+
# -r builddeps/test-requirements.txt
980+
# jaxley-mech
981+
jaxley-mech==0.3.1 \
982+
--hash=sha256:bd46cb2f02d1f76af56406ef83c464b6f9fc9742625cd88371a1923e14f601e8 \
983+
--hash=sha256:cc5eda21c8521e32795526f9f85ca52941899449b0a491d3ffdb321f3f0c8cbd
984+
# via -r builddeps/test-requirements.txt
950985
jaxlib==0.8.2 \
951986
--hash=sha256:023de6f3f56da2af7037970996500586331fdb50b530ecbb54b9666da633bd00 \
952987
--hash=sha256:05b958f497e49824c432e734bb059723b7dfe69e2ad696a9f9c8ad82fff7c3f8 \
@@ -980,6 +1015,14 @@ jaxlib==0.8.2 \
9801015
# jraph
9811016
# neuralgcm
9821017
# optax
1018+
jaxtyping==0.3.5 \
1019+
--hash=sha256:8150ad5b72b62fa63f573d492a79e9e455f070abe3b260f7dc15270b3eb9bba6 \
1020+
--hash=sha256:862c39fa2e526274e82dc96ee8dbe9369dadb651ab1e05d95bd685acb4e2ef02
1021+
# via
1022+
# diffrax
1023+
# equinox
1024+
# lineax
1025+
# optimistix
9831026
jmp==0.0.4 \
9841027
--hash=sha256:5dfeb0fd7c7a9f72a70fff0aab9d0cbfae32a809c02f4037ff3485ceb33e1730 \
9851028
--hash=sha256:6aa7adbddf2bd574b28c7faf6e81a735eb11f53386447896909c6968dc36807d
@@ -1097,6 +1140,12 @@ kiwisolver==1.4.9 \
10971140
--hash=sha256:fb940820c63a9590d31d88b815e7a3aa5915cad3ce735ab45f0c730b39547de1 \
10981141
--hash=sha256:fc1795ac5cd0510207482c3d1d3ed781143383b8cfd36f5c645f3897ce066220
10991142
# via matplotlib
1143+
lineax==0.0.8 \
1144+
--hash=sha256:1bd21d6c41afda233769d1c1096329ee75181825c9136be65c92b41f6daa1ddb \
1145+
--hash=sha256:bb2778066b8882acc88ff569d8e415bc5aa387f751b14ae262c9f9607d7f25bb
1146+
# via
1147+
# diffrax
1148+
# optimistix
11001149
locket==1.0.0 \
11011150
--hash=sha256:5c0d4c052a8bbbf750e056a8e65ccd309086f4f0f18a2eac306a8dfa4112a632 \
11021151
--hash=sha256:b6c819a722f7b6bd955b80781788e4a66a55628b858d347536b7e81325a3a5e3
@@ -1252,7 +1301,10 @@ matplotlib==3.10.8 \
12521301
--hash=sha256:f97aeb209c3d2511443f8797e3e5a569aebb040d4f8bc79aa3ee78a8fb9e3dd8 \
12531302
--hash=sha256:f9b587c9c7274c1613a30afabf65a272114cd6cdbe67b3406f818c79d7ab2e2a \
12541303
--hash=sha256:fb061f596dad3a0f52b60dc6a5dec4a0c300dec41e058a7efe09256188d170b7
1255-
# via pymatgen
1304+
# via
1305+
# jaxley
1306+
# jaxley-mech
1307+
# pymatgen
12561308
mdurl==0.1.2 \
12571309
--hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \
12581310
--hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba
@@ -1546,7 +1598,9 @@ nest-asyncio==1.6.0 \
15461598
networkx==3.6.1 \
15471599
--hash=sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509 \
15481600
--hash=sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762
1549-
# via pymatgen
1601+
# via
1602+
# jaxley
1603+
# pymatgen
15501604
neuralgcm==1.2.2 \
15511605
--hash=sha256:24edbbb5d21e2d17a7475738f84602885eb011af3a23c33df293b2c5d10ac11c \
15521606
--hash=sha256:795297260a5aff05708e855fe8cb27db7cc0f514e9c34e373e4ba378732327e5
@@ -1640,6 +1694,8 @@ numpy==2.1.3 \
16401694
# flax
16411695
# jax
16421696
# jax-md
1697+
# jaxley
1698+
# jaxley-mech
16431699
# jaxlib
16441700
# jmp
16451701
# jraph
@@ -1657,6 +1713,7 @@ numpy==2.1.3 \
16571713
# spglib
16581714
# tensorstore
16591715
# treescope
1716+
# tridiax
16601717
# xarray
16611718
# xarray-tensorstore
16621719
# zarr
@@ -1742,6 +1799,10 @@ optax==0.2.6 \
17421799
# flax
17431800
# jax-md
17441801
# neuralgcm
1802+
optimistix==0.0.11 \
1803+
--hash=sha256:acb4fb23b598db03e376900fcb61aee8dd511de41411e849661c0ffe9e4cd1c6 \
1804+
--hash=sha256:cfce0de98e7e9fdbcc2ce6d49a9f82cd3166fd0eee29c0c7a1983f8aefd37757
1805+
# via diffrax
17451806
orbax-checkpoint==0.11.31 \
17461807
--hash=sha256:b00e39cd61cbd6c7c78b091ccac0ed1bbf3cf7788e761618e7070761195bfcc0 \
17471808
--hash=sha256:f021193a619782655798bc4a285f40612f6fe647ddeb303d1f49cdbc5645e319
@@ -1906,6 +1967,8 @@ pandas==2.3.3 \
19061967
--hash=sha256:f8bfc0e12dc78f777f323f55c58649591b2cd0c43534e8355c51d3fede5f4dee
19071968
# via
19081969
# dinosaur
1970+
# jaxley
1971+
# jaxley-mech
19091972
# neuralgcm
19101973
# pymatgen
19111974
# xarray
@@ -2743,18 +2806,26 @@ treescope==0.1.10 \
27432806
--hash=sha256:20f74656f34ab2d8716715013e8163a0da79bdc2554c16d5023172c50d27ea95 \
27442807
--hash=sha256:dde52f5314f4c29d22157a6fe4d3bd103f9cae02791c9e672eefa32c9aa1da51
27452808
# via flax
2809+
tridiax==0.2.1 \
2810+
--hash=sha256:311b0ed41671303197e219019fb9d22d6b31c841ddf5fdd1ec2601e09ed4e750 \
2811+
--hash=sha256:95a8c6d003cdd694487c99e5ba2c43d4fb4dfbe3a3df96e9ac2c80c1c4aaecd1
2812+
# via jaxley
27462813
typing-extensions==4.15.0 \
27472814
--hash=sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466 \
27482815
--hash=sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548
27492816
# via
27502817
# aiosignal
27512818
# chex
2819+
# diffrax
2820+
# equinox
27522821
# etils
27532822
# flax
27542823
# flexcache
27552824
# flexparser
27562825
# grpcio
2826+
# lineax
27572827
# numcodecs
2828+
# optimistix
27582829
# orbax-checkpoint
27592830
# pint
27602831
# spglib
@@ -2771,6 +2842,13 @@ urllib3==2.6.2 \
27712842
--hash=sha256:016f9c98bb7e98085cb2b4b17b87d2c702975664e4f060c6532e64d1c1a5e797 \
27722843
--hash=sha256:ec21cddfe7724fc7cb4ba4bea7aa8e2ef36f607a4bab81aa6ce42a13dc3f03dd
27732844
# via requests
2845+
wadler-lindig==0.1.7 \
2846+
--hash=sha256:81d14d3fe77d441acf3ebd7f4aefac20c74128bf460e84b512806dccf7b2cd55 \
2847+
--hash=sha256:e3ec83835570fd0a9509f969162aeb9c65618f998b1f42918cfc8d45122fe953
2848+
# via
2849+
# diffrax
2850+
# equinox
2851+
# jaxtyping
27742852
werkzeug==3.1.4 \
27752853
--hash=sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905 \
27762854
--hash=sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e

builddeps/test-requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ protobuf >= 6
55

66
jax-md; sys_platform == 'linux'
77

8+
jaxley; sys_platform == 'linux'
9+
jaxley_mech; sys_platform == 'linux'
10+
811
# maxtext can't be installed concurrently, but installing it fixes
912
# https://github.com/wsmoses/maxtext/archive/bc50722be7d89e4003bd830b80e4ac968be658eb.tar.gz; python_version < "3.12"
1013
# maxtext; python_version < "3.13"

src/enzyme_ad/jax/Implementations/CHLODerivatives.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def IsInf : HLOInst<"IsInfOp">;
8080
def IsNegInf : HLOInst<"IsNegInfOp">;
8181
def IsPosInf : HLOInst<"IsPosInfOp">;
8282
def Lgamma : HLOInst<"LgammaOp">;
83+
def Square : HLOInst<"SquareOp">;
8384

8485
/// CHLO - broadcasting compare operation
8586
def BroadcastCompare : HLOInst<"BroadcastCompareOp">;
@@ -142,3 +143,7 @@ def : HLODerivative<"SinhOp", (Op $x), [(Mul (DiffeRet), (Cosh $x))]>;
142143
def : HLODerivative<"TanOp", (Op $x), [
143144
(Div (DiffeRet), (Mul (Cos $x), (Cos $x)))
144145
]>;
146+
147+
def : HLODerivative<"SquareOp", (Op $x), [
148+
(Mul (DiffeRet), (Mul (HLOConstantFP<"2"> $x), $x))
149+
]>;

test/BUILD

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,25 @@ py_test(
122122
deps = TEST_DEPS,
123123
)
124124

125+
py_test(
126+
name = "jaxley_test",
127+
timeout = "eternal",
128+
srcs = [
129+
"jaxley_test.py",
130+
"test_utils.py",
131+
"xprof_utils.py",
132+
],
133+
imports = ["."],
134+
tags = ["exclusive"],
135+
deps = TEST_DEPS + select({
136+
"@bazel_tools//src/conditions:linux_x86_64": [
137+
"@pypi_jaxley//:pkg",
138+
"@pypi_jaxley_mech//:pkg",
139+
],
140+
"//conditions:default": [],
141+
}),
142+
)
143+
125144
py_test(
126145
name = "jaxmd",
127146
timeout = "eternal",
@@ -192,6 +211,7 @@ test_suite(
192211
name = "python_tests",
193212
tests = [
194213
":bench_vs_xla",
214+
":jaxley_test",
195215
":jaxmd",
196216
":llama",
197217
":neuralgcm_test",

0 commit comments

Comments
 (0)