Skip to content

Commit dcc0f4a

Browse files
committed
fix bug of gfile compatibility
1 parent 8e8e3c7 commit dcc0f4a

File tree

11 files changed

+70
-68
lines changed

11 files changed

+70
-68
lines changed

easy_rec/python/core/metrics.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77

88
import numpy as np
99
import tensorflow as tf
10-
if tf.__version__.startswith('1.'):
11-
from tensorflow.python.platform import gfile
12-
else:
13-
import tensorflow.io.gfile as gfile
1410
from sklearn import metrics as sklearn_metrics
1511
from tensorflow.python.ops import array_ops
1612
from tensorflow.python.ops import math_ops
@@ -22,6 +18,11 @@
2218
from easy_rec.python.utils.io_util import save_data_to_json_path
2319
from easy_rec.python.utils.shape_utils import get_shape_list
2420

21+
if tf.__version__.startswith('1.'):
22+
from tensorflow.python.platform import gfile
23+
else:
24+
import tensorflow.io.gfile as gfile
25+
2526
if tf.__version__ >= '2.0':
2627
tf = tf.compat.v1
2728

easy_rec/python/core/sampler.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@
1313
import numpy as np
1414
import six
1515
import tensorflow as tf
16-
if tf.__version__.startswith('1.'):
17-
from tensorflow.python.platform import gfile
18-
else:
19-
import tensorflow.io.gfile as gfile
16+
2017
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
2118
from easy_rec.python.utils import ds_util
2219
from easy_rec.python.utils.config_util import process_multi_file_input_path
2320
from easy_rec.python.utils.tf_utils import get_tf_type
2421

22+
if tf.__version__.startswith('1.'):
23+
from tensorflow.python.platform import gfile
24+
else:
25+
import tensorflow.io.gfile as gfile
26+
2527

2628
# patch graph-learn string_attrs for utf-8
2729
@property

easy_rec/python/export.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55

66
import tensorflow as tf
77
from tensorflow.python.lib.io import file_io
8-
if tf.__version__.startswith('1.'):
9-
from tensorflow.python.platform import gfile
10-
else:
11-
import tensorflow.io.gfile as gfile
8+
129
from easy_rec.python.main import export
1310
from easy_rec.python.protos.train_pb2 import DistributionStrategy
1411
from easy_rec.python.utils import config_util
1512
from easy_rec.python.utils import estimator_utils
1613

14+
if tf.__version__.startswith('1.'):
15+
from tensorflow.python.platform import gfile
16+
else:
17+
import tensorflow.io.gfile as gfile
18+
1719
if tf.__version__ >= '2.0':
1820
tf = tf.compat.v1
1921

easy_rec/python/input/criteo_input.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
import logging
44

55
import tensorflow as tf
6+
7+
from easy_rec.python.input.criteo_binary_reader import BinaryDataset
8+
from easy_rec.python.input.input import Input
9+
610
if tf.__version__.startswith('1.'):
711
from tensorflow.python.platform import gfile
812
else:
913
import tensorflow.io.gfile as gfile
1014

11-
from easy_rec.python.input.criteo_binary_reader import BinaryDataset
12-
from easy_rec.python.input.input import Input
13-
1415
if tf.__version__ >= '2.0':
1516
tf = tf.compat.v1
1617

easy_rec/python/input/datahub_input.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55
import traceback
66

77
import tensorflow as tf
8-
if tf.__version__.startswith('1.'):
9-
from tensorflow.python.platform import gfile
10-
else:
11-
import tensorflow.io.gfile as gfile
128
from tensorflow.python.framework import dtypes
9+
1310
from easy_rec.python.input.input import Input
1411
from easy_rec.python.utils import odps_util
1512
from easy_rec.python.utils.config_util import parse_time
1613

14+
if tf.__version__.startswith('1.'):
15+
from tensorflow.python.platform import gfile
16+
else:
17+
import tensorflow.io.gfile as gfile
18+
1719
try:
1820
import common_io
1921
except Exception:

easy_rec/python/input/kafka_input.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77
import six
88
import tensorflow as tf
9+
10+
from easy_rec.python.input.input import Input
11+
from easy_rec.python.input.kafka_dataset import KafkaDataset
12+
from easy_rec.python.utils.config_util import parse_time
13+
914
if tf.__version__.startswith('1.'):
1015
from tensorflow.python.platform import gfile
1116
else:
1217
import tensorflow.io.gfile as gfile
13-
from easy_rec.python.input.input import Input
14-
from easy_rec.python.input.kafka_dataset import KafkaDataset
15-
from easy_rec.python.utils.config_util import parse_time
1618

