Skip to content

Commit 01c0fbd

Browse files
committed
add default args to enable_determinism
1 parent 0d113c2 commit 01c0fbd

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

bioimageio/core/_resource_tests.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,25 @@ class DeprecatedKwargs(TypedDict):
6868

6969

7070
def enable_determinism(
71-
mode: Literal["seed_only", "full"], weight_formats: Sequence[SupportedWeightsFormat]
71+
mode: Literal["seed_only", "full"] = "full",
72+
weight_formats: Optional[Sequence[SupportedWeightsFormat]] = None,
7273
):
7374
"""Seed and configure ML frameworks for maximum reproducibility.
7475
May degrade performance. Only recommended for testing reproducibility!
7576
7677
Seed any random generators and (if **mode**=="full") request ML frameworks to use
7778
deterministic algorithms.
79+
80+
Args:
81+
mode: determinism mode
82+
- 'seed_only' -- only set seeds, or
83+
- 'full' determinsm features (might degrade performance or throw exceptions)
84+
weight_formats: Limit deep learning importing deep learning frameworks
85+
based on weight_formats.
86+
E.g. this allows to avoid importing tensorflow when testing with pytorch.
87+
7888
Notes:
79-
- **mode** == "full" might degrade performance and throw exceptions.
89+
- **mode** == "full" might degrade performance or throw exceptions.
8090
- Subsequent inference calls might still differ. Call before each function
8191
(sequence) that is expected to be reproducible.
8292
- Degraded performance: Use for testing reproducibility only!
@@ -95,7 +105,11 @@ def enable_determinism(
95105
except Exception as e:
96106
logger.debug(str(e))
97107

98-
if "pytorch_state_dict" in weight_formats or "torchscript" in weight_formats:
108+
if (
109+
weight_formats is None
110+
or "pytorch_state_dict" in weight_formats
111+
or "torchscript" in weight_formats
112+
):
99113
try:
100114
try:
101115
import torch
@@ -108,7 +122,8 @@ def enable_determinism(
108122
logger.debug(str(e))
109123

110124
if (
111-
"tensorflow_saved_model_bundle" in weight_formats
125+
weight_formats is None
126+
or "tensorflow_saved_model_bundle" in weight_formats
112127
or "keras_hdf5" in weight_formats
113128
):
114129
try:
@@ -125,7 +140,7 @@ def enable_determinism(
125140
except Exception as e:
126141
logger.debug(str(e))
127142

128-
if "keras_hdf5" in weight_formats:
143+
if weight_formats is None or "keras_hdf5" in weight_formats:
129144
try:
130145
try:
131146
import keras

0 commit comments

Comments
 (0)