11# -*- coding: utf-8 -*-
22import os
3- os .environ ['TF_CPP_MIN_LOG_LEVEL' ] = '3'
43import abc
54import json
65import pickle
98import datetime
109
1110import numpy as np
12- import tensorflow as tf
13- tf .compat .v1 .logging .set_verbosity (tf .compat .v1 .logging .ERROR )
14- from tensorflow .keras .utils import Progbar
15- # what the...
11+ import tensorboard as tb
1612import absl .logging
1713logging .root .removeHandler (absl .logging ._absl_handler )
1814absl .logging ._warn_preinit_stderr = False
@@ -331,18 +327,15 @@ def instantiate(cls, cls_opt, path, data_fields):
331327 return c
332328
333329
334- class TensorboardExtention (object ):
330+ class TensorboardExtension (object ):
335331 @abc .abstractmethod
336332 def get_evaluation_metrics (self ):
337333 raise NotImplementedError
338334
339335 def _get_initial_tensorboard_data (self ):
340336 tb = aux .Option ({'summary_writer' : None ,
341337 'name' : None ,
342- 'metrics' : {},
343- 'feed_dict' : {},
344- 'merged_summary_op' : None ,
345- 'session' : None ,
338+ 'metrics' : [],
346339 'pbar' : None ,
347340 'data_root' : None ,
348341 'step' : 1 })
@@ -352,41 +345,28 @@ def initialize_tensorboard(self, num_steps, name_prefix='', name_postfix='', met
352345 if not self .opt .tensorboard :
353346 if not hasattr (self , '_tb_setted' ):
354347 self .logger .debug ('Cannot find tensorboard configuration.' )
355- self .tb_setted = False
348+ self ._tb_setted = False
356349 return
357350 name = self .opt .tensorboard .name
358351 name = name_prefix + name + name_postfix
359352 dtm = datetime .datetime .now ().strftime ('%Y%m%d-%H.%M' )
360353 template = self .opt .tensorboard .get ('name_template' , '{name}.{dtm}' )
361354 self ._tb = self ._get_initial_tensorboard_data ()
362355 self ._tb .name = template .format (name = name , dtm = dtm )
363- if not os .path .isdir (self .opt .tensorboard .root ):
364- os .makedirs (self .opt .tensorboard .root )
356+ os .makedirs (self .opt .tensorboard .root , exist_ok = True )
365357 tb_dir = os .path .join (self .opt .tensorboard .root , self ._tb .name )
366358 self ._tb .data_root = tb_dir
367- self ._tb .summary_writer = tf .summary .FileWriter (tb_dir )
368- if not metrics :
369- metrics = self .get_evaluation_metrics ()
370- for m in metrics :
371- self ._tb .metrics [m ] = tf .placeholder (tf .float32 )
372- tf .summary .scalar (m , self ._tb .metrics [m ])
373- self ._tb .feed_dict [self ._tb .metrics [m ]] = 0.0
374- self ._tb .merged_summary_op = tf .summary .merge_all ()
375- self ._tb .session = tf .Session ()
376- self ._tb .pbar = Progbar (num_steps , stateful_metrics = self ._tb .metrics , verbose = 0 )
359+ self ._tb .summary_writer = tb .summary .Writer (tb_dir )
360+ self ._tb .metrics = metrics if metrics is not None else self .get_evaluation_metrics ()
377361 self ._tb_setted = True
378362
379363 def update_tensorboard_data (self , metrics ):
380364 if not self .opt .tensorboard :
381365 return
382- metrics = [(m , np .float32 (metrics .get (m , 0.0 )))
383- for m in self ._tb .metrics .keys ()]
384- self ._tb .feed_dict = {self ._tb .metrics [k ]: v
385- for k , v in metrics }
386- summary = self ._tb .session .run (self ._tb .merged_summary_op ,
387- feed_dict = self ._tb .feed_dict )
388- self ._tb .summary_writer .add_summary (summary , self ._tb .step )
389- self ._tb .pbar .update (self ._tb .step , metrics )
366+ for m in self ._tb .metrics :
367+ v = metrics .get (m , 0.0 )
368+ self ._tb .summary_writer .add_scalar (m , v , self ._tb .step )
369+ self ._tb .summary_writer .flush ()
390370 self ._tb .step += 1
391371
392372 def finalize_tensorboard (self ):
@@ -395,6 +375,4 @@ def finalize_tensorboard(self):
395375 with open (os .path .join (self ._tb .data_root , 'opt.json' ), 'w' ) as fout :
396376 fout .write (json .dumps (self .opt , indent = 2 ))
397377 self ._tb .summary_writer .close ()
398- self ._tb .session .close ()
399378 self ._tb = None
400- tf .reset_default_graph ()
0 commit comments