11import os
2- import tf2onnx
3- import onnx
42from keras .callbacks import Callback
53
64import logging
@@ -22,22 +20,29 @@ def __init__(
2220 self .metadata = metadata
2321
2422 def on_train_end (self , logs = None ):
25- self .model .load_weights (self .saved_model_path )
26- self .onnx_model_path = self .saved_model_path .replace (".h5" , ".onnx" )
27- tf2onnx .convert .from_keras (self .model , output_path = self .onnx_model_path )
28-
29- if self .metadata and isinstance (self .metadata , dict ):
30- # Load the ONNX model
31- onnx_model = onnx .load (self .onnx_model_path )
32-
33- # Add the metadata dictionary to the model's metadata_props attribute
34- for key , value in self .metadata .items ():
35- meta = onnx_model .metadata_props .add ()
36- meta .key = key
37- meta .value = value
38-
39- # Save the modified ONNX model
40- onnx .save (onnx_model , self .onnx_model_path )
23+ """ Converts the model to onnx format after training is finished. """
24+ try :
25+ import onnx
26+ import tf2onnx
27+ self .model .load_weights (self .saved_model_path )
28+ self .onnx_model_path = self .saved_model_path .replace (".h5" , ".onnx" )
29+ tf2onnx .convert .from_keras (self .model , output_path = self .onnx_model_path )
30+
31+ if self .metadata and isinstance (self .metadata , dict ):
32+ # Load the ONNX model
33+ onnx_model = onnx .load (self .onnx_model_path )
34+
35+ # Add the metadata dictionary to the model's metadata_props attribute
36+ for key , value in self .metadata .items ():
37+ meta = onnx_model .metadata_props .add ()
38+ meta .key = key
39+ meta .value = value
40+
41+ # Save the modified ONNX model
42+ onnx .save (onnx_model , self .onnx_model_path )
43+
44+ except Exception as e :
45+ print (e )
4146
4247
4348class TrainLogger (Callback ):
0 commit comments