@@ -41,50 +41,42 @@ def save(
4141 )
4242
4343 def save_weights (
44- self ,
45- filepath ,
46- overwrite = True ,
47- save_format = None ,
48- options = None ,
44+ self , filepath , overwrite = True , save_format = None , options = None ,
4945 ):
5046 with file_util .save_file (filepath ) as path :
5147 super ().save_weights (filepath = path , overwrite = overwrite , save_format = save_format , options = options )
5248
5349 def load_weights (
54- self ,
55- filepath ,
56- by_name = False ,
57- skip_mismatch = False ,
58- options = None ,
50+ self , filepath , by_name = False , skip_mismatch = False , options = None ,
5951 ):
6052 with file_util .read_file (filepath ) as path :
6153 super ().load_weights (filepath = path , by_name = by_name , skip_mismatch = skip_mismatch , options = options )
6254
55+ @property
56+ def metrics (self ):
57+ if not hasattr (self , "_tfasr_metrics" ):
58+ self ._tfasr_metrics = {}
59+ return list (self ._tfasr_metrics .values ())
60+
6361 def add_metric (
64- self ,
65- metric : tf .keras .metrics .Metric ,
62+ self , metric : tf .keras .metrics .Metric ,
6663 ):
67- if not hasattr (self , "_metrics " ):
68- self ._metrics = {}
69- self ._metrics [metric .name ] = metric
64+ if not hasattr (self , "_tfasr_metrics " ):
65+ self ._tfasr_metrics = {}
66+ self ._tfasr_metrics [metric .name ] = metric
7067
7168 def make (self , * args , ** kwargs ):
7269 """Custom function for building model (uses self.build so cannot overwrite that function)"""
7370 raise NotImplementedError ()
7471
7572 def compile (
76- self ,
77- loss ,
78- optimizer ,
79- run_eagerly = None ,
80- ** kwargs ,
73+ self , loss , optimizer , run_eagerly = None , ** kwargs ,
8174 ):
8275 self .use_loss_scale = False
8376 if not env_util .has_devices ("TPU" ):
8477 optimizer = mxp .experimental .LossScaleOptimizer (tf .keras .optimizers .get (optimizer ), "dynamic" )
8578 self .use_loss_scale = True
86- loss_metric = tf .keras .metrics .Mean (name = "loss" , dtype = tf .float32 )
87- self .add_metric (loss_metric )
79+ self .add_metric (metric = tf .keras .metrics .Mean (name = "loss" , dtype = tf .float32 ))
8880 super ().compile (optimizer = optimizer , loss = loss , run_eagerly = run_eagerly , ** kwargs )
8981
9082 # -------------------------------- STEP FUNCTIONS -------------------------------------
@@ -110,8 +102,8 @@ def train_step(self, batch):
110102 else :
111103 gradients = tape .gradient (loss , self .trainable_weights )
112104 self .optimizer .apply_gradients (zip (gradients , self .trainable_variables ))
113- self ._metrics ["loss" ].update_state (loss )
114- return {m .name : m .result () for m in self ._metrics . values () }
105+ self ._tfasr_metrics ["loss" ].update_state (loss )
106+ return {m .name : m .result () for m in self .metrics }
115107
116108 def test_step (self , batch ):
117109 """
@@ -125,8 +117,8 @@ def test_step(self, batch):
125117 inputs , y_true = batch
126118 y_pred = self (inputs , training = False )
127119 loss = self .loss (y_true , y_pred )
128- self ._metrics ["loss" ].update_state (loss )
129- return {m .name : m .result () for m in self ._metrics . values () }
120+ self ._tfasr_metrics ["loss" ].update_state (loss )
121+ return {m .name : m .result () for m in self .metrics }
130122
131123 def predict_step (self , batch ):
132124 """
0 commit comments