Skip to content

Commit 03e1d28

Browse files
RecML authorsqinyiyan
authored andcommitted
Sync copybara
PiperOrigin-RevId: 759293017
1 parent f5e68e2 commit 03e1d28

File tree

12 files changed

+128
-100
lines changed

12 files changed

+128
-100
lines changed

recml/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,30 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
"""Public API for RecML."""
15+
16+
# pylint: disable=g-importing-member
17+
18+
from recml.core import data
19+
from recml.core import metrics
20+
from recml.core import utils
21+
from recml.core.metrics.base_metrics import Metric
22+
from recml.core.training.core import Experiment
23+
from recml.core.training.core import run_experiment
24+
from recml.core.training.core import Trainer
25+
from recml.core.training.jax_trainer import JaxState
26+
from recml.core.training.jax_trainer import JaxTask
27+
from recml.core.training.jax_trainer import JaxTrainer
28+
from recml.core.training.jax_trainer import KerasState
29+
from recml.core.training.keras_trainer import KerasTask
30+
from recml.core.training.keras_trainer import KerasTrainer
31+
from recml.core.training.optax_factory import AdagradFactory
32+
from recml.core.training.optax_factory import AdamFactory
33+
from recml.core.training.optax_factory import OptimizerFactory
34+
from recml.core.training.partitioning import DataParallelPartitioner
35+
from recml.core.training.partitioning import ModelParallelPartitioner
36+
from recml.core.training.partitioning import NullPartitioner
37+
from recml.core.training.partitioning import Partitioner
38+
from recml.core.utils.types import Factory
39+
from recml.core.utils.types import FactoryProtocol
40+
from recml.core.utils.types import ObjectFactory

recml/core/__init__.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

recml/core/data/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2024 RecML authors <[email protected]>.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Public API for RecML data."""
15+
16+
# pylint: disable=g-importing-member
17+
18+
from recml.core.data.iterator import Iterator
19+
from recml.core.data.iterator import TFDatasetIterator
20+
from recml.core.data.preprocessing import PreprocessingMode
21+
from recml.core.data.tf_dataset_factory import DatasetShardingInfo
22+
from recml.core.data.tf_dataset_factory import TFDatasetFactory
23+
from recml.core.data.tf_dataset_factory import TFDSMetadata

recml/core/data/iterator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import tensorflow as tf
2424

2525

26-
DatasetIterator = clu_data.DatasetIterator
26+
Iterator = clu_data.DatasetIterator
2727

2828

2929
class TFDatasetIterator(clu_data.DatasetIterator):

recml/core/metrics/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Tools for MLRX metrics."""
14+
"""Tools for RecML metrics."""
1515

1616
from collections.abc import Mapping
1717
import concurrent.futures

recml/core/training/core.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@
4646
tf.data.Dataset,
4747
tuple[tf.data.Dataset, tf.data.Dataset],
4848
tuple[tf.data.Dataset, Mapping[str, tf.data.Dataset]],
49-
iterator.DatasetIterator,
50-
tuple[iterator.DatasetIterator, iterator.DatasetIterator],
51-
tuple[iterator.DatasetIterator, Mapping[str, iterator.DatasetIterator]],
49+
iterator.Iterator,
50+
tuple[iterator.Iterator, iterator.Iterator],
51+
tuple[iterator.Iterator, Mapping[str, iterator.Iterator]],
5252
)
5353
MetaT = TypeVar("MetaT")
5454
Logs = Any # Any metric logs returned by the training or evaluation task.
@@ -120,9 +120,9 @@ def run_experiment(
120120

121121
def get_iterators(
122122
datasets: DatasetT,
123-
) -> tuple[iterator.DatasetIterator, Mapping[str, iterator.DatasetIterator]]:
123+
) -> tuple[iterator.Iterator, Mapping[str, iterator.Iterator]]:
124124
"""Creates and unpacks the datasets returned by the task."""
125-
if isinstance(datasets, (iterator.DatasetIterator, tf.data.Dataset)):
125+
if isinstance(datasets, (iterator.Iterator, tf.data.Dataset)):
126126
if isinstance(datasets, tf.data.Dataset):
127127
datasets = iterator.TFDatasetIterator(datasets)
128128
return datasets, {}
@@ -133,7 +133,7 @@ def get_iterators(
133133
)
134134

135135
train_dataset, eval_datasets = datasets
136-
if isinstance(train_dataset, (iterator.DatasetIterator, tf.data.Dataset)):
136+
if isinstance(train_dataset, (iterator.Iterator, tf.data.Dataset)):
137137
if isinstance(train_dataset, tf.data.Dataset):
138138
train_dataset = iterator.TFDatasetIterator(train_dataset)
139139
else:
@@ -143,7 +143,7 @@ def get_iterators(
143143
f" {type(train_dataset)}."
144144
)
145145

