Skip to content

Commit 36b2ac9

Browse files
committed
Remove rewrite exclusion in JAX mode
1 parent 3ed2c49 commit 36b2ac9

File tree

1 file changed

+0
-2
lines changed

1 file changed

+0
-2
lines changed

pytensor/compile/mode.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,13 +452,11 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
452452
JAXLinker(),
453453
RewriteDatabaseQuery(
454454
include=["fast_run", "jax"],
455-
# TODO: "local_uint_constant_indices" can be reintroduced once https://github.com/google/jax/issues/16836 is fixed.
456455
exclude=[
457456
"cxx_only",
458457
"BlasOpt",
459458
"fusion",
460459
"inplace",
461-
"local_uint_constant_indices",
462460
],
463461
),
464462
)

0 commit comments

Comments
 (0)