Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/quick_start/local_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pip install tensorflow_probability==0.5.0
常见版本对应关系:

| TensorFlow版本 | TensorFlowProbability版本 |
|--------------|-------------------------|
| ------------ | ----------------------- |
| 1.12 | 0.5.0 |
| 1.15 | 0.8.0 |
| 2.5.0 | 0.13.0 |
Expand Down
7 changes: 2 additions & 5 deletions easy_rec/python/tools/split_model_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework.dtypes import _TYPE_TO_STRING
from tensorflow.python.ops.resource_variable_ops import _from_proto_fn
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as tf_saver
Expand All @@ -18,7 +19,6 @@
if tf.__version__ >= '2.0':
tf = tf.compat.v1
from tensorflow.python.saved_model.path_helpers import get_variables_path
from tensorflow.python.ops.resource_variable_ops import _from_proto_fn
else:
from tensorflow.python.saved_model.utils_impl import get_variables_path

Expand Down Expand Up @@ -207,10 +207,7 @@ def export(model_dir, meta_graph_def, variable_protos, input_tensor_names,
graph = ops.get_default_graph()
importer.import_graph_def(inference_graph, name='')
for name in variables_to_keep:
if tf.__version__ >= '2.0':
variable = _from_proto_fn(variable_protos[name.split(':')[0]])
else:
variable = graph.get_tensor_by_name(name)
variable = _from_proto_fn(variable_protos[name.split(':')[0]])
graph.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, variable)
saver = tf_saver.Saver()
saver.restore(sess, get_variables_path(model_dir))
Expand Down