@@ -434,6 +434,8 @@ def is_multi_tier(storage_type):
434434 with ops .control_dependencies (set_attr_ops + [self ._init_op ]):
435435 self ._initializer_op = control_flow_ops .no_op ()
436436
437+ self .create_init_op_for_restore (name , initial_value , invalid_key , rank )
438+
437439 self ._graph_element = self ._handle
438440 self ._cached_value = None
439441 if not context .executing_eagerly ():
@@ -444,6 +446,49 @@ def is_multi_tier(storage_type):
444446 def export (self ):
445447 return gen_kv_variable_ops .kv_resource_export (self ._handle , Tkeys = self ._invalid_key_type )
446448
449+
450+ def create_init_op_for_restore (self , name , initial_value , invalid_key , rank ):
451+ with ops .control_dependencies (None if self ._is_primary else [self ._primary ._init_op_for_restore ]):
452+ self ._initializer_for_restore = gen_kv_variable_ops .initialize_kv_variable_v2_op (
453+ self ._handle ,
454+ self ._primary ._handle ,
455+ variables ._try_guard_against_uninitialized_dependencies (name , initial_value ),
456+ ops .convert_to_tensor (invalid_key ),
457+ initial_num_buckets = config_pb2 .IsSetInitialized .NOT_SET_INITAILIZED ,
458+ slot_num = self ._slot_num ,
459+ shape = initial_value .get_shape ()[rank :],
460+ steps_to_live = self ._steps_to_live ,
461+ emb_index = self ._emb_index , block_num = self .block_num ,
462+ slot_index = self ._slot_index ,
463+ ht_type = self ._ht_type ,
464+ ht_partition_num = self ._ht_partition_num ,
465+ filter_freq = self ._filter_freq ,
466+ l2_weight_threshold = self ._l2_weight_threshold ,
467+ max_element_size = self ._max_element_size ,
468+ false_positive_probability = self ._false_positive_probability ,
469+ counter_type = self ._counter_type ,
470+ max_freq = 99999 ,
471+ layout = self ._layout ,
472+ storage_type = self ._storage_type ,
473+ storage_path = self ._storage_path ,
474+ storage_size = self ._storage_size ,
475+ default_value_dim = self ._default_value_dim ,
476+ default_value_no_permission = self ._default_value_no_permission ,
477+ record_freq = self ._record_freq ,
478+ record_version = self ._record_version ,
479+ embedding_variable_type = config_pb2 .EmbeddingVariableType .IMMUTABLE )
480+ set_attr_ops = []
481+ if self ._is_primary and self ._is_multi_tier :
482+ with ops .control_dependencies ([self ._initializer_for_restore ]):
483+ set_cache_op = gen_kv_variable_ops .kv_resource_init_cache_strategy_op (
484+ self ._handle ,
485+ cache_strategy = self ._storage_cache_strategy ,
486+ Tkeys = self ._invalid_key_type ,
487+ dtype = self ._dtype )
488+ set_attr_ops .append (set_cache_op )
489+ with ops .control_dependencies (set_attr_ops + [self ._initializer_for_restore ]):
490+ self ._init_op_for_restore = control_flow_ops .no_op ()
491+
447492 def need_counts (self ):
448493 return (self ._record_freq or (self ._filter_freq > 0 ) or self ._is_multi_tier )
449494 @property
@@ -482,6 +527,11 @@ def _init_from_proto(self, variable_def, import_scope=None):
482527 cache_op = op
483528 elif self ._initializer_op .type == "InitializeKvVariableOp" :
484529 init_op = self ._initializer_op
530+
531+ self ._init_op_for_restore = g .as_graph_element (
532+ ops .prepend_name_scope (
533+ variable_def .initialize_op_for_restore ,
534+ import_scope = import_scope ))
485535 self ._trainable = getattr (variable_def , "trainable" , True )
486536 if variable_def .snapshot_name :
487537 self ._cached_value = g .as_graph_element (
@@ -842,6 +892,8 @@ def to_proto(self, export_scope=None):
842892 if self ._save_slice_info :
843893 var_def .save_slice_info_def .MergeFrom (
844894 self ._save_slice_info .to_proto (export_scope = export_scope ))
895+ var_def .initialize_op_for_restore = ops .strip_name_scope (
896+ self ._init_op_for_restore .name , export_scope )
845897 return var_def
846898 else :
847899 return None
0 commit comments