@@ -68,15 +68,25 @@ class DeprecatedKwargs(TypedDict):
6868
6969
7070def enable_determinism (
71- mode : Literal ["seed_only" , "full" ], weight_formats : Sequence [SupportedWeightsFormat ]
71+ mode : Literal ["seed_only" , "full" ] = "full" ,
72+ weight_formats : Optional [Sequence [SupportedWeightsFormat ]] = None ,
7273):
7374 """Seed and configure ML frameworks for maximum reproducibility.
7475 May degrade performance. Only recommended for testing reproducibility!
7576
7677 Seed any random generators and (if **mode**=="full") request ML frameworks to use
7778 deterministic algorithms.
79+
80+ Args:
81+ mode: determinism mode
82+ - 'seed_only' -- only set seeds, or
83+ - 'full' determinsm features (might degrade performance or throw exceptions)
84+ weight_formats: Limit deep learning importing deep learning frameworks
85+ based on weight_formats.
86+ E.g. this allows to avoid importing tensorflow when testing with pytorch.
87+
7888 Notes:
79- - **mode** == "full" might degrade performance and throw exceptions.
89+ - **mode** == "full" might degrade performance or throw exceptions.
8090 - Subsequent inference calls might still differ. Call before each function
8191 (sequence) that is expected to be reproducible.
8292 - Degraded performance: Use for testing reproducibility only!
@@ -95,7 +105,11 @@ def enable_determinism(
95105 except Exception as e :
96106 logger .debug (str (e ))
97107
98- if "pytorch_state_dict" in weight_formats or "torchscript" in weight_formats :
108+ if (
109+ weight_formats is None
110+ or "pytorch_state_dict" in weight_formats
111+ or "torchscript" in weight_formats
112+ ):
99113 try :
100114 try :
101115 import torch
@@ -108,7 +122,8 @@ def enable_determinism(
108122 logger .debug (str (e ))
109123
110124 if (
111- "tensorflow_saved_model_bundle" in weight_formats
125+ weight_formats is None
126+ or "tensorflow_saved_model_bundle" in weight_formats
112127 or "keras_hdf5" in weight_formats
113128 ):
114129 try :
@@ -125,7 +140,7 @@ def enable_determinism(
125140 except Exception as e :
126141 logger .debug (str (e ))
127142
128- if "keras_hdf5" in weight_formats :
143+ if weight_formats is None or "keras_hdf5" in weight_formats :
129144 try :
130145 try :
131146 import keras
0 commit comments