@@ -13,17 +13,18 @@ class DeepEval():
1313 """
1414 def __init__ (self ,
1515 model_file ,
16- load_prefix = 'load' ) :
17- self .graph = self ._load_graph (model_file , prefix = load_prefix )
16+ load_prefix = 'load' ,
17+ default_tf_graph = False ) :
18+ self .graph = self ._load_graph (model_file , prefix = load_prefix , default_tf_graph = default_tf_graph )
1819 t_mt = self .graph .get_tensor_by_name (os .path .join (load_prefix , 'model_attr/model_type:0' ))
1920 sess = tf .Session (graph = self .graph , config = default_tf_session_config )
2021 [mt ] = sess .run ([t_mt ], feed_dict = {})
2122 self .model_type = mt .decode ('utf-8' )
2223
2324 def _load_graph (self ,
24- frozen_graph_filename ,
25- prefix = 'load' ,
26- default_tf_graph = True ):
25+ frozen_graph_filename ,
26+ prefix = 'load' ,
27+ default_tf_graph = False ):
2728 # We load the protobuf file from the disk and parse it to retrieve the
2829 # unserialized graph_def
2930 with tf .gfile .GFile (frozen_graph_filename , "rb" ) as f :
@@ -102,8 +103,9 @@ def __init__(self,
102103 model_file ,
103104 variable_name ,
104105 variable_dof ,
105- load_prefix = 'load' ) :
106- DeepEval .__init__ (self , model_file , load_prefix = load_prefix )
106+ load_prefix = 'load' ,
107+ default_tf_graph = False ) :
108+ DeepEval .__init__ (self , model_file , load_prefix = load_prefix , default_tf_graph = default_tf_graph )
107109 # self.model_file = model_file
108110 # self.graph = self.load_graph (self.model_file)
109111 self .variable_name = variable_name
0 commit comments