Skip to content

Commit 7199a13

Browse files
Remove tf.compat.v1 imports in model/
Some minor API changes, but not many. We do have to use some carve outs where we manually write out tf.compat.v1 before some API usages, but I expect these to go away as we hack more on the code base and fix more things up. Reviewers: orodley, ondrasej, virajbshah Reviewed By: ondrasej Pull Request: #332
1 parent 5a01de2 commit 7199a13

File tree

9 files changed

+40
-34
lines changed

9 files changed

+40
-34
lines changed

gematria/model/python/inference.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from gematria.model.python import model_base
2222
from gematria.model.python import training
2323
from gematria.proto import throughput_pb2
24-
import tensorflow.compat.v1 as tf
2524

2625

2726
def _get_num_instructions_in_block_with_throughput_proto(

gematria/model/python/loss_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from collections.abc import Sequence
1717

1818
from gematria.model.python import options
19-
import tensorflow.compat.v1 as tf
19+
import tensorflow as tf
2020
import tensorflow_probability as tfp
2121

2222

gematria/model/python/main_function.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ def main(_):
4747
from gematria.proto import throughput_pb2
4848
from gematria.utils.python import timer
4949
import numpy as np
50-
import tensorflow.compat.v1 as tf
51-
import tensorflow as tf2
50+
import tensorflow as tf
5251

5352
_ACTION = flags.DEFINE_enum_class(
5453
'gematria_action',
@@ -685,7 +684,7 @@ def _restore_model_from_checkpoint(
685684
checkpoint_file: str, model: model_base.ModelBase
686685
) -> None:
687686
"""Restores a model from a checkpoint."""
688-
checkpoint = tf2.train.Checkpoint(model)
687+
checkpoint = tf.train.Checkpoint(model)
689688
checkpoint.restore(checkpoint_file)
690689

691690

@@ -720,10 +719,10 @@ def run_gematria_model_from_command_line_flags(
720719
"""
721720
np.set_printoptions(edgeitems=_GEMATRIA_NUMPY_PRINT_EDGEITEMS.value)
722721
if _GEMATRIA_RANDOM_SEED.value >= 0:
723-
tf.random.set_random_seed(_GEMATRIA_RANDOM_SEED.value)
722+
tf.random.set_seed(_GEMATRIA_RANDOM_SEED.value)
724723
random.seed(_GEMATRIA_RANDOM_SEED.value)
725724
is_chief = _GEMATRIA_TRAINING_TASK.value == 0
726-
dev = tf.train.replica_device_setter(
725+
dev = tf.compat.v1.train.replica_device_setter(
727726
ps_tasks=_GEMATRIA_TRAINING_PS_TASKS.value
728727
)
729728
with tf.device(dev):
@@ -812,8 +811,8 @@ def run_gematria_model_from_command_line_flags(
812811
== io_options.ThroughputSelection.RANDOM
813812
)
814813

815-
checkpoint = tf2.train.Checkpoint(model)
816-
checkpoint_manager = tf2.train.CheckpointManager(
814+
checkpoint = tf.train.Checkpoint(model)
815+
checkpoint_manager = tf.train.CheckpointManager(
817816
checkpoint,
818817
_CHECKPOINT_DIR.value,
819818
_GEMATRIA_CHECKPOINT_MAX_TO_KEEP.value,
@@ -822,11 +821,11 @@ def run_gematria_model_from_command_line_flags(
822821
def checkpoint_model():
823822
checkpoint_manager.save()
824823

825-
train_summary_writer = tf2.summary.create_file_writer(
824+
train_summary_writer = tf.summary.create_file_writer(
826825
_GEMATRIA_SUMMARY_DIR.value
827826
)
828827

829-
with train_summary_writer.as_default(), tf2.summary.record_if(
828+
with train_summary_writer.as_default(), tf.summary.record_if(
830829
lambda: tf.math.equal(
831830
model.global_step % _GEMATRIA_SAVE_SUMMARIES_EPOCHS, 0
832831
)

gematria/model/python/model_base.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from gematria.utils.python import timer
4444
import numpy as np
4545
import scipy.stats
46-
import tensorflow.compat.v1 as tf
46+
import tensorflow as tf
4747
import tf_slim.evaluation
4848

4949
# The type used for TensorFlow feed_dict objects. The type we use here is
@@ -75,7 +75,7 @@ class AddBasicBlockError(Exception):
7575
"""The exception raised when adding a block to batch fails."""
7676

7777

78-
class SaveBestCheckpoint(tf.train.SessionRunHook):
78+
class SaveBestCheckpoint(tf.compat.v1.train.SessionRunHook):
7979
"""A run hook that saves top N models based on error values."""
8080

8181
def __init__(
@@ -100,8 +100,8 @@ def __init__(
100100
self._saver = tf.train.Saver(max_to_keep=max_to_keep, name='relative_mae')
101101
self._last_eval = math.inf
102102

103-
def before_run(self, run_context: ...) -> tf.train.SessionRunArgs:
104-
return tf.train.SessionRunArgs(
103+
def before_run(self, run_context: ...) -> tf.compat.v1.train.SessionRunArgs:
104+
return tf.compat.v1.train.SessionRunArgs(
105105
{'loss': self._error_tensor, 'global_step': self._global_step}
106106
)
107107

@@ -328,7 +328,7 @@ def __init__(
328328
)
329329
self._trained_variable_groups = tuple(trained_variable_groups or ())
330330

331-
self._global_step = tf.train.get_or_create_global_step()
331+
self._global_step = tf.compat.v1.train.get_or_create_global_step()
332332

333333
self._decayed_learning_rate = None
334334
self._loss: Optional[loss_utils.LossComputation] = None
@@ -604,61 +604,67 @@ def _create_optimizer(self) -> None:
604604
'must be great than zero.'
605605
)
606606
if self._learning_rate_schedule == options.LearningRateScheduleType.COSINE:
607-
self._decayed_learning_rate = tf.train.cosine_decay(**decay_args)
607+
self._decayed_learning_rate = tf.compat.v1.train.cosine_decay(
608+
**decay_args
609+
)
608610
elif (
609611
self._learning_rate_schedule
610612
== options.LearningRateScheduleType.EXPONENTIAL
611613
):
612-
self._decayed_learning_rate = tf.train.exponential_decay(
614+
self._decayed_learning_rate = tf.compat.v1.train.exponential_decay(
613615
**decay_args, **decay_rate_arg
614616
)
615617
elif (
616618
self._learning_rate_schedule
617619
== options.LearningRateScheduleType.INVERSE_TIME
618620
):
619-
self._decayed_learning_rate = tf.train.inverse_time_decay(
621+
self._decayed_learning_rate = tf.compat.v1.train.inverse_time_decay(
620622
**decay_args, **decay_rate_arg
621623
)
622624
elif (
623625
self._learning_rate_schedule
624626
== options.LearningRateScheduleType.LINEAR_COSINE
625627
):
626-
self._decayed_learning_rate = tf.train.linear_cosine_decay(**decay_args)
628+
self._decayed_learning_rate = tf.compat.v1.train.linear_cosine_decay(
629+
**decay_args
630+
)
627631
elif (
628632
self._learning_rate_schedule
629633
== options.LearningRateScheduleType.NATURAL_EXP
630634
):
631-
self._decayed_learning_rate = tf.train.natural_exp_decay(
635+
self._decayed_learning_rate = tf.compat.v1.train.natural_exp_decay(
632636
**decay_args, **decay_rate_arg
633637
)
634638
elif (
635639
self._learning_rate_schedule
636640
== options.LearningRateScheduleType.NOISY_LINEAR_COSINE
637641
):
638-
self._decayed_learning_rate = tf.train.noisy_linear_cosine_decay(
639-
**decay_args
642+
self._decayed_learning_rate = (
643+
tf.compat.v1.train.noisy_linear_cosine_decay(**decay_args)
640644
)
641645
elif (
642646
self._learning_rate_schedule
643647
== options.LearningRateScheduleType.POLYNOMIAL
644648
):
645-
self._decayed_learning_rate = tf.train.polynomial_decay(**decay_args)
649+
self._decayed_learning_rate = tf.compat.v1.train.polynomial_decay(
650+
**decay_args
651+
)
646652
else:
647653
assert (
648654
self._learning_rate_schedule == options.LearningRateScheduleType.NONE
649655
)
650656
self._decayed_learning_rate = self._learning_rate
651657

652658
if self._optimizer_type == options.OptimizerType.ADAM:
653-
self._optimizer = tf.train.AdamOptimizer(
659+
self._optimizer = tf.compat.v1.train.AdamOptimizer(
654660
learning_rate=self._decayed_learning_rate
655661
)
656662
elif self._optimizer_type == options.OptimizerType.SGD:
657-
self._optimizer = tf.train.GradientDescentOptimizer(
663+
self._optimizer = tf.compat.v1.train.GradientDescentOptimizer(
658664
learning_rate=self._decayed_learning_rate
659665
)
660666
elif self._optimizer_type == options.OptimizerType.RMSPROP:
661-
self._optimizer = tf.train.RMSPropOptimizer(
667+
self._optimizer = tf.compat.v1.train.RMSPropOptimizer(
662668
learning_rate=self._decayed_learning_rate
663669
)
664670
else:
@@ -681,7 +687,7 @@ def _create_optimizer(self) -> None:
681687

682688
def get_monitored_training_session_hooks(
683689
self,
684-
) -> Sequence[tf.train.SessionRunHook]:
690+
) -> Sequence[tf.compat.v1.train.SessionRunHook]:
685691
"""Returns hooks for a MonitoredTrainingSession required by the model."""
686692
hooks = []
687693
if isinstance(self._optimizer, tf.train.SyncReplicasOptimizer):
@@ -1063,7 +1069,9 @@ def run_continuous_evaluation(
10631069
tf_master: str = '',
10641070
eval_interval_seconds: int = 60,
10651071
max_num_evaluations: Optional[int] = None,
1066-
session_hooks: Optional[Sequence[tf.train.SessionRunHook]] = None,
1072+
session_hooks: Optional[
1073+
Sequence[tf.compat.v1.train.SessionRunHook]
1074+
] = None,
10671075
max_blocks_in_batch: Optional[int] = None,
10681076
max_instructions_in_batch: Optional[int] = None,
10691077
) -> None:

gematria/model/python/model_blocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Optional
1818

1919
import sonnet as snt
20-
import tensorflow.compat.v1 as tf
20+
import tensorflow as tf
2121
import tf_keras
2222

2323

gematria/model/python/token_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from gematria.model.python import model_base
2626
from gematria.model.python import oov_token_behavior
2727
import numpy as np
28-
import tensorflow.compat.v1 as tf
28+
import tensorflow as tf
2929

3030
_OutOfVocabularyTokenBehavior = oov_token_behavior.OutOfVocabularyTokenBehavior
3131

gematria/model/python/token_model_flags.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from absl import flags
2323
from gematria.model.python import oov_token_behavior
2424
from gematria.utils.python import flag_utils
25-
import tensorflow.compat.v1 as tf
25+
import tensorflow as tf
2626

2727
_TOKEN_FILE = flags.DEFINE_string(
2828
'gematria_tokens_file',

gematria/model/python/token_model_flags_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from absl import flags
1616
from absl.testing import flagsaver
1717
from gematria.model.python import token_model_flags
18-
import tensorflow.compat.v1 as tf
18+
import tensorflow as tf
1919

2020
FLAGS = flags.FLAGS
2121

gematria/model/python/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from gematria.basic_block.python import basic_block
2323
from gematria.basic_block.python import throughput
2424
import numpy as np
25-
import tensorflow.compat.v1 as tf
25+
import tensorflow as tf
2626

2727

2828
@dataclasses.dataclass(frozen=True)

0 commit comments

Comments
 (0)