@@ -106,7 +106,9 @@ def __init__(
106106 self ,
107107 * ,
108108 distribution : (
109- keras .distribution .DataParallel | keras .distribution .ModelParallel
109+ keras .distribution .DataParallel
110+ | keras .distribution .ModelParallel
111+ | None
110112 ) = None ,
111113 model_dir : str | None = None ,
112114 train_steps : int = 0 ,
@@ -128,10 +130,7 @@ def __init__(
128130 # This should be set before any layers are constructed and this is a
129131 # fallback in case the trainer binary doesn't already do this.
130132 if (
131- isinstance (
132- distribution ,
133- (keras .distribution .DataParallel , keras .distribution .ModelParallel ),
134- )
133+ distribution is not None
135134 and keras .distribution .distribution () != distribution
136135 ):
137136 if hasattr (distribution , "_auto_shard_dataset" ):
@@ -175,6 +174,7 @@ def __init__(
175174 ),
176175 ]
177176 else :
177+ self ._checkpoint_manager = None
178178 self ._train_callbacks = [
179179 keras .callbacks .TensorBoard (
180180 log_dir = os .path .join (model_dir , core .LOG_DIR ),
@@ -199,13 +199,13 @@ def __init__(
199199 ]
200200
201201 def _maybe_get_model_kws (
202- self , task : KerasTask , dataset : keras . Model
202+ self , task : KerasTask , dataset : tf . data . Dataset
203203 ) -> Mapping [str , Any ]:
204204 kws = {}
205205 if py_utils .has_argument (task .create_model , "input_shapes" ):
206- batch = next ( iter ( dataset ))
207- x , * _ = keras .utils .unpack_x_y_sample_weight (batch )
208- kws ["input_shapes" ]: keras .tree .map_structure (core .get_shape , x ) # pylint: disable=undefined-variable
206+ batch_spec = dataset . element_spec
207+ x , * _ = keras .utils .unpack_x_y_sample_weight (batch_spec )
208+ kws ["input_shapes" ] = keras .tree .map_structure (core .get_shape , x )
209209
210210 return kws
211211
0 commit comments