Skip to content

Commit a1b6f5c

Browse files
committed
add enable_determinism
1 parent b8bad0b commit a1b6f5c

File tree

3 files changed

+79
-1
lines changed

3 files changed

+79
-1
lines changed

bioimageio/core/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@
3333
tensor,
3434
)
3535
from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline
36-
from ._resource_tests import load_description_and_test, test_description, test_model
36+
from ._resource_tests import (
37+
enable_determinism,
38+
load_description_and_test,
39+
test_description,
40+
test_model,
41+
)
3742
from ._settings import settings
3843
from .axis import Axis, AxisId
3944
from .block_meta import BlockMeta
@@ -71,6 +76,7 @@
7176
"create_prediction_pipeline",
7277
"digest_spec",
7378
"dump_description",
79+
"enable_determinism",
7480
"io",
7581
"load_dataset_description",
7682
"load_description_and_test",

bioimageio/core/_resource_tests.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,67 @@
3232
from .utils import VERSION
3333

3434

35+
def enable_determinism(mode: Literal["seed_only", "full"]):
36+
"""Seed and configure ML frameworks for maximum reproducibility.
37+
May degrade performance. Only recommended for testing reproducibility!
38+
39+
Seed any random generators and (if **mode**=="full") request ML frameworks to use
40+
deterministic algorithms.
41+
Notes:
42+
- **mode** == "full" might degrade performance and throw exceptions.
43+
- Subsequent inference calls might still differ. Call before each function
44+
(sequence) that is expected to be reproducible.
45+
- Degraded performance: Use for testing reproducibility only!
46+
- Recipes:
47+
- [PyTorch](https://pytorch.org/docs/stable/notes/randomness.html)
48+
- [Keras](https://keras.io/examples/keras_recipes/reproducibility_recipes/)
49+
- [NumPy](https://numpy.org/doc/2.0/reference/random/generated/numpy.random.seed.html)
50+
"""
51+
try:
52+
try:
53+
import numpy.random
54+
except ImportError:
55+
pass
56+
else:
57+
numpy.random.seed(0)
58+
except Exception as e:
59+
logger.debug(str(e))
60+
61+
try:
62+
try:
63+
import torch
64+
except ImportError:
65+
pass
66+
else:
67+
_ = torch.manual_seed(0)
68+
torch.use_deterministic_algorithms(mode == "full")
69+
except Exception as e:
70+
logger.debug(str(e))
71+
72+
try:
73+
try:
74+
import keras
75+
except ImportError:
76+
pass
77+
else:
78+
keras.utils.set_random_seed(0)
79+
except Exception as e:
80+
logger.debug(str(e))
81+
82+
try:
83+
try:
84+
import tensorflow as tf # pyright: ignore[reportMissingImports]
85+
except ImportError:
86+
pass
87+
else:
88+
tf.random.seed(0)
89+
if mode == "full":
90+
tf.config.experimental.enable_op_determinism()
91+
# TODO: find possibility to switch it off again??
92+
except Exception as e:
93+
logger.debug(str(e))
94+
95+
3596
def test_model(
3697
source: Union[v0_5.ModelDescr, PermissiveFileSource],
3798
weight_format: Optional[WeightsFormat] = None,

tests/test_resource_tests.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
1+
from typing import Literal
2+
3+
import pytest
4+
15
from bioimageio.spec import InvalidDescr
26

37

8+
@pytest.mark.parametrize("mode", ["seed_only", "full"])
9+
def test_enable_determinism(mode: Literal["seed_only", "full"]):
10+
from bioimageio.core import enable_determinism
11+
12+
enable_determinism(mode)
13+
14+
415
def test_error_for_wrong_shape(stardist_wrong_shape: str):
516
from bioimageio.core._resource_tests import test_model
617

0 commit comments

Comments
 (0)