1719
try:
1820
from kafka import KafkaConsumer, TopicPartition

easy_rec/python/input/odps_rtp_input_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import tensorflow as tf
77

88
from easy_rec.python.input.odps_rtp_input import OdpsRTPInput
9+
910
if tf.__version__.startswith('1.'):
1011
from tensorflow.python.platform import gfile
1112
else:

easy_rec/python/input/parquet_input.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@
66
import time
77

88
import tensorflow as tf
9-
if tf.__version__.startswith('1.'):
10-
from tensorflow.python.platform import gfile
11-
else:
12-
import tensorflow.io.gfile as gfile
139
from tensorflow.python.ops import array_ops
1410

1511
from easy_rec.python.compat import queues
1612
from easy_rec.python.input import load_parquet
1713
from easy_rec.python.input.input import Input
1814

15+
if tf.__version__.startswith('1.'):
16+
from tensorflow.python.platform import gfile
17+
else:
18+
import tensorflow.io.gfile as gfile
19+
1920

2021
class ParquetInput(Input):
2122

easy_rec/python/main.py

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@
1313

1414
import six
1515
import tensorflow as tf
16-
if tf.__version__.startswith('1.'):
17-
from tensorflow.python.platform import gfile
18-
else:
19-
import tensorflow.io.gfile as gfile
2016
from tensorflow.core.protobuf import saved_model_pb2
2117

2218
import easy_rec
@@ -68,18 +64,6 @@
6864
BestExporter = exporter.BestExporter
6965

7066

