11import hashlib
2+ import os
23import platform
34import subprocess
45import traceback
@@ -66,8 +67,9 @@ class DeprecatedKwargs(TypedDict):
6667 decimal : NotRequired [Optional [int ]]
6768
6869
69- # TODO: avoid unnecessary imports in enable_determinism
70- def enable_determinism (mode : Literal ["seed_only" , "full" ]):
70+ def enable_determinism (
71+ mode : Literal ["seed_only" , "full" ], weight_formats : Sequence [SupportedWeightsFormat ]
72+ ):
7173 """Seed and configure ML frameworks for maximum reproducibility.
7274 May degrade performance. Only recommended for testing reproducibility!
7375
@@ -93,39 +95,46 @@ def enable_determinism(mode: Literal["seed_only", "full"]):
9395 except Exception as e :
9496 logger .debug (str (e ))
9597
96- try :
98+ if "pytorch_state_dict" in weight_formats or "torchscript" in weight_formats :
9799 try :
98- import torch
99- except ImportError :
100- pass
101- else :
102- _ = torch .manual_seed (0 )
103- torch .use_deterministic_algorithms (mode == "full" )
104- except Exception as e :
105- logger .debug (str (e ))
100+ try :
101+ import torch
102+ except ImportError :
103+ pass
104+ else :
105+ _ = torch .manual_seed (0 )
106+ torch .use_deterministic_algorithms (mode == "full" )
107+ except Exception as e :
108+ logger .debug (str (e ))
106109
107- try :
110+ if (
111+ "tensorflow_saved_model_bundle" in weight_formats
112+ or "keras_hdf5" in weight_formats
113+ ):
108114 try :
109- import keras
110- except ImportError :
111- pass
112- else :
113- keras .utils .set_random_seed (0 )
114- except Exception as e :
115- logger .debug (str (e ))
116-
117- try :
115+ os .environ ["TF_ENABLE_ONEDNN_OPTS" ] = "0"
116+ try :
117+ import tensorflow as tf
118+ except ImportError :
119+ pass
120+ else :
121+ tf .random .set_seed (0 )
122+ if mode == "full" :
123+ tf .config .experimental .enable_op_determinism ()
124+ # TODO: find possibility to switch it off again??
125+ except Exception as e :
126+ logger .debug (str (e ))
127+
128+ if "keras_hdf5" in weight_formats :
118129 try :
119- import tensorflow as tf
120- except ImportError :
121- pass
122- else :
123- tf .random .set_seed (0 )
124- if mode == "full" :
125- tf .config .experimental .enable_op_determinism ()
126- # TODO: find possibility to switch it off again??
127- except Exception as e :
128- logger .debug (str (e ))
130+ try :
131+ import keras
132+ except ImportError :
133+ pass
134+ else :
135+ keras .utils .set_random_seed (0 )
136+ except Exception as e :
137+ logger .debug (str (e ))
129138
130139
131140def test_model (
@@ -390,7 +399,7 @@ def load_description_and_test(
390399 else :
391400 weight_formats = [weight_format ]
392401
393- enable_determinism (determinism )
402+ enable_determinism (determinism , weight_formats = weight_formats )
394403 for w in weight_formats :
395404 _test_model_inference (rd , w , devices , ** deprecated )
396405 if not isinstance (rd , v0_4 .ModelDescr ):
@@ -589,12 +598,14 @@ def get_ns(n: int):
589598
590599 resized_test_inputs = Sample (
591600 members = {
592- t .id : test_inputs .members [t .id ].resize_to (
593- {
594- aid : s
595- for (tid , aid ), s in input_target_sizes .items ()
596- if tid == t .id
597- },
601+ t .id : (
602+ test_inputs .members [t .id ].resize_to (
603+ {
604+ aid : s
605+ for (tid , aid ), s in input_target_sizes .items ()
606+ if tid == t .id
607+ },
608+ )
598609 )
599610 for t in model .inputs
600611 },
0 commit comments