Skip to content

Commit 8d8e16a

Browse files
authored
[Embedding] Fix set initialized flag too early in restore subgraph. (#920)
Signed-off-by: lixy9474 <[email protected]>
1 parent f09e5ec commit 8d8e16a

File tree

9 files changed

+158
-26
lines changed

9 files changed

+158
-26
lines changed

tensorflow/core/framework/embedding/config.proto

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,7 @@ enum ValuePosition {
5656
IN_DRAM = 0;
5757
NOT_IN_DRAM = 1;
5858
}
59+
60+
enum IsSetInitialized {
61+
NOT_SET_INITAILIZED = 0;
62+
}

tensorflow/core/framework/embedding/multi_tier_storage.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@ class MultiTierStorage : public Storage<K, V> {
8181
}
8282

8383
void InitCache(embedding::CacheStrategy cache_strategy) override {
84-
cache_ = CacheFactory::Create<K>(cache_strategy, name_);
85-
eviction_manager_ = EvictionManagerCreator::Create<K, V>();
86-
eviction_manager_->AddStorage(this);
87-
cache_thread_pool_ = CacheThreadPoolCreator::Create();
84+
if (cache_ == nullptr) {
85+
cache_ = CacheFactory::Create<K>(cache_strategy, name_);
86+
eviction_manager_ = EvictionManagerCreator::Create<K, V>();
87+
eviction_manager_->AddStorage(this);
88+
cache_thread_pool_ = CacheThreadPoolCreator::Create();
89+
}
8890
}
8991

9092
Status BatchCommit(const std::vector<K>& keys,

tensorflow/core/framework/variable.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ message VariableDef {
7474

7575
// EmebddingVariable
7676
bool is_embedding_var = 91;
77+
78+
string initialize_op_for_restore = 92;
7779
}
7880

7981
message SaveSliceInfoDef {

tensorflow/core/kernels/kv_variable_ops.cc

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,6 @@ limitations under the License.
4343

4444
namespace tensorflow {
4545

46-
namespace {
47-
const int64 kEmbeddingVarUseDB = -214;
48-
const int64 kInitializableEmbeddingVarUseDB = -215;
49-
}
50-
5146
Status MoveMatchingFiles(
5247
Env* env,
5348
const tstring& pattern,
@@ -207,6 +202,15 @@ class InitializeKvVariableOp : public OpKernel {
207202
(embedding_var_type ==
208203
embedding::EmbeddingVariableType::IMMUTABLE);
209204

205+
//initial_num_buckets is useless, so is used to set is_set_initialized_.
206+
int64 initial_num_buckets = 0;
207+
OP_REQUIRES_OK(c, c->GetAttr("initial_num_buckets", &initial_num_buckets));
208+
is_set_initialized_ = true;
209+
if (initial_num_buckets ==
210+
embedding::IsSetInitialized::NOT_SET_INITAILIZED) {
211+
is_set_initialized_ = false;
212+
}
213+
210214
int64 storage_type = 0;
211215
OP_REQUIRES_OK(c, c->GetAttr("storage_type", &storage_type));
212216
storage_type_ = static_cast<embedding::StorageType>(storage_type);
@@ -263,15 +267,10 @@ class InitializeKvVariableOp : public OpKernel {
263267
" should be DRAM when layout is 'compact'."));
264268
}
265269

266-
if (steps_to_live_ == kEmbeddingVarUseDB ||
267-
steps_to_live_ == kInitializableEmbeddingVarUseDB) {
268-
LOG(INFO) << "hashmap use db";
269-
//use_db_ = true;
270-
} else {
271-
OP_REQUIRES(c, steps_to_live_ >= 0,
272-
errors::InvalidArgument(
270+
OP_REQUIRES(c, steps_to_live_ >= 0,
271+
errors::InvalidArgument(
273272
"steps_to_live must >= 0, ", std::to_string(steps_to_live_)));
274-
}
273+
275274
OP_REQUIRES_OK(c, c->GetAttr("ht_type", &ht_type_));
276275
if (embedding::StorageType::LEVELDB == storage_type_) {
277276
ht_type_ = "leveldb_kv";
@@ -406,7 +405,7 @@ class InitializeKvVariableOp : public OpKernel {
406405
core::ScopedUnref unref_me(primary_variable);
407406
}
408407
core::ScopedUnref unref_me(ev);
409-
if (steps_to_live_ != kEmbeddingVarUseDB) {
408+
if (is_set_initialized_) {
410409
ev->SetInitialized();
411410
}
412411
}
@@ -436,6 +435,7 @@ class InitializeKvVariableOp : public OpKernel {
436435
bool record_freq_;
437436
bool record_version_;
438437
bool is_inference_;
438+
bool is_set_initialized_;
439439
};
440440

441441
#define REGISTER_KERNELS(ktype, vtype) \

tensorflow/python/ops/embedding_variable_ops_test.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2751,5 +2751,70 @@ def testCPUFbjOptWithBloomFilter(self):
27512751
self.assertNotEqual(val, 1.0)
27522752
del os.environ["TF_EMBEDDING_FBJ_OPT"]
27532753

2754+
def testSetInitializedWithoutRestore(self):
2755+
print("testSetInitializedWithoutRestore")
2756+
with ops.device("/cpu:0"):
2757+
var = variable_scope.get_embedding_variable("var_1",
2758+
embedding_dim = 3)
2759+
emb = embedding_ops.embedding_lookup(var, math_ops.cast([1], dtypes.int64))
2760+
fun = math_ops.multiply(emb, 2.0, name='multiply')
2761+
loss = math_ops.reduce_sum(fun, name='reduce_sum')
2762+
gs = training_util.get_or_create_global_step()
2763+
opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs)
2764+
g_v = opt.compute_gradients(loss)
2765+
train_op = opt.apply_gradients(g_v)
2766+
init = variables.global_variables_initializer()
2767+
saver = saver_module.Saver()
2768+
with self.test_session() as sess:
2769+
result = sess.run(var._is_initialized_op)
2770+
self.assertEqual(False, result)
2771+
sess.run([init])
2772+
result = sess.run(var._is_initialized_op)
2773+
self.assertEqual(True, result)
2774+
2775+
def testSetInitializedWithRestore(self):
2776+
print("testSetInitializedWitRestore")
2777+
checkpoint_directory = self.get_temp_dir()
2778+
ckpt_path = os.path.join(checkpoint_directory, "model.ckpt")
2779+
with ops.Graph().as_default() as g, ops.device('/cpu:0'):
2780+
var = variable_scope.get_embedding_variable("var_1",
2781+
embedding_dim = 3)
2782+
emb = embedding_ops.embedding_lookup(var, math_ops.cast([1,2 ,3], dtypes.int64))
2783+
fun = math_ops.multiply(emb, 2.0, name='multiply')
2784+
loss = math_ops.reduce_sum(fun, name='reduce_sum')
2785+
gs = training_util.get_or_create_global_step()
2786+
opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs)
2787+
g_v = opt.compute_gradients(loss)
2788+
train_op = opt.apply_gradients(g_v)
2789+
saver = saver_module.Saver()
2790+
init = variables.global_variables_initializer()
2791+
with self.test_session(graph=g) as sess:
2792+
sess.run([init])
2793+
sess.run(train_op)
2794+
saver.save(sess, ckpt_path)
2795+
2796+
with ops.Graph().as_default() as g, ops.device('/cpu:0'):
2797+
var = variable_scope.get_embedding_variable("var_1",
2798+
embedding_dim = 3)
2799+
emb = embedding_ops.embedding_lookup(var, math_ops.cast([1, 2, 3], dtypes.int64))
2800+
fun = math_ops.multiply(emb, 2.0, name='multiply')
2801+
loss = math_ops.reduce_sum(fun, name='reduce_sum')
2802+
gs = training_util.get_or_create_global_step()
2803+
opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs)
2804+
g_v = opt.compute_gradients(loss)
2805+
train_op = opt.apply_gradients(g_v)
2806+
saver = saver_module.Saver()
2807+
init = variables.global_variables_initializer()
2808+
with self.test_session(graph=g) as sess:
2809+
result = sess.run(var._is_initialized_op)
2810+
self.assertEqual(False, result)
2811+
sess.run([var._initializer_for_restore])
2812+
result = sess.run(var._is_initialized_op)
2813+
self.assertEqual(False, result)
2814+
2815+
saver.restore(sess, ckpt_path)
2816+
result = sess.run(var._is_initialized_op)
2817+
self.assertEqual(True, result)
2818+
27542819
if __name__ == "__main__":
27552820
googletest.main()

tensorflow/python/ops/kv_variable_ops.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,8 @@ def is_multi_tier(storage_type):
434434
with ops.control_dependencies(set_attr_ops + [self._init_op]):
435435
self._initializer_op = control_flow_ops.no_op()
436436

437+
self.create_init_op_for_restore(name, initial_value, invalid_key, rank)
438+
437439
self._graph_element = self._handle
438440
self._cached_value = None
439441
if not context.executing_eagerly():
@@ -444,6 +446,49 @@ def is_multi_tier(storage_type):
444446
def export(self):
445447
return gen_kv_variable_ops.kv_resource_export(self._handle, Tkeys=self._invalid_key_type)
446448

449+
450+
def create_init_op_for_restore(self, name, initial_value, invalid_key, rank):
451+
with ops.control_dependencies(None if self._is_primary else [self._primary._init_op_for_restore]):
452+
self._initializer_for_restore = gen_kv_variable_ops.initialize_kv_variable_v2_op(
453+
self._handle,
454+
self._primary._handle,
455+
variables._try_guard_against_uninitialized_dependencies(name, initial_value),
456+
ops.convert_to_tensor(invalid_key),
457+
initial_num_buckets=config_pb2.IsSetInitialized.NOT_SET_INITAILIZED,
458+
slot_num = self._slot_num,
459+
shape=initial_value.get_shape()[rank:],
460+
steps_to_live=self._steps_to_live,
461+
emb_index=self._emb_index, block_num=self.block_num,
462+
slot_index=self._slot_index,
463+
ht_type=self._ht_type,
464+
ht_partition_num=self._ht_partition_num,
465+
filter_freq = self._filter_freq,
466+
l2_weight_threshold = self._l2_weight_threshold,
467+
max_element_size = self._max_element_size,
468+
false_positive_probability = self._false_positive_probability,
469+
counter_type = self._counter_type,
470+
max_freq = 99999,
471+
layout = self._layout,
472+
storage_type = self._storage_type,
473+
storage_path = self._storage_path,
474+
storage_size = self._storage_size,
475+
default_value_dim = self._default_value_dim,
476+
default_value_no_permission = self._default_value_no_permission,
477+
record_freq = self._record_freq,
478+
record_version = self._record_version,
479+
embedding_variable_type=config_pb2.EmbeddingVariableType.IMMUTABLE)
480+
set_attr_ops = []
481+
if self._is_primary and self._is_multi_tier:
482+
with ops.control_dependencies([self._initializer_for_restore]):
483+
set_cache_op = gen_kv_variable_ops.kv_resource_init_cache_strategy_op(
484+
self._handle,
485+
cache_strategy=self._storage_cache_strategy,
486+
Tkeys=self._invalid_key_type,
487+
dtype=self._dtype)
488+
set_attr_ops.append(set_cache_op)
489+
with ops.control_dependencies(set_attr_ops + [self._initializer_for_restore]):
490+
self._init_op_for_restore = control_flow_ops.no_op()
491+
447492
def need_counts(self):
448493
return (self._record_freq or (self._filter_freq > 0) or self._is_multi_tier)
449494
@property
@@ -482,6 +527,11 @@ def _init_from_proto(self, variable_def, import_scope=None):
482527
cache_op = op
483528
elif self._initializer_op.type == "InitializeKvVariableOp":
484529
init_op = self._initializer_op
530+
531+
self._init_op_for_restore = g.as_graph_element(
532+
ops.prepend_name_scope(
533+
variable_def.initialize_op_for_restore,
534+
import_scope=import_scope))
485535
self._trainable = getattr(variable_def, "trainable", True)
486536
if variable_def.snapshot_name:
487537
self._cached_value = g.as_graph_element(
@@ -842,6 +892,8 @@ def to_proto(self, export_scope=None):
842892
if self._save_slice_info:
843893
var_def.save_slice_info_def.MergeFrom(
844894
self._save_slice_info.to_proto(export_scope=export_scope))
895+
var_def.initialize_op_for_restore = ops.strip_name_scope(
896+
self._init_op_for_restore.name, export_scope)
845897
return var_def
846898
else:
847899
return None

tensorflow/python/training/optimizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,7 @@ def _get_processor(v):
243243
if v.op.type == "KvVarHandleOp":
244244
from tensorflow.core.framework import attr_value_pb2
245245
from tensorflow.core.framework.embedding import config_pb2
246-
v._init_op._set_attr("embedding_variable_type",
247-
attr_value_pb2.AttrValue(i=config_pb2.EmbeddingVariableType.MUTABLE))
246+
slot_creator._set_init_op_embedding_type_attr(v, config_pb2.EmbeddingVariableType.MUTABLE)
248247
return _DenseResourceVariableProcessor(v)
249248
if isinstance(v, variables.Variable):
250249
return _RefVariableProcessor(v)

tensorflow/python/training/saving/saveable_object_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def restore(self, restored_tensors, unused_restored_shapes):
195195
if self.var._init_data_source is not None:
196196
return self.var.recover_from_init_data_source(self.var._init_data_source, self.partition_id, self.partition_num)
197197
else:
198-
with ops.control_dependencies([self.var._initializer_op]):
198+
with ops.control_dependencies([self.var._init_op_for_restore]):
199199
rank = self.op.initial_value.get_shape().rank - 1
200200
restore_op = gen_kv_variable_ops.kv_resource_import_v3(
201201
restored_tensors[0],

tensorflow/python/training/slot_creator.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,7 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype, slot_con
9494
validate_shape=validate_shape,
9595
steps_to_live=primary._steps_to_live,
9696
ht_partition_num=primary._ht_partition_num)
97-
slot._init_op._set_attr("embedding_variable_type",
98-
attr_value_pb2.AttrValue(i=config_pb2.EmbeddingVariableType.MUTABLE))
97+
_set_init_op_embedding_type_attr(slot, config_pb2.EmbeddingVariableType.MUTABLE)
9998
else:
10099
filter_strategy = None
101100
if primary._filter_freq != 0:
@@ -107,7 +106,7 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype, slot_con
107106
else:
108107
filter_strategy = variables.CounterFilter(filter_freq=primary._filter_freq)
109108
if slot_config.slot_type is config_pb2.SlotType.EMBEDDING_VARIABLE:
110-
primary._init_op._set_attr("slot_num", attr_value_pb2.AttrValue(i=slot_config.slot_num))
109+
_set_init_op_slot_num_attr(primary, slot_config.slot_num)
111110
primary._slot_num = slot_config.slot_num
112111
emb_index = primary._emb_index
113112
if primary.block_num > 1:
@@ -132,8 +131,7 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype, slot_con
132131
l2_weight_threshold=primary._l2_weight_threshold,
133132
filter_strategy=filter_strategy)
134133
)
135-
slot._init_op._set_attr("embedding_variable_type",
136-
attr_value_pb2.AttrValue(i=config_pb2.EmbeddingVariableType.MUTABLE))
134+
_set_init_op_embedding_type_attr(slot, config_pb2.EmbeddingVariableType.MUTABLE)
137135
else:
138136
slot = variable_scope.get_variable(
139137
scope,
@@ -300,3 +298,13 @@ def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True, slo
300298
return create_slot(primary, val, name,
301299
colocate_with_primary=colocate_with_primary,
302300
slot_config=slot_config)
301+
302+
def _set_init_op_embedding_type_attr(var, embedding_type):
303+
var._init_op._set_attr("embedding_variable_type",
304+
attr_value_pb2.AttrValue(i=embedding_type))
305+
var._initializer_for_restore._set_attr("embedding_variable_type",
306+
attr_value_pb2.AttrValue(i=embedding_type))
307+
308+
def _set_init_op_slot_num_attr(var, slot_num):
309+
var._init_op._set_attr("slot_num", attr_value_pb2.AttrValue(i=slot_num))
310+
var._initializer_for_restore._set_attr("slot_num", attr_value_pb2.AttrValue(i=slot_num))

0 commit comments

Comments
 (0)