71-
def is_directory(path):
72-
if tf.__version__.startswith('1.'):
73-
return gfile.IsDirectory(path)
74-
if not gfile.exists(path):
75-
return False
76-
try:
77-
gfile.listdir(path)
78-
return True
79-
except:
80-
return False
81-
82-
8367
def _get_input_fn(data_config,
8468
feature_configs,
8569
data_path=None,
@@ -255,27 +239,27 @@ def _metric_cmp_fn(best_eval_result, current_eval_result):
255239

256240
def _check_model_dir(model_dir, continue_train):
257241
if not continue_train:
258-
if not is_directory(model_dir):
259-
gfile.MakeDirs(model_dir)
242+
if not tf.gfile.IsDirectory(model_dir):
243+
tf.gfile.MakeDirs(model_dir)
260244
else:
261-
assert len(gfile.Glob(model_dir + '/model.ckpt-*.meta')) == 0, \
245+
assert len(tf.gfile.Glob(model_dir + '/model.ckpt-*.meta')) == 0, \
262246
'model_dir[=%s] already exists and not empty(if you ' \
263247
'want to continue train on current model_dir please ' \
264248
'delete dir %s or specify --continue_train[internal use only])' % (
265249
model_dir, model_dir)
266250
else:
267-
if not is_directory(model_dir):
251+
if not tf.gfile.IsDirectory(model_dir):
268252
logging.info('%s does not exists, create it automatically' % model_dir)
269-
gfile.MakeDirs(model_dir)
253+
tf.gfile.MakeDirs(model_dir)
270254

271255

272256
def _get_ckpt_path(pipeline_config, checkpoint_path):
273257
if checkpoint_path != '' and checkpoint_path is not None:
274-
if is_directory(checkpoint_path):
258+
if tf.gfile.IsDirectory(checkpoint_path):
275259
ckpt_path = estimator_utils.latest_checkpoint(checkpoint_path)
276260
else:
277261
ckpt_path = checkpoint_path
278-
elif is_directory(pipeline_config.model_dir):
262+
elif tf.gfile.IsDirectory(pipeline_config.model_dir):
279263
ckpt_path = estimator_utils.latest_checkpoint(pipeline_config.model_dir)
280264
logging.info('checkpoint_path is not specified, '
281265
'will use latest checkpoint %s from %s' %
@@ -299,7 +283,8 @@ def train_and_evaluate(pipeline_config_path, continue_train=False):
299283
Returns:
300284
None, the model will be saved into pipeline_config.model_dir
301285
"""
302-
assert gfile.Exists(pipeline_config_path), 'pipeline_config_path not exists'
286+
assert tf.gfile.Exists(
287+
pipeline_config_path), 'pipeline_config_path not exists'
303288
pipeline_config = config_util.get_configs_from_pipeline_file(
304289
pipeline_config_path)
305290

@@ -338,7 +323,7 @@ def _train_and_evaluate_impl(pipeline_config,
338323
if estimator_utils.is_chief():
339324
_check_model_dir(pipeline_config.model_dir, continue_train)
340325
config_util.save_pipeline_config(pipeline_config, pipeline_config.model_dir)
341-
with gfile.GFile(version_file, 'w') as f:
326+
with tf.gfile.GFile(version_file, 'w') as f:
342327
f.write(easy_rec.__version__ + '\n')
343328

344329
train_steps = None
@@ -524,7 +509,7 @@ def evaluate(pipeline_config,
524509
model_dir = pipeline_config.model_dir
525510
eval_result_file = os.path.join(model_dir, eval_result_filename)
526511
logging.info('save eval result to file %s' % eval_result_file)
527-
with gfile.GFile(eval_result_file, 'w') as ofile:
512+
with tf.gfile.GFile(eval_result_file, 'w') as ofile:
528513
result_to_write = {}
529514
for key in sorted(eval_result):
530515
# skip logging binary data
@@ -577,10 +562,11 @@ def distribute_evaluate(pipeline_config,
577562
return eval_result
578563
model_dir = get_model_dir_path(pipeline_config)
579564
eval_tmp_results_dir = os.path.join(model_dir, 'distribute_eval_tmp_results')
580-
if not is_directory(eval_tmp_results_dir):
565+
if not tf.gfile.IsDirectory(eval_tmp_results_dir):
581566
logging.info('create eval tmp results dir {}'.format(eval_tmp_results_dir))
582-
gfile.MakeDirs(eval_tmp_results_dir)
583-
assert is_directory(eval_tmp_results_dir), 'tmp results dir not create success.'
567+
tf.gfile.MakeDirs(eval_tmp_results_dir)
568+
assert tf.gfile.IsDirectory(
569+
eval_tmp_results_dir), 'tmp results dir not create success.'
584570
os.environ['eval_tmp_results_dir'] = eval_tmp_results_dir
585571

586572
server_target = None
@@ -693,7 +679,7 @@ def distribute_evaluate(pipeline_config,
693679
if cur_job_name == 'master':
694680
print('eval_result = ', eval_result)
695681
logging.info('eval_result = {0}'.format(eval_result))
696-
with gfile.GFile(eval_result_file, 'w') as ofile:
682+
with tf.gfile.GFile(eval_result_file, 'w') as ofile:
697683
result_to_write = {'eval_method': 'distribute'}
698684
for key in sorted(eval_result):
699685
# skip logging binary data
@@ -780,8 +766,8 @@ def export(export_dir,
780766
AssertionError, if:
781767
* pipeline_config_path does not exist
782768
"""
783-
if not gfile.Exists(export_dir):
784-
gfile.MakeDirs(export_dir)
769+
if not tf.gfile.Exists(export_dir):
770+
tf.gfile.MakeDirs(export_dir)
785771

786772
pipeline_config = config_util.get_configs_from_pipeline_file(pipeline_config)
787773
if pipeline_config.fg_json_path:
@@ -844,10 +830,10 @@ def export(export_dir,
844830
]
845831
export_ts = export_ts[-1]
846832
saved_pb_path = os.path.join(final_export_dir, 'saved_model.pb')
847-
with gfile.GFile(saved_pb_path, 'rb') as fin:
833+
with tf.gfile.GFile(saved_pb_path, 'rb') as fin:
848834
saved_model.ParseFromString(fin.read())
849835
saved_model.meta_graphs[0].meta_info_def.meta_graph_version = export_ts
850-
with gfile.GFile(saved_pb_path, 'wb') as fout:
836+
with tf.gfile.GFile(saved_pb_path, 'wb') as fout:
851837
fout.write(saved_model.SerializeToString())
852838

853839
logging.info('model has been exported to %s successfully' % final_export_dir)

easy_rec/python/test/train_eval_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@
1212
import six
1313
import tensorflow as tf
1414
from distutils.version import LooseVersion
15-
if tf.__version__.startswith('1.'):
16-
from tensorflow.python.platform import gfile
17-
else:
18-
import tensorflow.io.gfile as gfile
15+
1916
from easy_rec.python.main import predict
2017
from easy_rec.python.utils import config_util
2118
from easy_rec.python.utils import constant
2219
from easy_rec.python.utils import estimator_utils
2320
from easy_rec.python.utils import test_utils
2421

22+
if tf.__version__.startswith('1.'):
23+
from tensorflow.python.platform import gfile
24+
else:
25+
import tensorflow.io.gfile as gfile
26+
2527
try:
2628
import graphlearn as gl
2729
except Exception:

0 commit comments

Comments
 (0)