@@ -111,35 +111,47 @@ def get_ckpt_var_map(ckpt_path, ckpt_scope, var_scope, skip_mismatch=None):
111111 ckpt_var_name_to_shape = reader .get_variable_to_shape_map ()
112112 ckpt_var_names = set (reader .get_variable_to_shape_map ().keys ())
113113
114+ if tf .distribute .get_replica_context ():
115+ replica_id = tf .keras .backend .get_value (
116+ tf .distribute .get_replica_context ().replica_id_in_sync_group )
117+ else :
118+ replica_id = 0
119+
114120 for i , v in enumerate (model_vars ):
115- if not v .op .name .startswith (var_scope ):
121+ var_op_name = v .op .name
122+
123+ if replica_id >= 1 :
124+ var_op_name = '' .join (var_op_name .rsplit (f'/replica_{ replica_id } ' , 1 ))
125+
126+ if not var_op_name .startswith (var_scope ):
116127 logging .info ('skip {} -- does not match scope {}' .format (
117- v .op .name , var_scope ))
118- ckpt_var = ckpt_scope + v .op .name [len (var_scope ):]
128+ var_op_name , var_scope ))
129+ ckpt_var = ckpt_scope + var_op_name [len (var_scope ):]
130+
119131 if (ckpt_var not in ckpt_var_names and
120- v . op . name .endswith ('/ExponentialMovingAverage' )):
121- ckpt_var = ckpt_scope + v . op . name [:- len ('/ExponentialMovingAverage' )]
132+ var_op_name .endswith ('/ExponentialMovingAverage' )):
133+ ckpt_var = ckpt_scope + var_op_name [:- len ('/ExponentialMovingAverage' )]
122134
123135 if ckpt_var not in ckpt_var_names :
124136 if 'Momentum' in ckpt_var or 'RMSProp' in ckpt_var :
125137 # Skip optimizer variables.
126138 continue
127139 if skip_mismatch :
128- logging .info ('skip {} ({}) -- not in ckpt' .format (v . op . name , ckpt_var ))
140+ logging .info ('skip {} ({}) -- not in ckpt' .format (var_op_name , ckpt_var ))
129141 continue
130142 raise ValueError ('{} is not in ckpt {}' .format (v .op , ckpt_path ))
131143
132144 if v .shape != ckpt_var_name_to_shape [ckpt_var ]:
133145 if skip_mismatch :
134146 logging .info ('skip {} ({} vs {}) -- shape mismatch' .format (
135- v . op . name , v .shape , ckpt_var_name_to_shape [ckpt_var ]))
147+ var_op_name , v .shape , ckpt_var_name_to_shape [ckpt_var ]))
136148 continue
137149 raise ValueError ('shape mismatch {} ({} vs {})' .format (
138- v . op . name , v .shape , ckpt_var_name_to_shape [ckpt_var ]))
150+ var_op_name , v .shape , ckpt_var_name_to_shape [ckpt_var ]))
139151
140152 if i < 5 :
141153 # Log the first few elements for sanity check.
142- logging .info ('Init {} from ckpt var {}' .format (v . op . name , ckpt_var ))
154+ logging .info ('Init {} from ckpt var {}' .format (var_op_name , ckpt_var ))
143155 var_map [ckpt_var ] = v
144156
145157 return var_map
0 commit comments