Skip to content

Commit 37e0094

Browse files
Tweak conditional import of JAX.
1 parent ee07bc2 commit 37e0094

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

jaxtyping/_decorator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import dataclasses
2121
import functools as ft
22-
import importlib.util
2322
import inspect
2423
import itertools as it
2524
import sys
@@ -228,9 +227,10 @@ def f(...): ...
228227
if _tb_flag:
229228
try:
230229
import jax._src.traceback_util as traceback_util
231-
traceback_util.register_exclusion(__file__)
232-
except:
230+
except Exception:
233231
pass
232+
else:
233+
traceback_util.register_exclusion(__file__)
234234
_tb_flag = False
235235

236236
# First handle the `jaxtyped("context")` usage, which is a special case.

0 commit comments

Comments
 (0)