Skip to content

Commit 2b54892

Browse files
Hilly12recml authors
authored andcommitted
Use the element spec to get the input shapes instead of getting a batch.
PiperOrigin-RevId: 764451548
1 parent 09c8411 commit 2b54892

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

recml/core/training/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,9 @@ def get_iterators(
174174

175175

176176
def get_shape(
177-
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor,
177+
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | tf.TensorSpec,
178178
) -> Sequence[int | None]:
179-
"""Gets the shape of a dense / sparse / ragged tensor."""
179+
"""Gets the shape of a dense / sparse / ragged tensor or tensor spec."""
180180
if isinstance(x, tf.SparseTensor):
181181
return [x.shape[0]] + [None for _ in x.shape[1:]]
182182
return x.shape.as_list()

recml/core/training/keras_trainer.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ def __init__(
106106
self,
107107
*,
108108
distribution: (
109-
keras.distribution.DataParallel | keras.distribution.ModelParallel
109+
keras.distribution.DataParallel
110+
| keras.distribution.ModelParallel
111+
| None
110112
) = None,
111113
model_dir: str | None = None,
112114
train_steps: int = 0,
@@ -128,10 +130,7 @@ def __init__(
128130
# This should be set before any layers are constructed and this is a
129131
# fallback in case the trainer binary doesn't already do this.
130132
if (
131-
isinstance(
132-
distribution,
133-
(keras.distribution.DataParallel, keras.distribution.ModelParallel),
134-
)
133+
distribution is not None
135134
and keras.distribution.distribution() != distribution
136135
):
137136
if hasattr(distribution, "_auto_shard_dataset"):
@@ -175,6 +174,7 @@ def __init__(
175174
),
176175
]
177176
else:
177+
self._checkpoint_manager = None
178178
self._train_callbacks = [
179179
keras.callbacks.TensorBoard(
180180
log_dir=os.path.join(model_dir, core.LOG_DIR),
@@ -199,13 +199,13 @@ def __init__(
199199
]
200200

201201
def _maybe_get_model_kws(
202-
self, task: KerasTask, dataset: keras.Model
202+
self, task: KerasTask, dataset: tf.data.Dataset
203203
) -> Mapping[str, Any]:
204204
kws = {}
205205
if py_utils.has_argument(task.create_model, "input_shapes"):
206-
batch = next(iter(dataset))
207-
x, *_ = keras.utils.unpack_x_y_sample_weight(batch)
208-
kws["input_shapes"]: keras.tree.map_structure(core.get_shape, x) # pylint: disable=undefined-variable
206+
batch_spec = dataset.element_spec
207+
x, *_ = keras.utils.unpack_x_y_sample_weight(batch_spec)
208+
kws["input_shapes"] = keras.tree.map_structure(core.get_shape, x)
209209

210210
return kws
211211

0 commit comments

Comments
 (0)