Skip to content

Commit 0d113c2

Browse files
committed
avoid unnecessary imports in enable_determinism
1 parent 6ec6451 commit 0d113c2

File tree

1 file changed

+49
-38
lines changed

1 file changed

+49
-38
lines changed

bioimageio/core/_resource_tests.py

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import hashlib
2+
import os
23
import platform
34
import subprocess
45
import traceback
@@ -66,8 +67,9 @@ class DeprecatedKwargs(TypedDict):
6667
decimal: NotRequired[Optional[int]]
6768

6869

69-
# TODO: avoid unnecessary imports in enable_determinism
70-
def enable_determinism(mode: Literal["seed_only", "full"]):
70+
def enable_determinism(
71+
mode: Literal["seed_only", "full"], weight_formats: Sequence[SupportedWeightsFormat]
72+
):
7173
"""Seed and configure ML frameworks for maximum reproducibility.
7274
May degrade performance. Only recommended for testing reproducibility!
7375
@@ -93,39 +95,46 @@ def enable_determinism(mode: Literal["seed_only", "full"]):
9395
except Exception as e:
9496
logger.debug(str(e))
9597

96-
try:
98+
if "pytorch_state_dict" in weight_formats or "torchscript" in weight_formats:
9799
try:
98-
import torch
99-
except ImportError:
100-
pass
101-
else:
102-
_ = torch.manual_seed(0)
103-
torch.use_deterministic_algorithms(mode == "full")
104-
except Exception as e:
105-
logger.debug(str(e))
100+
try:
101+
import torch
102+
except ImportError:
103+
pass
104+
else:
105+
_ = torch.manual_seed(0)
106+
torch.use_deterministic_algorithms(mode == "full")
107+
except Exception as e:
108+
logger.debug(str(e))
106109

107-
try:
110+
if (
111+
"tensorflow_saved_model_bundle" in weight_formats
112+
or "keras_hdf5" in weight_formats
113+
):
108114
try:
109-
import keras
110-
except ImportError:
111-
pass
112-
else:
113-
keras.utils.set_random_seed(0)
114-
except Exception as e:
115-
logger.debug(str(e))
116-
117-
try:
115+
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
116+
try:
117+
import tensorflow as tf
118+
except ImportError:
119+
pass
120+
else:
121+
tf.random.set_seed(0)
122+
if mode == "full":
123+
tf.config.experimental.enable_op_determinism()
124+
# TODO: find possibility to switch it off again??
125+
except Exception as e:
126+
logger.debug(str(e))
127+
128+
if "keras_hdf5" in weight_formats:
118129
try:
119-
import tensorflow as tf
120-
except ImportError:
121-
pass
122-
else:
123-
tf.random.set_seed(0)
124-
if mode == "full":
125-
tf.config.experimental.enable_op_determinism()
126-
# TODO: find possibility to switch it off again??
127-
except Exception as e:
128-
logger.debug(str(e))
130+
try:
131+
import keras
132+
except ImportError:
133+
pass
134+
else:
135+
keras.utils.set_random_seed(0)
136+
except Exception as e:
137+
logger.debug(str(e))
129138

130139

131140
def test_model(
@@ -390,7 +399,7 @@ def load_description_and_test(
390399
else:
391400
weight_formats = [weight_format]
392401

393-
enable_determinism(determinism)
402+
enable_determinism(determinism, weight_formats=weight_formats)
394403
for w in weight_formats:
395404
_test_model_inference(rd, w, devices, **deprecated)
396405
if not isinstance(rd, v0_4.ModelDescr):
@@ -589,12 +598,14 @@ def get_ns(n: int):
589598

590599
resized_test_inputs = Sample(
591600
members={
592-
t.id: test_inputs.members[t.id].resize_to(
593-
{
594-
aid: s
595-
for (tid, aid), s in input_target_sizes.items()
596-
if tid == t.id
597-
},
601+
t.id: (
602+
test_inputs.members[t.id].resize_to(
603+
{
604+
aid: s
605+
for (tid, aid), s in input_target_sizes.items()
606+
if tid == t.id
607+
},
608+
)
598609
)
599610
for t in model.inputs
600611
},

0 commit comments

Comments
 (0)