146-
if isinstance(eval_datasets, (iterator.DatasetIterator, tf.data.Dataset)):
146+
if isinstance(eval_datasets, (iterator.Iterator, tf.data.Dataset)):
147147
if isinstance(eval_datasets, tf.data.Dataset):
148148
eval_datasets = iterator.TFDatasetIterator(eval_datasets)
149149
return train_dataset, {"": eval_datasets}
@@ -162,7 +162,7 @@ def get_iterators(
162162
}
163163

164164
if not all(
165-
isinstance(v, iterator.DatasetIterator) for v in eval_datasets.values()
165+
isinstance(v, iterator.Iterator) for v in eval_datasets.values()
166166
):
167167
raise ValueError(
168168
"Expected all values in the evaluation datasets mapping to be either"

recml/core/training/jax_trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -298,14 +298,14 @@ class JaxTask(abc.ABC, Generic[StateT]):
298298
def create_datasets(self) -> core.DatasetT:
299299
"""Creates training and evaluation datasets.
300300
301-
Returns:
301+
Returns:`
302302
One of the following:
303-
1) A `tf.data.Dataset` or CLU `DatasetIterator instance that will be
303+
1) A `tf.data.Dataset` or `Iterator` instance that will be
304304
used for training.
305-
2) A tuple of `tf.data.Dataset` or CLU `DatasetIterator` instances where
305+
2) A tuple of `tf.data.Dataset` or `Iterator` instances where
306306
the first element is the training dataset and the second element is
307307
the evaluation dataset.
308-
3) A tuple of `tf.data.Dataset` or CLU `DatasetIterator` instances where
308+
3) A tuple of `tf.data.Dataset` or `Iterator` instances where
309309
the first element is the training dataset and the second element is a
310310
dictionary of evaluation datasets keyed by name.
311311
"""
@@ -601,8 +601,8 @@ def _evaluate_n_steps(
601601
def process_task(
602602
self, task: JaxTask, *, training: bool, check_for_checkpoints: bool
603603
) -> tuple[
604-
iterator_lib.DatasetIterator,
605-
Mapping[str, iterator_lib.DatasetIterator],
604+
iterator_lib.Iterator,
605+
Mapping[str, iterator_lib.Iterator],
606606
State,
607607
partitioning.StepFn,
608608
partitioning.StepFn,

recml/core/training/keras_trainer.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import abc
1919
from collections.abc import Mapping
20+
import dataclasses
2021
import gc
2122
import os
2223
import time
@@ -125,7 +126,7 @@ def __init__(
125126
model_dir = "/tmp"
126127

127128
# This should be set before any layers are constructed and this is a
128-
# fallback incase the trainer binary doesn't already do this.
129+
# fallback in case the trainer binary doesn't already do this.
129130
if (
130131
isinstance(
131132
distribution,
@@ -204,7 +205,7 @@ def _maybe_get_model_kws(
204205
if py_utils.has_argument(task.create_model, "input_shapes"):
205206
batch = next(iter(dataset))
206207
x, *_ = keras.utils.unpack_x_y_sample_weight(batch)
207-
kws["input_shapes"]: keras.tree.map_structure(core.get_shape, x)
208+
kws["input_shapes"]: keras.tree.map_structure(core.get_shape, x) # pylint: disable=undefined-variable
208209

209210
return kws
210211

@@ -232,6 +233,27 @@ def evaluate(self, task: KerasTask) -> core.Logs:
232233
model = task.create_model_for_eval(
233234
**self._maybe_get_model_kws(task, dataset)
234235
)
236+
237+
if keras.backend.backend() == "jax":
238+
[tb_cbk] = [
239+
cbk
240+
for cbk in self._eval_callbacks
241+
if isinstance(cbk, keras_utils.EpochSummaryCallback)
242+
]
243+
epoch_start_time = time.time()
244+
history = model.evaluate(
245+
dataset,
246+
steps=self._steps_per_eval,
247+
callbacks=self._eval_callbacks,
248+
return_dict=True,
249+
)
250+
epoch_dt = time.time() - epoch_start_time
251+
steps_per_second = self._steps_per_eval / epoch_dt
252+
val_logs = {"val_" + k: v for k, v in history.items()}
253+
val_logs["val_steps_per_second"] = steps_per_second
254+
tb_cbk.on_epoch_end(0, val_logs)
255+
return history
256+
235257
return model.evaluate(
236258
dataset,
237259
steps=self._steps_per_eval,

recml/core/utils/keras_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,11 @@ def restore_keras_model(
191191
Args:
192192
model: The Keras model to restore.
193193
checkpoint_dir: The directory containing the Orbax checkpoints.
194-
step: The step to restore the model to. If `None` then the latest checkpoint
195-
will be restored.
194+
step: The checkpoint step to resume training from. If set, it requires a
195+
checkpoint with the same step number to be present in the model directory.
196+
If not set, will resume training from the last checkpoint. Depending on
197+
the value of `max_checkpoints_to_keep`, the model directory only contains
198+
a certain number of the latest checkpoints.
196199
restore_optimizer_vars: Whether to restore the optimizer variables.
197200
restore_steps: Whether to restore the model's steps. If `True` then the
198201
model will continue training from the step the checkpoint was saved at. If

0 commit comments

Comments
 (0)