diff --git a/.gitignore b/.gitignore index 36293a4..709699b 100644 --- a/.gitignore +++ b/.gitignore @@ -212,3 +212,5 @@ cython_debug/ /data/ /logs/ requirements_*.txt + +.specstory/ diff --git a/MODULE.bazel b/MODULE.bazel index 9a5854a..7345325 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -61,9 +61,15 @@ pip.parse( python_version = "3.10", requirements_lock = "//third_party:requirements_3_10_tpu_lock.txt", ) +pip.parse( + hub_name = "ml_infra_mps_3_10", + python_version = "3.10", + requirements_lock = "//third_party:requirements_3_10_mps_lock.txt", +) use_repo( pip, ml_infra_cpu_3_10 = "ml_infra_cpu_3_10", ml_infra_cuda_3_10 = "ml_infra_cuda_3_10", + ml_infra_mps_3_10 = "ml_infra_mps_3_10", ml_infra_tpu_3_10 = "ml_infra_tpu_3_10", ) diff --git a/src/core/BUILD b/src/core/BUILD index 1f39da2..623a926 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -8,7 +8,7 @@ ml_py_library( deps = [ "fiddle", "optax", - ":data", + ":datamodule", ":model", ], ) @@ -24,7 +24,7 @@ ml_py_library( deps = [ "clu", "jax", - ":data", + ":datamodule", ":model", "//src/utilities:logging", ], @@ -36,8 +36,8 @@ ml_py_library( deps = [ "chex", "flax", + "jax", "jaxtyping", - ":train_state", ], ) @@ -60,7 +60,7 @@ ml_py_library( "flax", "jax", "jaxtyping", - ":data", + ":datamodule", ":model", ":train_state", "//src/utilities:logging", diff --git a/src/core/config.py b/src/core/config.py index b9bd73f..e7544c1 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -4,7 +4,7 @@ import fiddle as fdl import optax -from src.core import data as _data +from src.core import datamodule as _datamodule from src.core import model as _model @@ -20,7 +20,7 @@ class DataConfig: drop_remainder (bool): Whether to drop the last incomplete batch. """ - module: fdl.Partial[_data.DataModule] + module: fdl.Partial[_datamodule.DataModule] batch_size: int = 32 num_workers: int = 4 deterministic: bool = True @@ -45,6 +45,8 @@ class TrainerConfig: Attributes: num_train_steps (int): Total number of training steps. + checkpoint_every_n_steps (Optional[int]): Frequency of checkpointing. + If `None`, defaults to `eval_every_n_steps`. log_every_n_steps (int): Frequency of logging training metrics. eval_every_n_steps (int): Frequency of evaluation during training. checkpoint_dir (Optional[str]): Directory of checkpoint to resume from. @@ -53,6 +55,7 @@ class TrainerConfig: """ num_train_steps: int = 10_000 + checkpoint_every_n_steps: typing.Optional[int] = None log_every_n_steps: int = 50 eval_every_n_steps: int = 1_000 checkpoint_dir: typing.Optional[str] = None diff --git a/src/core/evaluate.py b/src/core/evaluate.py index 85c80d1..9e7b4ff 100644 --- a/src/core/evaluate.py +++ b/src/core/evaluate.py @@ -1,5 +1,6 @@ import collections import functools +import traceback import typing from clu import metric_writers @@ -7,14 +8,14 @@ import jax import jaxtyping -from src.core import data as _data +from src.core import datamodule as _datamodule from src.core import model as _model from src.utilities import logging def run( - model: _model.Model, - datamodule: _data.DataModule, + datamodule: _datamodule.DataModule, + evaluation_step: typing.Callable[..., _model.StepOutputs], params: jaxtyping.PyTree, writer: metric_writers.MetricWriter, work_dir: str, @@ -24,8 +25,8 @@ def run( """Runs evaluation loop with the given model and datamodule. Args: - model (Model): The model to evaluate. datamodule (DataModule): The datamodule providing the evaluation data. + evaluation_step (Callable): The pmapped evaluation step function. params (PyTree): The model parameters to use for evaluation. writer (MetricWriter): The metric writer for logging evaluation metrics. work_dir (str): The working directory for saving outputs. @@ -36,11 +37,11 @@ def run( Integer status code (0 for success). """ _status = 0 - logging.rank_zero_debug(f"running {model.__class__.__name__} eval...") - eval_rng = jax.random.fold_in(rng, jax.process_index()) - p_evaluation_step = functools.partial(model.evaluation_step, rng=eval_rng) + logging.rank_zero_info("Compiling evaluation step...") + p_evaluation_step = functools.partial(evaluation_step, rng=rng) p_evaluation_step = jax.pmap(p_evaluation_step, axis_name="batch") + logging.rank_zero_info("Compiling evaluation step...DONE!") hooks = [] if jax.process_index() == 0: @@ -69,7 +70,7 @@ def run( batch, ) with jax.profiler.StepTraceAnnotation( - name="train", + name="evaluation", step_num=step, ): outputs = p_evaluation_step( @@ -85,38 +86,52 @@ def run( # logging at the end of batch if outputs.scalars is not None: - _scalars = {} - for k, v in outputs.scalars.items(): - eval_metrics[k].append(jax.device_get(v).mean()) - _scalars[ - f"eval/{k.replace('_', ' ')}" - ] = jax.device_get(v).mean() writer.write_scalars( - step=step + 1, - scalars=_scalars, + step=step, + scalars={ + f"eval/{k}_step": sum(v) / len(v) + for k, v in outputs.scalars.items() + }, ) if outputs.images is not None: writer.write_images( - step=step + 1, - images=outputs.images, + step=step, + images={ + f"eval/{k}_step": v + for k, v in outputs.images.items() + }, ) + if outputs.histograms is not None: + writer.write_histograms( + step=step, + arrays={ + f"eval/{k}_step": v + for k, v in outputs.histograms.items() + }, + ) + writer.flush() # logging at the end of evaluation logging.rank_zero_info("Evaluation done.") scalar_output = { - f"eval/{k.replace('_', ' ')}": sum(v) / len(v) + f"eval/{k.replace('_', ' ')}_epoch": sum(v) / len(v) for k, v in eval_metrics.items() } writer.write_scalars( step=step, scalars=scalar_output, ) + writer.flush() + except Exception as e: logging.rank_zero_error( "Exception occurred during evaluation: %s", e ) + error_trace = traceback.format_exc() + logging.rank_zero_error("Stack trace:\n%s", error_trace) _status = 1 finally: + writer.close() logging.rank_zero_info( "Evaluation done. Exit with code %d.", _status, diff --git a/src/core/model.py b/src/core/model.py index cf09dfa..0f1cd3b 100644 --- a/src/core/model.py +++ b/src/core/model.py @@ -2,23 +2,27 @@ import typing import chex -from flax import struct +from flax.core import frozen_dict +import jax import jaxtyping -from src.core import train_state as _train_state - @chex.dataclass class StepOutputs: """A base container for outputs from a single step. Attributes: + output (Optional[jax.Array]): The main output of the model. scalars (Optional[Dict[str, Any]]): A dictionary of scalar metrics. images (Optional[Dict[str, Any]]): A dictionary of image outputs. + histograms (Optional[Dict[str, Array]]): A dictionary of array to + plot as histograms. """ + output: typing.Optional[jax.Array] = None scalars: typing.Optional[typing.Dict[str, typing.Any]] = None images: typing.Optional[typing.Dict[str, typing.Any]] = None + histograms: typing.Optional[typing.Dict[str, jax.Array]] = None class Model(abc.ABC): @@ -51,67 +55,45 @@ def init( pass @abc.abstractmethod - def training_step( + def compute_loss( self, *, - state: _train_state.TrainState, - batch: typing.Any, - rngs: typing.Union[typing.Any, typing.Dict[str, typing.Any]], + rngs: typing.Any, + deterministic: bool = False, + params: frozen_dict.FrozenDict, **kwargs, - ) -> typing.Tuple[struct.PyTreeNode, StepOutputs]: - r"""Performs a single training step. + ) -> typing.Tuple[jax.Array, StepOutputs]: + """Computes the loss given parameters and model inputs. Args: - state (TrainState): The current training state. - batch (Any): A batch of data. - rngs (Union[Any, Dict[str, Any]]): Random generators. - **kwargs: Additional keyword arguments. + deterministic (bool): Whether to run the model in deterministic + mode (e.g., disable dropout). Default is `False`. + params (FrozenDict): The model parameters. + **kwargs: Keyword arguments consumed by the model. Returns: - A tuple containing the updated state and step outputs. + A dictionary containing the loss and other outputs. """ - pass + raise NotImplementedError @abc.abstractmethod - def evaluation_step( + def forward( self, *, - params: jaxtyping.PyTree, - batch: typing.Any, - rngs: typing.Union[typing.Any, typing.Dict[str, typing.Any]], + rngs: typing.Any, + deterministic: bool = True, + params: frozen_dict.FrozenDict, **kwargs, ) -> StepOutputs: - r"""Performs a single evaluation step. - - Args: - params (PyTree): The model parameters. - batch (Any): A batch of data. - rngs (Union[Any, Dict[str, Any]]): Random generators. - **kwargs: Additional keyword arguments. - - Returns: - The step outputs containing evaluation metrics. - """ - pass - - @abc.abstractmethod - def predict_step( - self, - *, - params: jaxtyping.PyTree, - batch: typing.Any, - rngs: typing.Union[typing.Any, typing.Dict[str, typing.Any]], - **kwargs, - ) -> typing.Any: - r"""Performs a single prediction step during inference. + """Forward pass the model and returns the output tree structure. Args: - params (PyTree): The model parameters. - batch (Any): A batch of data. - rngs (Union[Any, Dict[str, Any]]): Random generators. - **kwargs: Additional keyword arguments. + deterministic (bool): Whether to run the model in deterministic + mode (e.g., disable dropout). Default is `True`. + params (FrozenDict): The model parameters. + **kwargs: Keyword arguments consumed by the model. Returns: - The model's predictions. + The model outputs. """ - pass + raise NotImplementedError diff --git a/src/core/train.py b/src/core/train.py index 75c9b6a..68fe193 100644 --- a/src/core/train.py +++ b/src/core/train.py @@ -1,35 +1,23 @@ import collections import functools +import os +import traceback import typing -from clu import checkpoint from clu import metric_writers from clu import periodic_actions from flax import jax_utils +from flax.training import checkpoints import jax import jaxtyping -from src.core import data as _data +from src.core import datamodule as _data from src.core import model as _model from src.core import train_state as _train_state from src.utilities import logging - -def _create_step_fn( - model: _model.Model, - rng: typing.Any, -) -> typing.Tuple[jax.Array, typing.Callable, typing.Callable]: - """Creates the step functions for training and evaluation.""" - # create training step function - rng, train_rng = jax.random.split(rng, num=2) - p_training_step = functools.partial(model.training_step, rngs=train_rng) - p_training_step = jax.pmap(p_training_step, axis_name="batch") - - rng, eval_rng = jax.random.split(rng, num=2) - p_evaluation_step = functools.partial(model.evaluation_step, rngs=eval_rng) - p_evaluation_step = jax.pmap(p_evaluation_step, axis_name="batch") - - return rng, p_training_step, p_evaluation_step +EVAL_STEP_OUTPUT = _model.StepOutputs +TRAIN_STEP_OUTPUT = typing.Tuple[_train_state.TrainState, _model.StepOutputs] def _shard(tree: jaxtyping.PyTree) -> jaxtyping.PyTree: @@ -53,11 +41,11 @@ def _shard(tree: jaxtyping.PyTree) -> jaxtyping.PyTree: def run( - model: _model.Model, state: _train_state.TrainState, datamodule: _data.DataModule, + training_step: typing.Callable[..., TRAIN_STEP_OUTPUT], + evaluation_step: typing.Callable[..., EVAL_STEP_OUTPUT], num_train_steps: int, - checkpoint_manager: checkpoint.Checkpoint, writer: metric_writers.MetricWriter, work_dir: str, rng: typing.Any, @@ -69,11 +57,10 @@ def run( """Runs training and evaluation loop with given model and dataloaders. Args: - model (Model): The model to run. - train_dataloader (Any): The training dataloaders. - eval_dataloader (Any): The evaluation dataloaders. + datamodule (DataModule): The data module for loading data. + training_step (Callable): The training step function. + evaluation_step (Callable): The evaluation step function. num_train_steps (int): Number of training steps. - checkpoint_manager (Checkpoint): The checkpoint manager. writer (MetricWriter): The metric writer for logging. work_dir (str): The working directory for saving checkpoints and logs. rng (Any): The random number generator. @@ -87,14 +74,21 @@ def run( Integer status code. """ _status = 0 - logging.rank_zero_debug(f"running {model.__class__.__name__} fit stage...") if checkpoint_every_n_steps is None: checkpoint_every_n_steps = eval_every_n_steps - rng, p_training_step, p_evaluation_step = _create_step_fn( - model=model, - rng=rng, - ) + + logging.rank_zero_info("Compiling training step function...") + rng, train_rng = jax.random.split(rng, num=2) + p_training_step = functools.partial(training_step, rngs=train_rng) + p_training_step = jax.pmap(p_training_step, axis_name="batch") + logging.rank_zero_info("Compiling training step function... DONE!") + + logging.rank_zero_info("Compiling evaluation step function...") + rng, eval_rng = jax.random.split(rng, num=2) + p_evaluation_step = functools.partial(evaluation_step, rngs=eval_rng) + p_evaluation_step = jax.pmap(p_evaluation_step, axis_name="batch") + logging.rank_zero_info("Compiling evaluation step function... DONE!") hooks = [] report_progress = periodic_actions.ReportProgress( @@ -110,22 +104,74 @@ def run( num_profile_steps=5, ) ) - step, epoch = state.step, 0 + step = state.step state = jax_utils.replicate(state) logging.rank_zero_info("Training...") with metric_writers.ensure_flushes(writer): try: train_metrics = collections.defaultdict(list) while True: - for batch in datamodule.train_dataloader(): - batch = _shard(batch) + for train_batch in datamodule.train_dataloader(): + # evaluation and sanity check running + if ( + step % eval_every_n_steps == 0 + or step == num_train_steps + ): + logging.rank_zero_info("Running evaluation...") + eval_metrics = collections.defaultdict(list) + outputs = None + for eval_batch in datamodule.eval_dataloader(): + eval_batch = _shard(eval_batch) + outputs = p_evaluation_step( + params=state.params, + batch=eval_batch, + ) + if not isinstance(outputs, _model.StepOutputs): + raise RuntimeError( + "FATAL: Output from `evaluation_step` is " + "not a `StepOutputs` object." + ) + if outputs.scalars is not None: + for k, v in outputs.scalars.items(): + eval_metrics[k].append( + jax.device_get(v).mean() + ) + logging.rank_zero_info("Evaluation done.") + + if isinstance(outputs, _model.StepOutputs): + writer.write_scalars( + step=step, + scalars={ + f"eval/{k}": sum(v) / len(v) + for k, v in eval_metrics.items() + }, + ) + if outputs.images is not None: + writer.write_images( + step=step, + images={ + f"eval/{k}": v + for k, v in outputs.images.items() + }, + ) + if outputs.histograms is not None: + writer.write_histograms( + step=step, + arrays={ + f"eval/{k}": v + for k, v in outputs.histograms.items() + }, + ) + writer.flush() + + train_batch = _shard(train_batch) with jax.profiler.StepTraceAnnotation( name="train", step_num=step, ): state, outputs = p_training_step( state=state, - batch=batch, + batch=train_batch, ) if not isinstance(outputs, _model.StepOutputs): raise RuntimeError( @@ -135,95 +181,82 @@ def run( if outputs.scalars is not None: for k, v in outputs.scalars.items(): train_metrics[k].append(jax.device_get(v).mean()) - step += 1 for hook in hooks: hook(step) if step % log_every_n_steps == 0: if outputs.scalars is not None: - scalar_output = { - f"train/{k.replace('_', ' ')}_step": sum(v) - / len(v) - for k, v in outputs.scalars.items() - } writer.write_scalars( step=step, - scalars=scalar_output, + scalars={ + f"train/{k}_step": sum(v) / len(v) + for k, v in outputs.scalars.items() + }, ) if outputs.images is not None: writer.write_images( step=step, - images=outputs.images, - ) - - # evaluation - if ( - step % eval_every_n_steps == 0 - or step == num_train_steps - ): - logging.rank_zero_info("Running evaluation...") - eval_metrics = collections.defaultdict(list) - for batch in datamodule.eval_dataloader(): - batch = _shard(batch) - outputs = p_evaluation_step( - params=state.params, - batch=batch, + images={ + f"train/{k}": v + for k, v in outputs.images.items() + }, ) - if not isinstance(outputs, _model.StepOutputs): - raise RuntimeError( - "FATAL: Output from `evaluation_step` is " - "not a `StepOutputs` object." - ) - if outputs.scalars is not None: - for k, v in outputs.scalars.items(): - eval_metrics[k].append( - jax.device_get(v).mean() - ) - logging.rank_zero_info("Evaluation done.") - writer.write_scalars( - step=step, - scalars={ - f"eval/{k.replace('_', ' ')}": sum(v) / len(v) - for k, v in eval_metrics.items() - }, - ) - if outputs.images is not None: - writer.write_images( + if outputs.histograms is not None: + writer.write_histograms( step=step, - images=outputs.images, + arrays={ + f"train/{k}": v + for k, v in outputs.histograms.items() + }, ) + writer.flush() + step += 1 # checkpointing if step % checkpoint_every_n_steps == 0: logging.rank_zero_info("Checkpointing...") - # TODO (juanwulu): resolve the error (no __enter__) - with report_progress.timed("checkpoint"): - filepath = checkpoint_manager.save( - state=jax_utils.unreplicate(state) + if jax.process_index() == 0: + with report_progress.timed("checkpoint"): + filepath = checkpoints.save_checkpoint( + ckpt_dir=os.path.join( + work_dir, + "checkpoints", + ), + target=jax_utils.unreplicate(state), + keep=3, + overwrite=True, + prefix="ckpt-", + step=step, + ) + logging.rank_zero_info( + "Checkpoint saved to %s", + filepath, ) - logging.rank_zero_info( - "Checkpoint saved to %s", - filepath, - ) # logging on the end of epoch - logging.rank_zero_info("Epoch %d done.", epoch) scalar_output = { - f"train/{k.replace('_', ' ')}_epoch": sum(v) / len(v) + f"train/{k}_epoch": sum(v) / len(v) for k, v in train_metrics.items() } writer.write_scalars( - step=epoch, + step=step, scalars=scalar_output, ) - epoch += 1 + writer.flush() + + # break outer loop if reach max steps + if step >= num_train_steps: + break except Exception as e: logging.rank_zero_error( "Exception occurred during training: %s", e ) + error_trace = traceback.format_exc() + logging.rank_zero_error(error_trace) _status = 1 finally: state = jax_utils.unreplicate(state) + writer.close() logging.rank_zero_info( "Training finished. Final step: %d. Exit with code %d.", state.step, diff --git a/src/data/BUILD b/src/data/BUILD index 908264f..c268513 100644 --- a/src/data/BUILD +++ b/src/data/BUILD @@ -20,6 +20,7 @@ ml_py_test( name = "test_huggingface", srcs = ["test_huggingface.py"], deps = [ + "jax", "numpy", "tensorflow", ":huggingface", diff --git a/src/data/huggingface.py b/src/data/huggingface.py index a4d477a..9992f00 100644 --- a/src/data/huggingface.py +++ b/src/data/huggingface.py @@ -1,6 +1,8 @@ import abc import functools import os +import shutil +import tempfile import typing import datasets @@ -27,25 +29,18 @@ class HuggingFaceDataModule(datamodule.DataModule): - `hf_dataset`: the HuggingFace dataset object. - `feature_key`: the key in the dataset features to use as input. - `target_key`: the key in the dataset features to use as target. - - `output_signature`: a (nested) structure of `tf.TensorSpec` objects. - - `_create_dataset`: method to create a `tf.data.Dataset` from the + - `create_dataset`: method to create a `tf.data.Dataset` from the HuggingFace dataset object. - Attributes: - path (str): The path to the HuggingFace dataset. - revision (str): The revision of the dataset for version control. - Args: batch_size (int): The batch size for data loading. - deterministic (bool): Whether to enforce deterministic loading behavior. + deterministic (bool): Whether enforce deterministic loading behavior. drop_remainder (bool): Whether to drop the last incomplete batch. num_workers (int): Number of shards for distributed loading. - seed (int): Random seed for shuffling. - shuffle_buffer_size (int): Buffer size for shuffling the dataset. transform (Optional[Callable], optional): An optional function to - transform the input features. Defaults to `None`. - target_transform (Optional[Callable], optional): An optional function - to transform the target features. Defaults to `None`. + transform the features. Default is `None`. + shuffle_buffer_size (int): Buffer size for shuffling the dataset. + rng (Any): Random seed for shuffling. Default is `PRNGKey(42)`. """ def __init__( @@ -54,23 +49,17 @@ def __init__( deterministic: bool, drop_remainder: bool, num_workers: int, - seed: int, shuffle_buffer_size: int, transform: typing.Optional[typing.Callable] = None, - target_transform: typing.Optional[typing.Callable] = None, + rng: typing.Any = jax.random.PRNGKey(42), ) -> None: self._batch_size = batch_size self._deterministic = deterministic self._drop_remainder = drop_remainder self._num_workers = num_workers - self._seed = seed self._shuffle_buffer_size = shuffle_buffer_size - self._rng = random.fold_in( - random.PRNGKey(self._seed), - jax.process_index(), - ) + self._rng = jax.random.fold_in(rng, jax.process_index()) self._transform = transform - self._target_transform = target_transform # ========================================= # Interface @@ -94,28 +83,34 @@ def target_key(self) -> typing.Optional[str]: @property @abc.abstractmethod - def output_signature(self) -> typing.Any: - r"""Any: A (nested) structure of `tf.TensorSpec` objects.""" + def train_dataset(self) -> typing.Iterable: + r"""Iterable: The training dataset split.""" ... + @property @abc.abstractmethod - def _create_dataset( - self, - *, - split: str, - shuffle_seed: typing.Optional[int] = None, - ) -> tf.data.Dataset: - r"""Create an `tf.data.Dataset` from the HuggingFace dataset object. + def eval_dataset(self) -> typing.Iterable: + r"""Iterable: The validation dataset split.""" + ... - Args: - split (str): The dataset split to create. - shuffle_seed (Optional[int], optional): Seed for shuffling. - If `None`, no shuffling is applied. + @property + @abc.abstractmethod + def test_dataset(self) -> typing.Iterable: + r"""Iterable: The test dataset split.""" + ... + + @staticmethod + @abc.abstractmethod + def create_dataset(*args, **kwargs) -> tf.data.Dataset: + r"""Create sharded `tf.data.Dataset` from the HuggingFace dataset. + + The default method is suitable for processing image datasets with + `Pillow` images. Override this method for custom dataset processing. Returns: The created `tf.data.Dataset` instance. """ - pass + ... # ========================================= @property @@ -135,7 +130,7 @@ def drop_remainder(self) -> bool: @property def num_workers(self) -> int: - r"""int: Number of shards for distributed loading.""" + r"""int: Number of workers for distributed loading.""" return self._num_workers @property @@ -155,9 +150,9 @@ def num_test_examples(self) -> int: return len(self.hf_dataset["test"]) # type: ignore @property - def seed(self) -> int: - r"""int: Random seed for shuffling.""" - return self._seed + def rng(self) -> typing.Any: + r"""Any: Random seed for shuffling.""" + return self._rng @property def shuffle_buffer_size(self) -> int: @@ -174,37 +169,14 @@ def transform(self) -> typing.Optional[typing.Callable]: r"""Optional[Callable]: Transformation for the input features.""" return self._transform - @property - def target_transform(self) -> typing.Optional[typing.Callable]: - r"""Optional[Callable]: Transformation for the target features.""" - return self._target_transform - - def train_dataloader(self) -> typing.Generator[PyTree, None, None]: - r"""Returns an iterable over the training dataset.""" - self._rng, shuffle_rng = random.split(self._rng, num=2) - ds = self._create_dataset( - split="train", - shuffle_seed=int(shuffle_rng[0]), # type: ignore - ) - for data in ds.as_numpy_iterator(): - yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) - - def eval_dataloader(self) -> typing.Generator[PyTree, None, None]: - r"""Returns an iterable over the validation dataset.""" - ds = self._create_dataset(split="validation") - for data in ds.as_numpy_iterator(): - yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) - - def test_dataloader(self) -> typing.Generator[PyTree, None, None]: - r"""Returns an iterable over the test dataset.""" - ds = self._create_dataset(split="test") - for data in ds.as_numpy_iterator(): - yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) - class HuggingFaceImageDataModule(HuggingFaceDataModule): r"""Data module for HuggingFace image datasets. + Attributes: + path (str): The path to the HuggingFace dataset. + revision (str): The revision of the dataset for version control. + Args: batch_size (int): The batch size for data loading. deterministic (bool): Whether the dataloaders are deterministic. @@ -212,9 +184,12 @@ class HuggingFaceImageDataModule(HuggingFaceDataModule): num_workers (int): Number of shards for distributed loading. resize (int): The size to resize images to (square). resample (int): Resampling filter to use for resizing images. + shuffle_buffer_size (int): Buffer size for random shuffling. transform (Optional[Callable], optional): An optional function to - transform the input images. Defaults to `None`. - seed (int, optional): Random seed for shuffling. Defaults to `42`. + transform the input images. Default is `None`. + use_cache (bool, optional): Whether to use cached dataset. + Default is `True`. + rng (Any): Random seed for shuffling. Default is `PRNGKey(42)`. """ def __init__( @@ -226,13 +201,24 @@ def __init__( resize: int, resample: int, shuffle_buffer_size: int, - seed: int, transform: typing.Optional[typing.Callable] = None, - target_transform: typing.Optional[typing.Callable] = None, + use_cache: bool = True, + rng: typing.Any = jax.random.PRNGKey(42), ) -> None: - r"""Instantiates a `HuggingFaceImageDataModule` object.""" self._resize = resize self._resample = resample + if use_cache: + cache_dir = os.path.join( + tempfile.gettempdir(), + "chimera", + "huggingface", + ) + if os.path.exists(cache_dir): + # NOTE: clear the cache directory to avoid corrupted cache + shutil.rmtree(cache_dir) + os.makedirs(cache_dir, exist_ok=True) + else: + cache_dir = None super().__init__( batch_size=batch_size, @@ -240,121 +226,234 @@ def __init__( drop_remainder=drop_remainder, num_workers=num_workers, shuffle_buffer_size=shuffle_buffer_size, - seed=seed, transform=transform, - target_transform=target_transform, + rng=rng, ) - def _create_dataset( - self, + # prepare the dataset splits + pre_transform = functools.partial( + self.pre_transform, + feature_key=self.feature_key, + target_key=self.target_key, + center_crop=True, + resample=self._resample, + resize=self._resize, + ) + self._train_dataset = self.create_dataset( + batch_size=self.batch_size, + dataset=self.hf_dataset["train"] + .map(pre_transform, batched=False, num_proc=1) + .to_tf_dataset(batch_size=None, prefetch=False), + deterministic=self.deterministic, + drop_remainder=self.drop_remainder, + shuffle_buffer_size=self.shuffle_buffer_size, + shuffle_seed=int(self._rng[0]), + transform=self.transform, + cache_dir=( + os.path.join(cache_dir, "train_" + self.__class__.__name__) + if cache_dir is not None + else None + ), + ) + self._test_dataset = self.create_dataset( + batch_size=self.batch_size, + dataset=self.hf_dataset["test"] + .map(pre_transform, batched=False, num_proc=1) + .to_tf_dataset(batch_size=None, prefetch=False), + deterministic=self.deterministic, + drop_remainder=self.drop_remainder, + shuffle_buffer_size=self.shuffle_buffer_size, + shuffle_seed=None, + transform=self.transform, + cache_dir=( + os.path.join(cache_dir, "test_" + self.__class__.__name__) + if cache_dir is not None + else None + ), + ) + if "validation" in self.hf_dataset: + self._eval_dataset = self.create_dataset( + batch_size=self.batch_size, + dataset=self.hf_dataset["validation"] + .map(pre_transform, batched=False, num_proc=1) + .to_tf_dataset(batch_size=None, prefetch=False), + deterministic=self.deterministic, + drop_remainder=self.drop_remainder, + shuffle_buffer_size=self.shuffle_buffer_size, + shuffle_seed=None, + transform=self.transform, + cache_dir=( + os.path.join(cache_dir, "val_" + self.__class__.__name__) + if cache_dir is not None + else None + ), + ) + elif "val" in self.hf_dataset: + self._eval_dataset = self.create_dataset( + batch_size=self.batch_size, + dataset=self.hf_dataset["val"] + .map(pre_transform, batched=False, num_proc=1) + .to_tf_dataset(batch_size=None, prefetch=False), + deterministic=self.deterministic, + drop_remainder=self.drop_remainder, + shuffle_buffer_size=self.shuffle_buffer_size, + shuffle_seed=None, + transform=self.transform, + cache_dir=( + os.path.join(cache_dir, "val_" + self.__class__.__name__) + if cache_dir is not None + else None + ), + ) + else: + # NOTE: otherwise, use test set as validation set by default + self._eval_dataset = self._test_dataset + + @property + def train_dataset(self) -> tf.data.Dataset: + r"""tf.data.Dataset: The training dataset split.""" + return self._train_dataset + + @property + def eval_dataset(self) -> tf.data.Dataset: + r"""tf.data.Dataset: The validation dataset split.""" + return self._eval_dataset + + @property + def test_dataset(self) -> tf.data.Dataset: + r"""tf.data.Dataset: The test dataset split.""" + return self._test_dataset + + @staticmethod + def create_dataset( *, - split: str, + batch_size: int, + deterministic: bool, + drop_remainder: bool, + dataset: tf.data.Dataset, + shuffle_buffer_size: int, shuffle_seed: typing.Optional[int] = None, + transform: typing.Optional[typing.Callable] = None, + cache_dir: typing.Optional[str] = None, ) -> tf.data.Dataset: - r"""Create an `tf.data.Dataset` from the HuggingFace dataset object. + r"""Create sharded `tf.data.Dataset` from the HuggingFace dataset. The default method is suitable for processing image datasets with `Pillow` images. Override this method for custom dataset processing. Args: - split (str): The dataset split to create. + batch_size (int): The batch size for data loading. + deterministic (bool): Whether to enforce deterministic loading. + drop_remainder (bool): Whether to drop the last incomplete batch. + dataset (tf.data.Dataset): The converted HuggingFace dataset. + shuffle_buffer_size (int): Buffer size for random shuffling. shuffle_seed (Optional[int], optional): Seed for shuffling. If `None`, no shuffling is applied. + transform (Optional[Callable], optional): An optional function to + transform the features. Default is `None`. + cache_dir (Optional[str], optional): Directory to cache the dataset. Returns: The created `tf.data.Dataset` instance. """ - _hf_dataset = self.hf_dataset[split] - - def __hf_generator() -> typing.Generator[typing.Any, None, None]: - r"""Default iterator over HuggingFace dataset.""" - for example in _hf_dataset: - image = example[self.feature_key] # type: ignore - target = ( - example[self.target_key] # type: ignore - if self.target_key - else None - ) - if not isinstance(image, Image.Image): - raise ValueError( - "Default iterator expects the image to be a " - f"`PIL.Image.Image` object, but got {type(image)}." - ) - image = image.convert("RGB") - - # resize the image - width, height = image.size - scale = self._resize / min(width, height) - new_width, new_height = int(width * scale), int(height * scale) - image = image.resize( - size=(new_width, new_height), - resample=self._resample, - ) - - # center crop - left = (new_width - self._resize) / 2 - top = (new_height - self._resize) / 2 - right = (new_width + self._resize) / 2 - bottom = (new_height + self._resize) / 2 - image = image.crop((left, top, right, bottom)) + if isinstance(transform, typing.Callable): + dataset = dataset.map( + map_func=transform, + deterministic=deterministic, + num_parallel_calls=tf.data.AUTOTUNE, + ) - yield image, target + if shuffle_seed is not None: + dataset = dataset.shuffle( + buffer_size=shuffle_buffer_size, + seed=shuffle_seed, + reshuffle_each_iteration=True, + ) + + if cache_dir is not None: + dataset = dataset.cache(filename=cache_dir) - ds = tf.data.Dataset.from_generator( - __hf_generator, - output_signature=self.output_signature, + dataset = dataset.batch( + batch_size=batch_size, + deterministic=deterministic, + drop_remainder=drop_remainder, + num_parallel_calls=tf.data.AUTOTUNE, ) - def __make_shard_dataset( - shard_index: int, - num_workers: int, - dataset: tf.data.Dataset, - local_seed: typing.Optional[int] = None, - ) -> tf.data.Dataset: - r"""Shards the input TensorFlow dataset for parallel loading.""" - local_ds = dataset.shard(num_shards=num_workers, index=shard_index) - if local_seed is not None: - local_ds = local_ds.shuffle( - buffer_size=self.shuffle_buffer_size, - seed=int(local_seed), # type: ignore - ) - if self.transform is not None: - local_ds = local_ds.map( - map_func=self.transform, - deterministic=self.deterministic, - num_parallel_calls=tf.data.AUTOTUNE, - ) - local_ds = local_ds.batch( - batch_size=self.batch_size, - deterministic=self.deterministic, - drop_remainder=self.drop_remainder, - num_parallel_calls=tf.data.AUTOTUNE, + return dataset.prefetch(buffer_size=tf.data.AUTOTUNE) + + @staticmethod + def pre_transform( + example: typing.Dict[str, typing.Any], + feature_key: str, + target_key: typing.Optional[str], + center_crop: bool = True, + resample: typing.Optional[int] = None, + resize: typing.Optional[int] = None, + ) -> typing.Dict[str, typing.Any]: + r"""Pre-transformation function for input images. + + Args: + example (Dict[str, Any]): A dictionary of data from the dataset. + feature_key (str): The name of the input features to use. + target_key (Optional[str]): The name of the target features to use. + center_crop (bool, optional): Whether to apply center cropping + after resizing. Default is `True`. + resample (Optional[int], optional): The resampling filter to use for + resizing images. If `None`, use `PIL.Image.NEAREST`. + resize (Optional[int], optional): The size to resize images to + (square). If `None`, no resizing is applied. Default is `None`. + + Returns: + A dictionary with processed images and targets. + """ + image = example[feature_key] + target = example[target_key] if target_key is not None else None + if not isinstance(image, Image.Image): + raise ValueError( + "Default pre-transformation expects the image to be a " + f"`PIL.Image.Image` object, but got {type(image)}." ) - return local_ds + image = image.convert("RGB") - if shuffle_seed is not None: - local_seed = random.fold_in( - random.PRNGKey(shuffle_seed), - jax.process_index(), - )[0] - local_seed = int(local_seed) # type: ignore + # resize the image + if resize is not None: + width, height = image.size + scale = resize / min(width, height) + new_width, new_height = int(width * scale), int(height * scale) + image = image.resize( + size=(new_width, new_height), + resample=resample, + ) + + # center crop + if center_crop: + left = (new_width - resize) / 2 + top = (new_height - resize) / 2 + right = (new_width + resize) / 2 + bottom = (new_height + resize) / 2 + image = image.crop((left, top, right, bottom)) + + if target_key is None: + return {"image": image} else: - local_seed = None - - indices = tf.data.Dataset.range(self.num_workers) - out = indices.interleave( - map_func=functools.partial( - __make_shard_dataset, - num_workers=self.num_workers, - dataset=ds, - local_seed=local_seed, - ), - deterministic=self.deterministic, - num_parallel_calls=tf.data.AUTOTUNE, - ) + return {"image": image, "label": target} - return out.prefetch(buffer_size=tf.data.AUTOTUNE) + def train_dataloader(self) -> typing.Generator[PyTree, None, None]: + r"""Generator[PyTree]: Returns an iterable over the training data.""" + for data in self.train_dataset.as_numpy_iterator(): + yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) + + def eval_dataloader(self) -> typing.Generator[PyTree, None, None]: + r"""Generator[PyTree]: Returns an iterable over the validation data.""" + for data in self.eval_dataset.as_numpy_iterator(): + yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) + + def test_dataloader(self) -> typing.Generator[PyTree, None, None]: + r"""Generator[PyTree]: Returns an iterable over the test data.""" + for data in self.test_dataset.as_numpy_iterator(): + yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) # ============================================================================== @@ -384,11 +483,12 @@ class CIFAR10DataModule(HuggingFaceImageDataModule): image to before cropping. Defaults to `224`. resample (int, optional): Resampling filter to use when resizing images. Defaults to `3` (PIL.Image.BICUBIC). - seed (int, optional): Random seed for shuffling. Defaults to `42`. shuffle_buffer_size (int, optional): Buffer size for random shuffling. Defaults to `10_000`. streaming (bool, optional): Whether to stream the dataset using the `datasets` library. Defaults to `False`. + rng (jax.Array, optional): Random key for shuffling. + Default is `random.PRNGKey(42)`. """ def __init__( @@ -399,11 +499,10 @@ def __init__( num_workers: int = 4, resize: int = 224, resample: int = 3, - seed: int = 42, shuffle_buffer_size: int = 10_000, streaming: bool = False, transform: typing.Optional[typing.Callable] = None, - target_transform: typing.Optional[typing.Callable] = None, + rng: jax.Array = random.PRNGKey(42), ) -> None: self._hf_dataset = datasets.load_dataset( path="uoft-cs/cifar10", @@ -418,10 +517,9 @@ def __init__( num_workers=num_workers, resize=resize, resample=resample, - seed=seed, shuffle_buffer_size=shuffle_buffer_size, transform=transform, - target_transform=target_transform, + rng=rng, ) @property @@ -439,14 +537,6 @@ def target_key(self) -> str: r"""str: The key in the dataset features to use as target.""" return "label" - @property - def output_signature(self) -> typing.Tuple[tf.TensorSpec, tf.TensorSpec]: - r"""Tuple[tf.TensorSpec, tf.TensorSpec]: Tensor specifications.""" - return ( - tf.TensorSpec(shape=(224, 224, 3), dtype=tf.uint8), # type: ignore - tf.TensorSpec(shape=(), dtype=tf.int64), # type: ignore - ) - @property @typing_extensions.override def num_val_examples(self) -> int: @@ -454,13 +544,6 @@ def num_val_examples(self) -> int: # NOTE: using test set as validation set by default return len(self.hf_dataset["test"]) # type: ignore - @typing_extensions.override - def eval_dataloader(self) -> typing.Generator[PyTree, None, None]: - r"""Returns an iterable over the validation dataset.""" - ds = self._create_dataset(split="test") - for data in ds.as_numpy_iterator(): - yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) - class CIFAR100DataModule(HuggingFaceImageDataModule): r"""CIFAR-100 Image Classification Dataset. @@ -489,6 +572,8 @@ class CIFAR100DataModule(HuggingFaceImageDataModule): Defaults to `10_000`. streaming (bool, optional): Whether to stream the dataset using the `datasets` library. Defaults to `False`. + rng (jax.Array, optional): Random key for shuffling. + Defaults to `random.PRNGKey(42)`. """ def __init__( @@ -499,11 +584,10 @@ def __init__( num_workers: int = 4, resize: int = 224, resample: int = 3, - seed: int = 42, shuffle_buffer_size: int = 10_000, streaming: bool = False, transform: typing.Optional[typing.Callable] = None, - target_transform: typing.Optional[typing.Callable] = None, + rng: jax.Array = random.PRNGKey(42), ) -> None: self._hf_dataset = datasets.load_dataset( path="uoft-cs/cifar100", @@ -518,10 +602,9 @@ def __init__( num_workers=num_workers, resize=resize, resample=resample, - seed=seed, shuffle_buffer_size=shuffle_buffer_size, transform=transform, - target_transform=target_transform, + rng=rng, ) @property @@ -539,14 +622,6 @@ def target_key(self) -> str: r"""str: The key in the dataset features to use as target.""" return "fine_label" - @property - def output_signature(self) -> typing.Tuple[tf.TensorSpec, tf.TensorSpec]: - r"""Tuple[tf.TensorSpec, tf.TensorSpec]: Tensor specifications.""" - return ( - tf.TensorSpec(shape=(224, 224, 3), dtype=tf.uint8), # type: ignore - tf.TensorSpec(shape=(), dtype=tf.int64), # type: ignore - ) - @property @typing_extensions.override def num_val_examples(self) -> int: @@ -554,13 +629,6 @@ def num_val_examples(self) -> int: # NOTE: using test set as validation set by default return len(self.hf_dataset["test"]) # type: ignore - @typing_extensions.override - def eval_dataloader(self) -> typing.Generator[PyTree, None, None]: - r"""Returns an iterable over the validation dataset.""" - ds = self._create_dataset(split="test") - for data in ds.as_numpy_iterator(): - yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) - class ImageNet1KDataModule(HuggingFaceImageDataModule): r"""ILSVRC2012 image dataset subset with :math:`1,000` classes. @@ -583,11 +651,12 @@ class ImageNet1KDataModule(HuggingFaceImageDataModule): image to before cropping. Defaults to `224`. resample (int, optional): Resampling filter to use when resizing images. Defaults to `3` (PIL.Image.BICUBIC). - seed (int, optional): Random seed for shuffling. Defaults to `42`. shuffle_buffer_size (int, optional): Buffer size for random shuffling. Defaults to `10_000`. streaming (bool, optional): Whether to stream the dataset using the `datasets` library. Defaults to `False`. + rng (jax.Array, optional): Random key for shuffling. + Default is `random.PRNGKey(42)`. """ def __init__( @@ -598,11 +667,10 @@ def __init__( num_workers: int = 4, resize: int = 224, resample: int = 3, - seed: int = 42, shuffle_buffer_size: int = 10_000, streaming: bool = False, transform: typing.Optional[typing.Callable] = None, - target_transform: typing.Optional[typing.Callable] = None, + rng: jax.Array = random.PRNGKey(42), ) -> None: self._hf_dataset = datasets.load_dataset( path="ILSVRC/imagenet-1k", @@ -617,10 +685,9 @@ def __init__( num_workers=num_workers, resize=resize, resample=resample, - seed=seed, shuffle_buffer_size=shuffle_buffer_size, transform=transform, - target_transform=target_transform, + rng=rng, ) @property @@ -638,14 +705,6 @@ def target_key(self) -> str: r"""str: The key in the dataset features to use as target.""" return "label" - @property - def output_signature(self) -> typing.Tuple[tf.TensorSpec, tf.TensorSpec]: - r"""Tuple[tf.TensorSpec, tf.TensorSpec]: Tensor specifications.""" - return ( - tf.TensorSpec(shape=(224, 224, 3), dtype=tf.uint8), # type: ignore - tf.TensorSpec(shape=(), dtype=tf.int64), # type: ignore - ) - class MNISTDataModule(HuggingFaceImageDataModule): r"""MNIST Handwritten Digit Dataset. @@ -667,11 +726,12 @@ class MNISTDataModule(HuggingFaceImageDataModule): image to before cropping. Defaults to `224`. resample (int, optional): Resampling filter to use when resizing images. Defaults to `3` (PIL.Image.BICUBIC). - seed (int, optional): Random seed for shuffling. Defaults to `42`. shuffle_buffer_size (int, optional): Buffer size for random shuffling. Defaults to `10_000`. streaming (bool, optional): Whether to stream the dataset using the `datasets` library. Defaults to `False`. + rng (jax.Array, optional): Random key for shuffling. + Default is `random.PRNGKey(42)`. """ def __init__( @@ -682,11 +742,10 @@ def __init__( num_workers: int = 4, resize: int = 224, resample: int = 3, - seed: int = 42, shuffle_buffer_size: int = 10_000, streaming: bool = False, transform: typing.Optional[typing.Callable] = None, - target_transform: typing.Optional[typing.Callable] = None, + rng: jax.Array = random.PRNGKey(42), ) -> None: self._hf_dataset = datasets.load_dataset( path="ylecun/mnist", @@ -701,10 +760,9 @@ def __init__( num_workers=num_workers, resize=resize, resample=resample, - seed=seed, shuffle_buffer_size=shuffle_buffer_size, transform=transform, - target_transform=target_transform, + rng=rng, ) @property @@ -722,14 +780,6 @@ def target_key(self) -> str: r"""str: The key in the dataset features to use as target.""" return "label" - @property - def output_signature(self) -> typing.Tuple[tf.TensorSpec, tf.TensorSpec]: - r"""Tuple[tf.TensorSpec, tf.TensorSpec]: Tensor specifications.""" - return ( - tf.TensorSpec(shape=(224, 224, 3), dtype=tf.uint8), # type: ignore - tf.TensorSpec(shape=(), dtype=tf.int64), # type: ignore - ) - @property @typing_extensions.override def num_val_examples(self) -> int: @@ -737,13 +787,6 @@ def num_val_examples(self) -> int: # NOTE: using test set as validation set by default return len(self.hf_dataset["test"]) # type: ignore - @typing_extensions.override - def eval_dataloader(self) -> typing.Generator[PyTree, None, None]: - r"""Returns an iterable over the validation dataset.""" - ds = self._create_dataset(split="test") - for data in ds.as_numpy_iterator(): - yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) - __all__ = [ "HuggingFaceDataModule", diff --git a/src/data/test_huggingface.py b/src/data/test_huggingface.py index 9dc153e..603eef2 100644 --- a/src/data/test_huggingface.py +++ b/src/data/test_huggingface.py @@ -1,6 +1,7 @@ import sys import typing +import jax import numpy as np import pytest import tensorflow as tf @@ -25,9 +26,9 @@ def test_cifar10_datamodule() -> None: dm = huggingface.CIFAR10DataModule( batch_size=2, num_workers=1, - seed=0, transform=_default_transform, streaming=False, + rng=jax.random.PRNGKey(0), ) assert dm.batch_size == 2 assert dm.deterministic is True @@ -36,7 +37,7 @@ def test_cifar10_datamodule() -> None: assert dm.num_train_examples == 50000 assert dm.num_val_examples == 10000 assert dm.num_test_examples == 10000 - assert dm.seed == 0 + assert dm.rng == jax.random.PRNGKey(0) assert all(key in dm.splits for key in ["train", "test"]) # test training dataloader @@ -62,9 +63,9 @@ def test_cifar100_datamodule() -> None: dm = huggingface.CIFAR100DataModule( batch_size=2, num_workers=1, - seed=0, transform=_default_transform, streaming=False, + rng=jax.random.PRNGKey(0), ) assert dm.batch_size == 2 assert dm.deterministic is True @@ -73,7 +74,7 @@ def test_cifar100_datamodule() -> None: assert dm.num_train_examples == 50000 assert dm.num_val_examples == 10000 assert dm.num_test_examples == 10000 - assert dm.seed == 0 + assert dm.rng == jax.random.PRNGKey(0) assert all(key in dm.splits for key in ["train", "test"]) # test training dataloader @@ -99,9 +100,9 @@ def test_imagenet1k_datamodule() -> None: dm = huggingface.ImageNet1KDataModule( batch_size=2, num_workers=1, - seed=0, transform=_default_transform, streaming=False, + rng=jax.random.PRNGKey(0), ) assert dm.batch_size == 2 assert dm.deterministic is True @@ -110,7 +111,7 @@ def test_imagenet1k_datamodule() -> None: assert dm.num_train_examples == 1_281_167 assert dm.num_val_examples == 50_000 assert dm.num_test_examples == 100_000 - assert dm.seed == 0 + assert dm.rng == jax.random.PRNGKey(0) assert all(key in dm.splits for key in ["train", "validation", "test"]) # test training dataloader @@ -136,9 +137,9 @@ def test_mnist_datamodule() -> None: dm = huggingface.MNISTDataModule( batch_size=2, num_workers=1, - seed=0, transform=_default_transform, streaming=False, + rng=jax.random.PRNGKey(0), ) assert dm.batch_size == 2 assert dm.deterministic is True @@ -147,7 +148,7 @@ def test_mnist_datamodule() -> None: assert dm.num_train_examples == 60000 assert dm.num_val_examples == 10000 assert dm.num_test_examples == 10000 - assert dm.seed == 0 + assert dm.rng == jax.random.PRNGKey(0) assert all(key in dm.splits for key in ["train", "test"]) # test training dataloader diff --git a/src/projects/generative/BUILD b/src/projects/generative/BUILD index e5bf7f3..b6d261a 100644 --- a/src/projects/generative/BUILD +++ b/src/projects/generative/BUILD @@ -1,6 +1,6 @@ -load("//learning:defs.bzl", "ml_py_library", "ml_py_test") +load("//third_party:defs.bzl", "ml_py_binary", "ml_py_library") -package(default_visibility = ["//learning:__subpackages__"]) +package(default_visibility = ["//src/projects/generative:__subpackages__"]) ml_py_library( name = "config", @@ -9,35 +9,43 @@ ml_py_library( "fiddle", "optax", ":meanflow", - "//learning/core:config", - "//learning/data:cifar", - "//learning/data:preprocess", + "//src/core:config", + "//src/data:huggingface", + "//src/data:preprocess", ], ) -ml_py_library( - name = "meanflow", - srcs = ["meanflow.py"], +ml_py_binary( + name = "main", + srcs = ["main.py"], deps = [ - "chex", - "flax", + "absl-py", + "clu", + "fiddle", "jax", - "jaxlib", - "jaxtyping", - "typing_extensions", - "//learning/core:mixin", - "//learning/generative/model:refinenet", + "optax", + "tensorflow", + ":config", + "//src/core:config", + "//src/core:evaluate", + "//src/core:train", + "//src/core:train_state", + "//src/utilities:logging", + "//src/utilities:visualization", ], ) -ml_py_test( - name = "test_meanflow", - srcs = ["test_meanflow.py"], +ml_py_library( + name = "meanflow", + srcs = ["meanflow.py"], deps = [ "chex", "flax", "jax", - "jaxlib", - ":meanflow", + "jaxtyping", + "typing_extensions", + "//src/core:model", + "//src/core:train_state", + "//src/projects/generative/model:unet", ], ) diff --git a/src/projects/generative/config.py b/src/projects/generative/config.py index eb979e6..3da65e4 100644 --- a/src/projects/generative/config.py +++ b/src/projects/generative/config.py @@ -1,50 +1,73 @@ import functools +import math import fiddle as fdl import optax -from learning.core import config as _config -from learning.data import cifar -from learning.data import preprocess -from learning.generative import meanflow +from src.core import config as _config +from src.data import huggingface +from src.data import preprocess +from src.projects.generative import meanflow +# ============================================================================== # MeanFlow Models def meanflow_unet_cifar_10() -> _config.ExperimentConfig: return _config.ExperimentConfig( name="meanflow_unet_cifar_10", - data=fdl.Partial( - cifar.CIFAR10DataModule, - preprocess_fn=preprocess.chain( - functools.partial( - preprocess.filter_keys, - keys=["image", "label"], - ), - functools.partial( - preprocess.normalize, - mean=(0.5, 0.5, 0.5), - std=(0.5, 0.5, 0.5), + mode="train", + data=_config.DataConfig( + module=fdl.Partial( + huggingface.CIFAR10DataModule, + resize=32, + transform=preprocess.chain( + functools.partial( + preprocess.filter_keys, + keys=["image", "label"], + ), + functools.partial( + preprocess.normalize, + mean=(0.5, 0.5, 0.5), + std=(0.5, 0.5, 0.5), + ), ), ), + batch_size=1024, + num_workers=2, + deterministic=True, + drop_remainder=True, ), - model=fdl.Config( + model=fdl.Partial( meanflow.MeanFlowUNetModel, in_channels=3, image_size=32, - latent_channels=16, - num_classes=10, - use_cfg_embedding=False, + features=128, dropout_rate=0.2, + epsilon=1e-6, + skip_scale=math.sqrt(0.5), timestamp_cond="t_and_t_minus_r", - timestamp_sampler="lognormal", + timestamp_sampler="logit-normal", timestamp_sampler_kwargs=dict(mean=-2.0, stddev=2.0), timestamp_overlap_rate=0.25, adaptive_weight_power=0.75, ), - # TODO: implement the warmup in https://arxiv.org/abs/1706.02677 - batch_size=1024, - lr_scheduler=fdl.Config(optax.constant_schedule, value=6e-4), - optimizer=fdl.Partial(optax.adam, b1=0.9, b2=0.999), - ema_rate=0.99995, - num_train_steps=800_000, + trainer=_config.TrainerConfig( + num_train_steps=800_000, + log_every_n_steps=50, + checkpoint_every_n_steps=10_000, # save every 10k steps + eval_every_n_steps=1_000, + max_checkpoints_to_keep=3, + profile=False, + ), + optimizer=_config.OptimizerConfig( + lr_schedule=fdl.Config( + optax.warmup_constant_schedule, + init_value=1e-8, + peak_value=6e-4, + warmup_steps=10_000, + ), + optimizer=fdl.Partial(optax.adam, b1=0.9, b2=0.999), + ema_rate=0.9999, + ), + seed=42, ) diff --git a/src/projects/generative/main.py b/src/projects/generative/main.py new file mode 100644 index 0000000..7a3f740 --- /dev/null +++ b/src/projects/generative/main.py @@ -0,0 +1,244 @@ +from datetime import datetime +import functools +import os +import platform +import typing + +from absl import app +from absl import flags +from clu import metric_writers +from clu import platform as clu_platform +from fiddle import absl_flags +import fiddle as fdl +import jax +from jax import numpy as jnp +import jaxtyping +import optax +import tensorflow as tf + +from src.core import config as _config +from src.core import evaluate as _evaluate +from src.core import model as _model +from src.core import train as _train +from src.core import train_state as _train_state +from src.utilities import logging +from src.utilities import visualization + +CONFIG = absl_flags.DEFINE_fiddle_config( + name="experiment", + default=None, + help_string="Experiment configuration.", + required=True, +) +FLAGS = flags.FLAGS +flags.DEFINE_string( + name="work_dir", + default=None, + help="Directory to store the experiment results.", + required=True, +) +PyTree = jaxtyping.PyTree + + +# toggle off GPU/TPU for TensorFlow +tf.config.experimental.set_visible_devices([], "GPU") +tf.config.experimental.set_visible_devices([], "TPU") +assert not tf.config.experimental.get_visible_devices("GPU") + + +def evaluation_step( + rngs: jax.Array, + model: _model.Model, + params: PyTree, + batch: PyTree, + **kwargs, +) -> _model.StepOutputs: + r"""Conduct a single evaluation step and compute metrics.""" + local_rng = jax.random.fold_in(rngs, jax.lax.axis_index("batch")) + outputs = model.forward( + rngs=local_rng, + params=params, + deterministic=True, + batch=batch, + **kwargs, + ) + out = outputs.output + assert isinstance(out, jax.Array) + out = jnp.clip(out * 0.5 + 0.5, 0.0, 1.0) + img_grid = visualization.make_grid(out, n_rows=4, n_cols=8, padding=2) + outputs.images = {"sampled images": img_grid} + return outputs + + +def training_step( + rngs: jax.Array, + model: _model.Model, + state: _train_state.TrainState, + batch: typing.Dict[str, typing.Any], + **kwargs, +) -> typing.Tuple[_train_state.TrainState, _model.StepOutputs]: + r"""Conduct a single training step and update train state.""" + local_rng = jax.random.fold_in(rngs, state.step) + local_rng = jax.random.fold_in(local_rng, jax.lax.axis_index("batch")) + + def loss_fn(params: PyTree) -> typing.Tuple[jax.Array, _model.StepOutputs]: + loss, outputs = model.compute_loss( + rngs=local_rng, + params=params, + deterministic=False, + batch=batch, + **kwargs, + ) + return loss, outputs + + grad_fn = jax.value_and_grad(loss_fn, argnums=0, has_aux=True) + (_, outputs), grads = grad_fn(state.params) + grads = jax.lax.pmean(grads, axis_name="batch") + new_state = state.apply_gradients(grads=grads) + + return new_state, outputs + + +def main(_: typing.List[str]) -> int: + r"""Main entry point for training and evaluate generative models.""" + del _ # unused. + + # Log the current platform + logging.rank_zero_info("Running on platform: %s", platform.node()) + + # Setup JAX runtime + logging.rank_zero_info("Running on JAX backend: %s", jax.default_backend()) + logging.rank_zero_info( + "Running on JAX process: %d / %d", + jax.process_index() + 1, + jax.process_count(), + ) + logging.rank_zero_info("Running on JAX devices: %r", jax.devices()) + + clu_platform.work_unit().set_task_status( + "process_index: %d, process_count: %d" + % (jax.process_index() + 1, jax.process_count()), + ) + clu_platform.work_unit().create_artifact( + clu_platform.ArtifactType.DIRECTORY, + FLAGS.work_dir, + "Working directory.", + ) + + # Setup Experiment + exp_config = CONFIG.value + if not isinstance(exp_config, _config.ExperimentConfig): + logging.rank_zero_error( + "Expect configuration to be of type `ExperimentConfig`, got %s.", + type(exp_config), + ) + return 1 + logging.rank_zero_info("Experiment Configuration:\n%s", exp_config) + + rng = jax.random.PRNGKey(exp_config.seed) + log_dir = os.path.join( + FLAGS.work_dir, + exp_config.name, + datetime.now().strftime("%Y%m%d_%H%M%S"), + ) + writer = metric_writers.create_default_writer( + logdir=log_dir, + just_logging=(jax.process_index() > 0), + ) + + logging.rank_zero_info("Building dataset...") + rng, data_rng = jax.random.split(rng, num=2) + p_datamodule = fdl.build(exp_config.data.module) + datamodule = p_datamodule( + batch_size=exp_config.data.batch_size, + deterministic=exp_config.data.deterministic, + drop_remainder=exp_config.data.drop_remainder, + num_workers=exp_config.data.num_workers, + rng=data_rng, + ) + logging.rank_zero_info( + "Building dataset %s... DONE!", + datamodule.__class__.__name__, + ) + + logging.rank_zero_info("Building model...") + rng, init_rng = jax.random.split(rng, num=2) + p_model = fdl.build(exp_config.model) + model = p_model( + dtype=exp_config.dtype, + param_dtype=exp_config.param_dtype, + precision=exp_config.precision, + ) + params = model.init(batch=None, rngs=init_rng) # NOTE: use dummy batch + logging.rank_zero_info( + "Building model %s... DONE!", + model.__class__.__name__, + ) + + logging.rank_zero_info("Building train state...") + lr_scheduler = fdl.build(exp_config.optimizer.lr_schedule) + p_optimizer = fdl.build(exp_config.optimizer.optimizer) + tx = p_optimizer(learning_rate=lr_scheduler) + if exp_config.optimizer.grad_clip_method == "norm": + tx = optax.chain( + optax.clip_by_global_norm(exp_config.optimizer.grad_clip_value), + tx, + ) + elif exp_config.optimizer.grad_clip_method == "value": + tx = optax.chain( + optax.clip(exp_config.optimizer.grad_clip_value), + tx, + ) + elif exp_config.optimizer.grad_clip_method is not None: + logging.rank_zero_error( + "Unknown grad clip method: %s", + exp_config.optimizer.grad_clip_method, + ) + return 1 + state = _train_state.TrainState.create( + params=params, + tx=tx, + ema_rate=exp_config.optimizer.ema_rate, + ) + logging.rank_zero_info("Building train state... DONE!") + + if exp_config.trainer.checkpoint_dir is not None: + logging.rank_zero_error("Resuming from checkpoint not implemented.") + return 1 + + p_training_step = functools.partial(training_step, model=model) + p_evaluation_step = functools.partial(evaluation_step, model=model) + if exp_config.mode == "train": + _train.run( + state=state, + datamodule=datamodule, + training_step=p_training_step, + evaluation_step=p_evaluation_step, + num_train_steps=exp_config.trainer.num_train_steps, + writer=writer, + work_dir=log_dir, + rng=rng, + checkpoint_every_n_steps=exp_config.trainer.checkpoint_every_n_steps, + log_every_n_steps=exp_config.trainer.log_every_n_steps, + eval_every_n_steps=exp_config.trainer.eval_every_n_steps, + profile=exp_config.trainer.profile, + ) + elif exp_config.mode == "evaluate": + _evaluate.run( + datamodule=datamodule, + evaluation_step=p_evaluation_step, + params=params, + writer=writer, + work_dir=log_dir, + rng=rng, + ) + else: + logging.rank_zero_error("Mode %s not implemented.", exp_config.mode) + return 1 + + return 0 + + +if __name__ == "__main__": + jax.config.config_with_absl() + app.run(main=main) diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 9201d26..8c17158 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -4,52 +4,37 @@ from flax import linen as nn from flax.core import frozen_dict import jax -import jax.core as jax_core -import jax.numpy as jnp +from jax import numpy as jnp +from jax._src import typing as jax_typing import jaxtyping import typing_extensions -from learning.core import mixin as _mixin -from learning.generative.model import refinenet +from src.core import model as _model +from src.projects.generative.model import unet # Type Aliases PyTree = jaxtyping.PyTree -# ============================================================================== -# Data Structures -# ============================================================================== -@chex.dataclass -class MeanFlowOutputs: - """Generic output structure from a `MeanFlow` model.""" - - loss: typing.Optional[jax.Array] = None - """jax.Array: The training loss.""" - velocity_loss: typing.Optional[jax.Array] = None - """jax.Array: The velocity loss for monitoring.""" - output: typing.Optional[jax.Array] = None - """jax.Array: The model output.""" - - # ============================================================================== # Helper functions # ============================================================================== def sample_t_r( *, - key: jax.random.KeyArray, - shape: jax_core.Shape, + key: jax.Array, + shape: jax_typing.Shape, dtype: typing.Any, - distribution: typing.Literal["uniform", "lognormal"], + distribution: str, **kwargs, ) -> typing.Tuple[jax.Array, jax.Array]: """Samples begin and end timestamps randomly from a given distribution. Attributes: - key (jax.random.KeyArray): JAX random key. - shape (jax_core.Shape): The shape of the output arrays. + key (jax.Array): JAX random key. + shape (jax.typing.Shape): The shape of the output arrays. dtype (dtype): The dtype of the output arrays. distribution (str): The distribution to sample from. - One of `["uniform", "lognormal"]`. + One of `["uniform", "logit-normal"]`. **kwargs: Additional keyword arguments for the distribution. Returns: @@ -75,28 +60,28 @@ def sample_t_r( minval=minval, maxval=maxval, ) - elif distribution == "lognormal": + elif distribution == "logit-normal": - def _lognormal( - key: jax.random.KeyArray, - shape: jax_core.Shape, + def _logit_normal( + key: jax.Array, + shape: jax_typing.Shape, dtype: typing.Any, mean: float, stddev: float, ) -> jax.Array: z = jax.random.normal(key=key, shape=shape, dtype=dtype) - return jnp.exp(mean + stddev * z) + return jax.nn.sigmoid(mean + stddev * z) mean = kwargs.get("mean", -0.4) stddev = kwargs.get("stddev", 1.0) - t = _lognormal( + t = _logit_normal( key=t_key, shape=shape, dtype=dtype, mean=mean, stddev=stddev, ) - r = _lognormal( + r = _logit_normal( key=r_key, shape=shape, dtype=dtype, @@ -106,15 +91,49 @@ def _lognormal( else: raise ValueError( f"Unsupported distribution: {distribution}. " - 'Must be one of ["uniform", "lognormal"].' + 'Must be one of ["uniform", "logit-normal"].' ) - return jnp.clip(t, a_min=0.0, a_max=1.0), jnp.clip(r, a_min=0.0, a_max=1.0) + return jnp.clip(t, 0.0, 1.0), jnp.clip(r, 0.0, 1.0) # ============================================================================== # Helper modules # ============================================================================== +class SinusoidalEmbed(nn.Module): + r"""Sinusoidal positional embeddings. + + Args: + features (int): Dimensionality of the output embeddings. + max_indx (int): Maximum index value. + endpoint (bool): Whether to include the endpoint frequency. + """ + + features: int + max_indx: int = 10_000 + endpoint: bool = False + + def setup(self) -> None: + """Instantiate a `SinusoidalEmbed` module.""" + half_dim = self.features >> 1 + freqs = jnp.arange(0, half_dim, dtype=jnp.float32) + freqs = freqs / (half_dim - (1 if self.endpoint else 0)) + self.freqs = jnp.power(1.0 / self.max_indx, freqs) + + def __call__(self, inputs: jax.Array) -> jax.Array: + r"""Forward pass and returns the sinusoidal embeddings. + + Args: + inputs (jax.Array): Input indexes of shape `(*, )`. + + Returns: + Sinusoidal embedding array of shape `(..., features)`. + """ + out = jnp.outer(inputs[..., None], self.freqs) + out = jnp.concatenate([jnp.sin(out), jnp.cos(out)], axis=-1) + return out + + class TimestampEmbed(nn.Module): """Encode scalar timestamps to vectors. @@ -277,7 +296,7 @@ def __call__( def _drop_token( cond: jax.Array, dropout_rate: float, - rng: jax.random.KeyArray, + rng: jax.Array, ) -> jax.Array: """Drops class tokens for classifier-free guidance.""" raise NotImplementedError("This method is not yet implemented.") @@ -376,139 +395,144 @@ class MeanFlowUNetModule(nn.Module): """Generative model with a RefineNet backbone trained with `MeanFlow`. Attributes: - in_channels (int): Number of channels in the input images. - image_size (int): Height and width of the input images. - latent_channels (int): Number of channels in the latent feature maps. - num_classes (int): Number of conditioning classes. + features (int): Number of channels in the latent feature maps. + num_groups (int): Number of groups for `GroupNorm` layers. + dropout_rate (float): Dropout rate for the attention blocks. + epsilon (float): Small constant for numerical stability in `GroupNorm`. + skip_scale (float): Scaling factor for skip connections. + deterministic (Optional[bool]): Whether to run deterministically. dtype (dtype): The dtype of the computation (default: float32). param_dtype (dtype): The dtype of the parameters (default: float32). """ - in_channels: int - """int: Number of channels in the input images.""" - image_size: int - """int: Height and width of the input images.""" - latent_channels: int - """int: Number of channels in the latent feature maps.""" - num_classes: int - """int: Number of conditioning classes.""" - use_cfg_embedding: bool = False - """bool: Whether to use classifier-free guidance (CFG) embedding.""" - deterministic: typing.Optional[bool] = None - """Optional[bool]: Whether to run deterministically.""" + features: int + num_groups: int = 32 dropout_rate: float = 0.0 - """float: Dropout rate for the classifier-free guidance.""" - dtype: typing.Any = jnp.float32 - """typing.Any: The dtype of the computation.""" - param_dtype: typing.Any = jnp.float32 - """typing.Any: The dtype of the parameters.""" + epsilon: float = 1e-5 + skip_scale: float = 1.0 + deterministic: typing.Optional[bool] = None + dtype: typing.Any = None + param_dtype: typing.Any = None + precision: typing.Any = None def setup(self) -> None: - """Instantiate a `MeanFlowUNetModel` module.""" - self.backbone = refinenet.ConditionalRefineNet( - in_channels=self.in_channels, - image_size=self.image_size, - latent_channels=self.latent_channels, - norm_module=ConditionalInstanceNorm, - dtype=self.dtype, - param_dtype=self.param_dtype, - ) - self.r_embed = TimestampEmbed( - features=self.latent_channels, - frequency=256, - max_stamp=10_000, - name="r_embedder", + r"""Instantiate a `MeanFlowUNetModel` module.""" + # backbone U-Net model + self.backbone = unet.ScoreNet( + features=self.features, + num_groups=self.num_groups, + epsilon=self.epsilon, + dropout_rate=self.dropout_rate, + skip_scale=self.skip_scale, dtype=self.dtype, param_dtype=self.param_dtype, ) - self.t_embed = TimestampEmbed( - features=self.latent_channels, - frequency=256, - max_stamp=10_000, - name="t_embedder", + + # conditional embeddings + self.time_embed = SinusoidalEmbed(self.features * 2, endpoint=True) + self.cond_in = nn.Dense( + features=self.features * 4, + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, dtype=self.dtype, param_dtype=self.param_dtype, + name="cond_fc_1", ) - self.label_embed = ConditionEmbed( - features=self.latent_channels, - num_classes=self.num_classes, - use_cfg_embedding=self.use_cfg_embedding, - deterministic=self.deterministic, - dropout_rate=self.dropout_rate, - name="y_embedder", + self.cond_out = nn.Dense( + features=self.features * 4, + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, dtype=self.dtype, param_dtype=self.param_dtype, + name="cond_fc_2", ) def __call__( self, image: jax.Array, - label: jax.Array, - begin: typing.Optional[jax.Array] = None, - end: typing.Optional[jax.Array] = None, + timestamps: typing.Tuple[jax.Array], deterministic: typing.Optional[bool] = None, ) -> jax.Array: - """Forward pass the `MeanFlowUNetModel`. + r"""Forward pass the `MeanFlowUNetModel`. Args: inputs (jax.Array): Input images of shape `(*, H, W, C)`. - cond (jax.Array): Conditioning labels of shape `(*,)`. begin (jax.Array): Begin timestamp `r` of shape `(*, )`. end (jax.Array): End timestamp `t` of shape `(*, )`. + deterministic (bool, optional): Whether to run deterministically. Returns: - jax.Array: The predicted average velocity of shape `(*, H, W, C)`. + The predicted average velocity of shape `(*, H, W, C)`. """ - # sanity check for the input arrays - batch_dims = image.shape[:-3] - dims = chex.Dimensions( - H=self.image_size, - W=self.image_size, - C=self.in_channels, - ) - chex.assert_shape(image, (*batch_dims, *dims["HWC"])) - chex.assert_shape(label, batch_dims) - chex.assert_shape(begin, batch_dims) - chex.assert_shape(end, batch_dims) - m_deterministic = nn.merge_param( "deterministic", self.deterministic, deterministic, ) - y_emb = self.label_embed(label, deterministic=m_deterministic) - r_emb = self.r_embed(begin) - t_emb = self.t_embed(end) - cond = t_emb + r_emb + y_emb - output = self.backbone(inputs=image, cond=cond) + emb = [self.time_embed(time) for time in timestamps] + cond = jnp.concatenate(emb, axis=-1) + cond = self.cond_out(jax.nn.silu(self.cond_in(cond))) + + output = self.backbone( + inputs=image, + cond=cond, + deterministic=m_deterministic, + ) return output -class MeanFlowUNetModel(_mixin.ModelMixin): - """`MeanFlow` generative model with a U-Net backbone.""" +class MeanFlowUNetModel(_model.Model): + r"""`MeanFlow` generative model with a U-Net backbone. - module_class = MeanFlowUNetModule - """Type[nn.Module]: The class of the model module.""" + Args: + in_channels (int): Number of input image channels. + image_size (int): Height and width of the input images. + features (int): Dimensionality of the latent feature map. + dropout_rate (float): Dropout rate for the classifier-free guidance. + epsilon (float): Small constant for numerical stability in `GroupNorm`. + skip_scale (float): Scaling factor for skip connections. + dtype (dtype): The dtype of the computation (default: float32). + param_dtype (dtype): The dtype of the parameters (default: float32). + timestamp_cond (Literal): The type of timestamp conditioning. + One of `["t_and_r", "t_and_t_minus_r", + "t_and_r_and_t_minus_r", "t_minus_r"]`. + timestamp_sampler (str): The distribution to sample timestamps from. + One of `["uniform", "logit-normal"]`. + timestamp_sampler_kwargs (Dict[str, Any]): Additional keyword arguments + for the timestamp sampler. + timestamp_overlap_rate (float): The minimum overlap rate between + begin and end timestamps. + adaptive_weight_power (float): The power for adaptive weight scaling. + """ def __init__( self, in_channels: int, image_size: int, - latent_channels: int, - num_classes: int, - use_cfg_embedding: bool, + features: int, dropout_rate: float, - dtype: typing.Any = jnp.float32, - param_dtype: typing.Any = jnp.float32, + epsilon: float, + skip_scale: float, + dtype: typing.Any = None, + param_dtype: typing.Any = None, + precision: typing.Any = None, timestamp_cond: typing.Literal[ "t_and_r", "t_and_t_minus_r", "t_and_r_and_t_minus_r", "t_minus_r", ] = "t_and_t_minus_r", - timestamp_sampler: str = "lognormal", + timestamp_sampler: str = "logit-normal", timestamp_sampler_kwargs: typing.Dict[str, typing.Any] = { "mean": -0.4, "stddev": 1.0, @@ -519,55 +543,117 @@ def __init__( """Initializes the `MeanFlow` model.""" self.in_channels = in_channels self.image_size = image_size + self.features = features self.timestamp_cond = timestamp_cond self.timestamp_sampler = timestamp_sampler self.timestamp_sampler_kwargs = timestamp_sampler_kwargs self.timestamp_overlap_rate = timestamp_overlap_rate self.adaptive_weight_power = adaptive_weight_power - self._module = MeanFlowUNetModule( - in_channels=in_channels, - image_size=image_size, - latent_channels=latent_channels, - num_classes=num_classes, - use_cfg_embedding=use_cfg_embedding, + self._network = MeanFlowUNetModule( + features=features, + skip_scale=skip_scale, dropout_rate=dropout_rate, + epsilon=epsilon, name="unet", dtype=dtype, param_dtype=param_dtype, + precision=precision, ) + @property + @typing_extensions.override + def network(self) -> MeanFlowUNetModule: + r"""MeanFlowUNetModule: The U-Net neural network module.""" + return self._network + + def init( + self, + *, + batch: typing.Any, + rngs: typing.Any, + **kwargs, + ) -> PyTree: + del batch # unused + + # create dummy inputs + if self.timestamp_cond in ["t_and_r", "t_and_t_minus_r"]: + timestamps = ( + jnp.zeros((1,), dtype=jnp.float32), + jnp.zeros((1,), dtype=jnp.float32), + ) + elif self.timestamp_cond == "t_and_r_and_t_minus_r": + timestamps = ( + jnp.zeros((1,), dtype=jnp.float32), + jnp.zeros((1,), dtype=jnp.float32), + jnp.zeros((1,), dtype=jnp.float32), + ) + elif self.timestamp_cond == "t_minus_r": + timestamps = (jnp.zeros((1,), dtype=jnp.float32),) + else: + raise ValueError( + f"Unsupported timestamp conditioning: {self.timestamp_cond}." + ) + + dummy_inputs = { + "image": jnp.zeros( + (1, self.image_size, self.image_size, self.in_channels), + dtype=jnp.float32, + ), + "timestamps": timestamps, + } + variables = self.network.init( + rngs=rngs, + image=dummy_inputs["image"], + timestamps=dummy_inputs["timestamps"], + deterministic=True, + ) + _tabulate_fn = nn.summary.tabulate(self.network, rngs=rngs) + print(_tabulate_fn(**dummy_inputs, deterministic=True)) + + return variables["params"] + @typing_extensions.override def compute_loss( self, *, - rngs: typing.Union[ - jax.random.KeyArray, - typing.Dict[str, jax.random.KeyArray], - ], - image: jax.Array, - label: jax.Array, + rngs: typing.Any, + batch: typing.Dict[str, typing.Any], params: frozen_dict.FrozenDict, deterministic: bool = False, **kwargs, - ) -> MeanFlowOutputs: - """Computes the loss given parameters and model inputs. + ) -> typing.Tuple[jax.Array, _model.StepOutputs]: + r"""Computes the loss given parameters and model inputs. Args: rngs (Union[jax.random.KeyArray, Dict[str, jax.random.KeyArray]]): JAX random key or a dictionary of JAX random keys. - image (jax.Array): The input images of shape `(*, H, W, C)`. - label (jax.Array): The class labels of shape `(*,)`. + batch (Dict[str, Any]): A batch of data containing: + - image (jax.Array): Input images of shape `(*, H, W, C)`. params (frozen_dict.FrozenDict): The model parameters. deterministic (bool): Whether to run the model deterministically. **kwargs: additional keyword arguments. Returns: - MeanFlowOutputs: The model outputs. + The computed loss and other outputs. """ + del kwargs # unused + # NOTE: following the notation in Algorithm 1 of the source paper # sample t and r + image = batch["image"] + assert isinstance(image, jax.Array) batch_dims = image.shape[:-3] - rngs, tr_rng, mask_rng, e_rng = jax.random.split(rngs, num=4) + tr_rng, dropout_rng, f_rng, m_rng, e_rng = jax.random.split(rngs, 5) + + # randomly flip image horizontally for data augmentation + flip_mask = jax.random.bernoulli(key=f_rng, p=0.5, shape=batch_dims) + image = jnp.where( + flip_mask[..., None, None, None], + jnp.flip(image, axis=-2), + image, + ) + + # sample begin and end timestamps t, r = sample_t_r( key=tr_rng, shape=batch_dims, @@ -577,17 +663,12 @@ def compute_loss( ) t, r = jnp.maximum(t, r), jnp.minimum(t, r) # ensure a portion of overlap between t and r - r_neq_t_mask = jnp.greater_equal( - jax.random.uniform( - key=mask_rng, - shape=batch_dims, - dtype=image.dtype, - minval=0.0, - maxval=1.0, - ), + # NOTE: the following code randomly mask by uniform samples + r_eq_t_mask = jnp.less( + jax.random.uniform(key=m_rng, shape=batch_dims, dtype=image.dtype), self.timestamp_overlap_rate, ) - r = jnp.where(r_neq_t_mask, t, r) + r = jnp.where(r_eq_t_mask, t, r) # sample e ~ N(0, I) e = jax.random.normal(key=e_rng, shape=image.shape, dtype=image.dtype) @@ -600,31 +681,62 @@ def compute_loss( v = e - image # applies Jacobian vector product + def u_fn( + z_t: jax.Array, + r_in: jax.Array, + t_in: jax.Array, + ) -> jax.Array: + if self.timestamp_cond == "t_and_r": + timestamps = (r_in, t_in) + elif self.timestamp_cond == "t_and_t_minus_r": + timestamps = (t_in - r_in, t_in) + elif self.timestamp_cond == "t_and_r_and_t_minus_r": + timestamps = (t_in, r_in, t_in - r_in) + elif self.timestamp_cond == "t_minus_r": + timestamps = (t_in - r_in,) + else: + raise ValueError( + f"Unsupported timestamp conditioning: {self.timestamp_cond}." + ) + + out = self.network.apply( + variables={"params": params}, + image=z_t, + timestamps=timestamps, + deterministic=deterministic, + rngs={"dropout": dropout_rng}, + ) + assert isinstance(out, jax.Array) + + return out + + # NOTE: following the original meanflow drdt = jnp.zeros_like(r) dtdt = jnp.ones_like(t) - - u, dudt = jax.jvp( - self.u_fn( - label=label, - params=params, - deterministic=deterministic, - ), - (z, r, t), - (v, drdt, dtdt), - ) + u, dudt = jax.jvp(u_fn, (z, r, t), (v, drdt, dtdt)) + u_target = v - (t - r)[..., None, None, None] * dudt + + # NOTE: following the symmetric meanflow + # drdt = jnp.ones_like(r) + # dtdt = jnp.negative(jnp.ones_like(t)) + # u, dudt = jax.jvp(u_fn, (z, r, t), (-v, drdt, dtdt)) + # u_target = jax.lax.stop_gradient( + # v + # - jnp.clip(t - r, a_min=0.0, a_max=1.0)[..., None, None, None] + # * dudt + # * 0.5 + # ) # computes the target - u_target = jax.lax.stop_gradient( - v - - jnp.clip(t - r, a_min=0.0, a_max=1.0)[..., None, None, None] - * dudt - ) # NOTE: sum over all the pixels, following official implementation - loss = jnp.sum(jnp.square(u - u_target), axis=(-1, -2, -3)) + loss = jnp.sum( + jnp.square(u - jax.lax.stop_gradient(u_target)), + axis=(-1, -2, -3), + ) # applies adaptive weight power if self.adaptive_weight_power > 0.0: - ada_wt = jnp.power(loss + 1e-2, self.adaptive_weight_power) + ada_wt = jnp.power(loss + 1e-3, self.adaptive_weight_power) loss = loss / jax.lax.stop_gradient(ada_wt) loss = jnp.mean(loss) @@ -636,136 +748,65 @@ def compute_loss( ) velocity_loss = jnp.sum(velocity_loss, axis=(-1, -2, -3)).mean() - return MeanFlowOutputs( - loss=loss, - velocity_loss=velocity_loss, - output=u, + out = _model.StepOutputs( + scalars={"loss": loss, "velocity_loss": velocity_loss}, + histograms={"t": t, "r": r, "t - r": t - r}, ) + return loss, out + @typing_extensions.override def forward( self, *, - rngs: typing.Union[ - jax.random.KeyArray, - typing.Dict[str, jax.random.KeyArray], - ], + rngs: jax.Array, params: frozen_dict.FrozenDict, - image: jax.Array, - label: jax.Array, - begin: typing.Optional[jax.Array] = None, - end: typing.Optional[jax.Array] = None, - deterministic: bool = False, + batch: typing.Dict[str, typing.Any], + deterministic: bool = True, **kwargs, - ) -> MeanFlowOutputs: - """Forward sampling with average velocity prediction. + ) -> _model.StepOutputs: + r"""Forward sampling with average velocity prediction. Args: + rngs (jax.Array): Random key for sampling. params (frozen_dict.FrozenDict): The model parameters. - image (jax.Array): Input latent image `z_t` of shape `(*, H, W, C)`. - label (jax.Array): Conditioning labels of shape `(*,)`. - begin (jax.Array): Begin timestamp `r` of shape `(*, )`. - end (jax.Array): End timestamp `t` of shape `(*, )`. + batch (Dict[str, Any]): A batch of data containing: + - image (jax.Array): Input images of shape `(*, H, W, C)`. + shape (jax.typing.Shape): The shape of the output samples. + dtype (Any): The dtype of the output samples. deterministic (bool): Whether to run the model deterministically. **kwargs: Additional keyword arguments. Returns: - MeanFlowOutputs: The model outputs. + The output samples. """ - batch_dims = image.shape[:-3] - dims = chex.Dimensions( - H=self.image_size, - W=self.image_size, - C=self.in_channels, - ) - chex.assert_shape(image, (*batch_dims, *dims["HWC"])) - chex.assert_shape(label, batch_dims) - - if begin is None: - begin = jnp.zeros(batch_dims, dtype=image.dtype) - if end is None: - end = jnp.ones(batch_dims, dtype=image.dtype) - chex.assert_shape(begin, batch_dims) - assert jnp.all(begin >= 0) and jnp.all( - begin <= 1 - ), "Invalid input `r`." - chex.assert_shape(end, batch_dims) - assert jnp.all(end >= 0) and jnp.all(end <= 1), "Invalid input `t`." - r, t = jnp.minimum(begin, end), jnp.maximum( - begin, end - ) # ensure r <= t - - sample = jnp.subtract( - image, - jnp.einsum( - "...,...n->...n", - (t - r), - self.u_fn( - label=label, - params=params, - deterministic=deterministic, - )(image, r, t), - ), - ) + del kwargs # unused - return MeanFlowOutputs(output=sample) - - @property - def dummy_input(self) -> PyTree: - """PyTree: A dictionary mapping feature names to example arrays.""" - return { - "image": jnp.zeros( - (1, self.image_size, self.image_size, self.in_channels), - dtype=jnp.float32, - ), - "label": jnp.zeros((1,), dtype=jnp.int32), - "begin": jnp.zeros((1,), dtype=jnp.float32), - "end": jnp.zeros((1,), dtype=jnp.float32), - } + # TODO (juanwulu): unconditional generation + image = batch["image"] + shape, dtype = image.shape, image.dtype - def u_fn( - self, - *, - label: jax.Array, - params: frozen_dict.FrozenDict, - deterministic: bool = True, - ) -> typing.Callable[[jax.Array, jax.Array, jax.Array], jax.Array]: - """Returns the average velocity function `u(z_t, r, t)`.""" + e = jax.random.normal(key=rngs, shape=shape, dtype=dtype) + r = jnp.zeros(e.shape[:-3], dtype=dtype) + t = jnp.ones(e.shape[:-3], dtype=dtype) if self.timestamp_cond == "t_and_r": - return lambda z_t, r, t: self._module.apply( - variables={"params": params}, - image=z_t, - label=label, - begin=r, - end=t, - deterministic=deterministic, - ) + timestamps = (t, r) elif self.timestamp_cond == "t_and_t_minus_r": - return lambda z_t, r, t: self._module.apply( - variables={"params": params}, - image=z_t, - label=label, - begin=t - r, - end=t, - deterministic=deterministic, - ) + timestamps = (t, t - r) elif self.timestamp_cond == "t_and_r_and_t_minus_r": - # TODO: implement this - raise NotImplementedError( - "Conditioning on (t, r, t - r) is not implemented yet." - ) + timestamps = (t, r, t - r) elif self.timestamp_cond == "t_minus_r": - return lambda z_t, r, t: self._module.apply( - variables={"params": params}, - image=z_t, - label=label, - begin=t - r, - end=None, - deterministic=deterministic, - ) + timestamps = (t - r,) else: raise ValueError( - f"Unsupported timestamp condition: {self.timestamp_cond}. " - 'Must be one of ["t_and_r", "t_and_t_minus_r", ' - '"t_and_r_and_t_minus_r", "t_minus_r"].' + f"Unsupported timestamp conditioning: {self.timestamp_cond}." ) + + out = e - self.network.apply( + variables={"params": params}, + image=e, + timestamps=timestamps, + deterministic=deterministic, + ) + + return _model.StepOutputs(output=out) diff --git a/src/projects/generative/model/BUILD b/src/projects/generative/model/BUILD index aed17c1..7b57aef 100644 --- a/src/projects/generative/model/BUILD +++ b/src/projects/generative/model/BUILD @@ -1,6 +1,6 @@ -load("//learning:defs.bzl", "ml_py_library", "ml_py_test") +load("//third_party:defs.bzl", "ml_py_library", "ml_py_test") -package(default_visibility = ["//learning/generative:__subpackages__"]) +package(default_visibility = ["//src/projects/generative:__subpackages__"]) ml_py_library( name = "refinenet", @@ -9,7 +9,6 @@ ml_py_library( "chex", "flax", "jax", - "jaxlib", ], ) @@ -20,7 +19,25 @@ ml_py_test( "chex", "flax", "jax", - "jaxlib", ":refinenet", ], ) + +ml_py_library( + name = "unet", + srcs = ["unet.py"], + deps = [ + "chex", + "flax", + "jax", + ], +) + +ml_py_test( + name = "test_unet", + srcs = ["test_unet.py"], + deps = [ + "jax", + ":unet", + ], +) diff --git a/src/projects/generative/model/refinenet.py b/src/projects/generative/model/refinenet.py index 9be1ed9..b66d620 100644 --- a/src/projects/generative/model/refinenet.py +++ b/src/projects/generative/model/refinenet.py @@ -5,6 +5,7 @@ import jax from jax._src import core as jax_core from jax._src import dtypes as jax_dtypes +from jax._src import typing as jax_typing import jax.numpy as jnp @@ -12,23 +13,25 @@ # Builder functions # ============================================================================== def _uniform_init() -> jax.nn.initializers.Initializer: - """Uniform initializer for convolutional layers.""" + r"""Uniform initializer for convolutional layers.""" def init( - key: jax.random.KeyArray, - shape: jax_core.Shape, - dtype: typing.Any, + key: jax.Array, + shape: jax_typing.Shape, + dtype: typing.Any = jnp.float_, + out_sharding: typing.Any = None, ) -> jax.Array: """Uniform initializer for one-dimensional parameters.""" dim = shape[-1] dtype = jax_dtypes.canonicalize_dtype(dtype) - named_shape = jax_core.as_named_shape(shape) + named_shape = jax_core.canonicalize_shape(shape) return jax.random.uniform( key=key, shape=named_shape, dtype=dtype, minval=-jnp.sqrt(1.0 / dim), maxval=jnp.sqrt(1.0 / dim), + out_sharding=out_sharding, ) return init @@ -42,7 +45,7 @@ def _conv_1x1( dtype: typing.Any = jnp.float32, param_dtype: typing.Any = jnp.float32, ) -> nn.Conv: - """1x1 convolution with stride and padding.""" + r"""1x1 convolution with stride and padding.""" return nn.Conv( features=out_channels, kernel_size=(1, 1), @@ -69,7 +72,7 @@ def _conv_3x3( dtype: typing.Any = jnp.float32, param_dtype: typing.Any = jnp.float32, ) -> nn.Conv: - """3x3 convolution with stride and padding.""" + r"""3x3 convolution with stride and padding.""" return nn.Conv( features=out_channels, kernel_size=(3, 3), @@ -96,7 +99,7 @@ def _dilated_conv_3x3( dtype: typing.Any = jnp.float32, param_dtype: typing.Any = jnp.float32, ) -> nn.Conv: - """3x3 dilated convolution with dilation and padding.""" + r"""3x3 dilated convolution with dilation and padding.""" return nn.Conv( features=out_channels, kernel_size=(3, 3), @@ -120,7 +123,7 @@ def _dilated_conv_3x3( # Layers # ============================================================================== class ConditionalInstanceNorm2dPlus(nn.Module): - """Conditional Instance Normalization with extra affine transformation.""" + r"""Conditional Instance Normalization with extra affine transformation.""" features: int """int: Dimensionality of the feature map.""" @@ -148,15 +151,21 @@ def setup(self) -> None: ) def _kernel_init( - key: jax.random.KeyArray, - shape: jax_core.Shape, + key: typing.Any, + shape: jax_typing.Shape, dtype: typing.Any, + out_sharding: typing.Any = None, ) -> jax.Array: dtype = jax_dtypes.canonicalize_dtype(dtype) - named_shape = jax_core.as_named_shape(shape) + named_shape = jax_core.canonicalize_shape(shape) return ( 1.0 - + jax.random.normal(key=key, shape=named_shape, dtype=dtype) + + jax.random.normal( + key=key, + shape=named_shape, + dtype=dtype, + out_sharding=out_sharding, + ) * 0.02 ) @@ -187,7 +196,7 @@ def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: cond (jax.Array): Condition feature map of shape `(*, )`. Returns: - jax.Array: Output feature map of shape `(*, H, W, C)`. + Output feature map of shape `(*, H, W, C)`. """ batch_dims = inputs.shape[:-3] chex.assert_shape(cond, (*batch_dims,)) @@ -222,7 +231,7 @@ def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: class ConvMeanPool(nn.Module): - """Convolution followed by average pooling.""" + r"""Convolution followed by average pooling.""" features: int """int: Number of output channels.""" @@ -248,13 +257,13 @@ def setup(self) -> None: ) def __call__(self, inputs: jax.Array) -> jax.Array: - """Forward pass of the `ConvMeanPool` module. + r"""Forward pass of the `ConvMeanPool` module. Args: inputs (jax.Array): Input feature map of shape `(*, H, W, C)`. Returns: - jax.Array: Output feature map of shape `(*, H/2, W/2, C_out)`. + Output feature map of shape `(*, H/2, W/2, C_out)`. """ batch_dims = inputs.shape[:-3] if self.adjust_padding: @@ -278,14 +287,14 @@ def __call__(self, inputs: jax.Array) -> jax.Array: # Modules # ============================================================================== class ConditionalResidualBlock(nn.Module): - """Residual block with conditioning feature map.""" + r"""Residual block with conditioning feature map.""" in_channels: int """int: Number of channels of the input feature map.""" out_channels: int """int: Number of channels of the output feature map.""" - norm_module: typing.Callable[[typing.Any], typing.Type[nn.Module]] - """Callable[[typing.Any], Type[nn.Module]]: Normalization module to use.""" + norm_module: typing.Callable[..., nn.Module] + """Callable[..., nn.Module]: Normalization module to use.""" dilation: typing.Optional[int] = None """Optional[int]: Optional dilations in the convolutional layers.""" resample: typing.Optional[str] = None @@ -298,7 +307,7 @@ class ConditionalResidualBlock(nn.Module): """param_dtype: The data type of the parameters (default: float32).""" def setup(self) -> None: - """Instantiate a conditional residual block.""" + r"""Instantiate a conditional residual block.""" self.norm_1 = self.norm_module( features=self.in_channels, name="normalize1", @@ -426,14 +435,14 @@ def setup(self) -> None: ) def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: - """Forward pass of the conditional residual block. + r"""Forward pass of the conditional residual block. Args: inputs (jax.Array): Input feature map of shape `(*, H, W, D)`. cond (jax.Array): Condition feature map of shape `(*,)`. Returns: - jax.Array: Output feature map of shape `(*, H, W, D)`. + Output feature map of shape `(*, H, W, D)`. """ output = self.norm_1(inputs, cond) output = jax.nn.elu(output) @@ -452,12 +461,12 @@ def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: class ConditionalRCUBlock(nn.Module): - """Refinement Convolution Unit (RCU) block with conditioning feature map.""" + r"""Refinement Convolution Unit (RCU) block with conditioning features.""" features: int """int: Dimensionality of the feature map.""" - norm_module: typing.Callable[[typing.Any], typing.Type[nn.Module]] - """Callable[[typing.Any], Type[nn.Module]]: Normalization module to use.""" + norm_module: typing.Callable[..., nn.Module] + """Callable[..., nn.Module]: Normalization module to use.""" num_blocks: int """int: Number of repeated blocks in the cascade.""" num_stages: int @@ -468,7 +477,7 @@ class ConditionalRCUBlock(nn.Module): """param_dtype: The data type of the parameters (default: float32).""" def setup(self) -> None: - """Instantiate a `ConditionalRCUBlock` module.""" + r"""Instantiate a `ConditionalRCUBlock` module.""" convs, norms = [], [] for i in range(self.num_blocks): for j in range(self.num_stages): @@ -491,11 +500,19 @@ def setup(self) -> None: param_dtype=self.param_dtype, ) ) - self.convs: typing.Tuple[nn.Conv, ...] = convs - self.norms: typing.Tuple[ConditionalInstanceNorm2dPlus, ...] = norms + self.convs = convs + self.norms = norms def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: - """Forward pass of the `ConditionalRCUBlock` module.""" + r"""Forward pass of the `ConditionalRCUBlock` module. + + Args: + inputs (jax.Array): Input feature map of shape `(*, H, W, C)`. + cond (jax.Array): Condition feature map of shape `(*, H, W, d)`. + + Returns: + Output feature map of shape `(*, H, W, C)`. + """ _idx: int = 0 output = inputs for _ in range(self.num_blocks): @@ -511,21 +528,21 @@ def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: class ConditionalMSFBlock(nn.Module): - """Conditional Multi-Scale Feature block.""" + r"""Conditional Multi-Scale Feature block.""" in_features: typing.Sequence[int] """Sequence[int]: List of input feature map dimensionalities.""" features: int """int: Dimensionality of the output feature map.""" - norm_module: typing.Callable[[typing.Any], typing.Type[nn.Module]] - """Callable[[typing.Any], Type[nn.Module]]: Normalization module to use.""" + norm_module: typing.Callable[..., nn.Module] + """Callable[..., nn.Module]: Normalization module to use.""" dtype: typing.Any = jnp.float32 """dtype: The data type of the computation (default: float32).""" param_dtype: typing.Any = jnp.float32 """param_dtype: The data type of the parameters (default: float32).""" def setup(self) -> None: - """Instantiate a `ConditionalMSFBlock` module.""" + r"""Instantiate a `ConditionalMSFBlock` module.""" convs, norms = [], [] for i, in_feature in enumerate(self.in_features): convs.append( @@ -547,25 +564,25 @@ def setup(self) -> None: param_dtype=self.param_dtype, ) ) - self.convs: typing.Tuple[nn.Conv, ...] = convs - self.norms: typing.Tuple[ConditionalInstanceNorm2dPlus, ...] = norms + self.convs = convs + self.norms = norms def __call__( self, inputs: typing.Sequence[jax.Array], cond: jax.Array, - shape: jax_core.Shape, + shape: jax_typing.Shape, ) -> jax.Array: - """Forward pass of the `ConditionalMSFBlock` module. + r"""Forward pass of the `ConditionalMSFBlock` module. Args: inputs (Sequence[jax.Array]): Sequence of input feature maps to be merged. Each feature map has shape `(*, H_i, W_i, C) cond (jax.Array): Condition feature map of shape `(*, H, W, d)`. - shape (jax_core.Shape): Shape of the output feature map. + shape (jax._src.typing.Shape): Shape of the output feature map. Returns: - jax.Array: Output feature map of shape `(*, H, W, C)`. + Output feature map of shape `(*, H, W, C)`. """ assert isinstance(inputs, typing.Sequence) and len(inputs) == len( self.in_features @@ -586,12 +603,12 @@ def __call__( class ConditionalCRPBlock(nn.Module): - """Conditional convolutional residual pooling (CRP) block.""" + r"""Conditional convolutional residual pooling (CRP) block.""" features: int """int: Dimensionality of the output feature map.""" - norm_module: typing.Callable[[typing.Any], typing.Type[nn.Module]] - """Callable[Any, Type[nn.Module]]: Normalization module to use.""" + norm_module: typing.Callable[..., nn.Module] + """Callable[..., nn.Module]: Normalization module to use.""" num_stages: int """int: Number of stages in the cascade.""" dtype: typing.Any = jnp.float32 @@ -600,7 +617,7 @@ class ConditionalCRPBlock(nn.Module): """param_dtype: The data type of the parameters (default: float32).""" def setup(self) -> None: - """Instantiate a `ConditionalCRPBlock` module.""" + r"""Instantiate a `ConditionalCRPBlock` module.""" convs, norms = [], [] for i in range(self.num_stages): convs.append( @@ -622,18 +639,18 @@ def setup(self) -> None: param_dtype=self.param_dtype, ) ) - self.convs: typing.Tuple[nn.Conv, ...] = convs - self.norms: typing.Tuple[ConditionalInstanceNorm2dPlus, ...] = norms + self.convs = convs + self.norms = norms def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: - """Forward pass of the `ConditionalCRPBlock` module. + r"""Forward pass of the `ConditionalCRPBlock` module. Args: inputs (jax.Array): Input feature map of shape `(*, H, W, C)`. cond (jax.Array): Condition feature map of shape `(*, H, W, d)`. Returns: - jax.Array: Output feature map of shape `(*, H, W, C)`. + Output feature map of shape `(*, H, W, C)`. """ output = jax.nn.elu(inputs) path = output @@ -643,7 +660,7 @@ def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: inputs=path, window_shape=(5, 5), strides=(1, 1), - padding=((2, 2), (2, 2)), + padding=((2, 2), (2, 2)), # type: ignore ) path = conv(path) output = output + path @@ -651,14 +668,14 @@ def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: class ConditionalRefineBlock(nn.Module): - """Refinement block with skip connections and conditioning feature map.""" + r"""Refinement block with skip connections and conditioning feature map.""" in_features: typing.Sequence[int] """Sequence[int]: List of input feature map dimensionalities.""" out_features: int """int: Number of output channels of each convolution.""" - norm_module: typing.Callable[[typing.Any], typing.Type[nn.Module]] - """Callable[Any, Type[nn.Module]]: Normalization module to use.""" + norm_module: typing.Callable[[typing.Any], nn.Module] + """Callable[Any, nn.Module]: Normalization module to use.""" is_last_block: bool = False """bool: If True, this is the last refinement block.""" dtype: typing.Any = jnp.float32 @@ -681,7 +698,7 @@ def setup(self) -> None: param_dtype=self.param_dtype, ) ) - self.adapt_convs: typing.Tuple[ConditionalRCUBlock, ...] = adapt_convs + self.adapt_convs: typing.List[ConditionalRCUBlock] = adapt_convs self.output_convs = ConditionalRCUBlock( features=self.out_features, norm_module=self.norm_module, @@ -715,18 +732,18 @@ def __call__( self, inputs: typing.List[jax.Array], cond: jax.Array, - output_shape: jax_core.Shape, + output_shape: jax_typing.Shape, ) -> jax.Array: - """Forward pass of the refinement block. + r"""Forward pass of the refinement block. Args: inputs (List[jax.Array]): List of input feature maps to be merged. Each feature map has shape `(*, H_i, W_i, C)`. cond (jax.Array): Condition feature map of shape `(*, H, W, d)`. - output_shape (jax_core.Shape): Shape of the output feature map. + output_shape (jax._src.typing.Shape): Shape of the output feature. Returns: - jax.Array: Output feature map of shape `(*, H, W, 128)`. + Output feature map of shape `(*, H, W, 128)`. """ assert ( isinstance(inputs, typing.Sequence) @@ -760,13 +777,11 @@ def __call__( # Models # ============================================================================== class ConditionalRefineNet(nn.Module): - """Multi-path Refinement Network with Conditional Instance Normlization. - - .. note:: + r"""Multi-path Refinement Network with Conditional Instance Normlization. - This module is adapted from the original implementation of - `CondRefineNetDeeperDilated` in the NCSN official repository: - `https://github.com/ermongroup/ncsn/blob/master/models/cond_refinenet_dilated.py` + This module is adapted from the original implementation of + `CondRefineNetDeeperDilated` in the NCSN official repository: + `https://github.com/ermongroup/ncsn/blob/master/models/cond_refinenet_dilated.py` Attributes: in_channels (int): Number of channels of the input feature map. @@ -780,12 +795,12 @@ class ConditionalRefineNet(nn.Module): in_channels: int """int: Number of channels of the input feature map.""" - image_size: typing.Literal[28, 32] + image_size: int """int: Size of the input (square) image, either `28` or `32`.""" latent_channels: int """int: Number of channels of the latent feature map.""" - norm_module: typing.Callable[[typing.Any], typing.Type[nn.Module]] - """Callable[[typing.Any], Type[nn.Module]]: Normalization module to use.""" + norm_module: typing.Callable[..., nn.Module] + """Callable[..., nn.Module]: Normalization module to use.""" dtype: typing.Any = jnp.float32 """dtype: The data type of the computation (default: float32).""" param_dtype: typing.Any = jnp.float32 @@ -793,6 +808,11 @@ class ConditionalRefineNet(nn.Module): def setup(self) -> None: """Instantiate a Refinement Network module.""" + if self.image_size not in [28, 32]: + raise ValueError( + "`image_size` must be either `28` or `32`, " + f"but got {self.image_size}." + ) self.conv_in = _conv_3x3( out_channels=self.latent_channels, @@ -947,14 +967,14 @@ def __call__( cond: jax.Array, **kwargs, # type: ignore[unused-argument] ) -> jax.Array: - """Forward pass of the conditional refinement network. + r"""Forward pass of the conditional refinement network. Args: inputs (jax.Array): Input feature map of shape `(*, H, W, C)`. cond (jax.Array): Condition feature map of shape `(*,)`. Returns: - jax.Array: Output feature map of shape `(*, H, W, C)`. + Output feature map of shape `(*, H, W, C)`. """ batch_dims = inputs.shape[:-3] dims = chex.Dimensions( @@ -1019,11 +1039,11 @@ def __call__( @staticmethod def _forward_cond_res_block( - module: nn.Module, + module: typing.Sequence[nn.Module], inputs: jax.Array, cond: jax.Array, ) -> jax.Array: - """Forward pass through a residual block with conditional inputs.""" + r"""Forward pass through a residual block with conditional inputs.""" for m in module: assert isinstance(m, ConditionalResidualBlock) inputs = m(inputs=inputs, cond=cond) diff --git a/src/projects/generative/model/test_refinenet.py b/src/projects/generative/model/test_refinenet.py index 7c4a7b2..cd78fc1 100644 --- a/src/projects/generative/model/test_refinenet.py +++ b/src/projects/generative/model/test_refinenet.py @@ -5,16 +5,16 @@ import chex from flax import linen as nn import jax -import jax.numpy as jnp +from jax import numpy as jnp import pytest -from learning.generative.model import refinenet +from src.projects.generative.model import refinenet @pytest.mark.parametrize("out_channels", [1, 3]) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) def test_conv_1x1(out_channels: int, dtype: typing.Any) -> None: - """Test 1x1 convolution builder.""" + r"""Test 1x1 convolution builder.""" layer = refinenet._conv_1x1( out_channels=out_channels, name="conv1", @@ -40,7 +40,7 @@ def test_conv_1x1(out_channels: int, dtype: typing.Any) -> None: @pytest.mark.parametrize("out_channels", [1, 3]) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) def test_conv_3x3(out_channels: int, dtype: typing.Any) -> None: - """Test 3x3 convolution builder.""" + r"""Test 3x3 convolution builder.""" layer = refinenet._conv_3x3( out_channels=out_channels, name="conv3", @@ -71,7 +71,7 @@ def test_dilated_conv_3x3( dilation: int, dtype: typing.Any, ) -> None: - """Test dilated 3x3 convolution builder.""" + r"""Test dilated 3x3 convolution builder.""" layer = refinenet._dilated_conv_3x3( out_channels=out_channels, dilation=dilation, @@ -106,7 +106,7 @@ def test_conditional_instance_norm_2d_plus( use_bias: bool, dtype: typing.Any, ) -> None: - """Test `ConditionalInstanceNorm2dPlus` layer.""" + r"""Test `ConditionalInstanceNorm2dPlus` layer.""" layer = refinenet.ConditionalInstanceNorm2dPlus( features=features, num_classes=num_classes, @@ -136,7 +136,7 @@ def test_conditional_instance_norm_2d_plus( chex.assert_type(variables["params"]["embed"]["embedding"], dtype) test_output = layer.apply( variables, - jnp.ones((1, 32, 32, features)), + jnp.ones((1, 32, 32, features), dtype=dtype), jnp.ones((1,), dtype=jnp.int32), ) chex.assert_type(test_output, dtype) @@ -151,7 +151,7 @@ def test_conv_mean_pool( kernel_size: int, dtype: typing.Any, ) -> None: - """Test `ConvMeanPool` layer.""" + r"""Test `ConvMeanPool` layer.""" layer = refinenet.ConvMeanPool( features=features, kernel_size=kernel_size, @@ -188,7 +188,7 @@ def test_conditional_residual_block( resample: typing.Optional[str], dtype: typing.Any, ) -> None: - """Test `ConditionalResidualBlock` module.""" + r"""Test `ConditionalResidualBlock` module.""" if resample not in (None, "down"): with pytest.raises(ValueError): block = refinenet.ConditionalResidualBlock( @@ -357,7 +357,7 @@ def test_conditional_residual_block( @pytest.mark.parametrize("features", [1, 3]) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) def test_conditional_rcu_block(features: int, dtype: typing.Any) -> None: - """Test the `ConditionalRCUBlock` module.""" + r"""Test the `ConditionalRCUBlock` module.""" block = refinenet.ConditionalRCUBlock( features=features, norm_module=functools.partial( @@ -451,20 +451,20 @@ def test_conditional_msf_block(features: int, dtype: typing.Any) -> None: test_output = block.apply( variables, inputs=[ - jnp.ones((2, 32, 32, 3), dtype=jnp.float32), - jnp.ones((2, 16, 16, 8), dtype=jnp.float32), + jnp.ones((2, 32, 32, 3), dtype=dtype), + jnp.ones((2, 16, 16, 8), dtype=dtype), ], cond=jnp.ones((2,), dtype=jnp.int32), shape=(28, 28), ) - chex.assert_type(test_output, jnp.float32) + chex.assert_type(test_output, dtype) chex.assert_shape(test_output, (2, 28, 28, features)) @pytest.mark.parametrize("features", [1, 3]) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) def test_conditional_crp_block(features: int, dtype: typing.Any) -> None: - """Test the `ConditionalCRPBlock` module.""" + r"""Test the `ConditionalCRPBlock` module.""" block = refinenet.ConditionalCRPBlock( features=features, norm_module=functools.partial( @@ -486,18 +486,15 @@ def test_conditional_crp_block(features: int, dtype: typing.Any) -> None: variables["params"][f"convs.{i:d}"]["kernel"], (3, 3, features, features), ) - chex.assert_type( - variables["params"][f"convs.{i:d}"]["kernel"], - jnp.float32, - ) + chex.assert_type(variables["params"][f"convs.{i:d}"]["kernel"], dtype) assert variables["params"][f"convs.{i:d}"].get("bias") is None test_output = block.apply( variables, - jnp.ones((1, 32, 32, features), dtype=jnp.float32), + jnp.ones((1, 32, 32, features), dtype=dtype), jnp.ones((1,), dtype=jnp.int32), ) - chex.assert_type(test_output, jnp.float32) + chex.assert_type(test_output, dtype) chex.assert_shape(test_output, (1, 32, 32, features)) @@ -509,10 +506,10 @@ def test_conditional_refine_block( dtype: typing.Any, is_last_block: bool, ) -> None: - """Test the `ConditionalRefineBlock` module.""" + r"""Test the `ConditionalRefineBlock` module.""" test_inputs = [ - jnp.ones((2, 32, 32, 3), dtype=jnp.float32), - jnp.ones((2, 16, 16, 8), dtype=jnp.float32), + jnp.ones((2, 32, 32, 3), dtype=dtype), + jnp.ones((2, 16, 16, 8), dtype=dtype), ] block = refinenet.ConditionalRefineBlock( in_features=[3, 8], @@ -534,7 +531,7 @@ def test_conditional_refine_block( _ = block.init( jax.random.PRNGKey(0), inputs=[ - jnp.ones((2, 32, 32, 3), dtype=jnp.float32), + jnp.ones((2, 32, 32, 3), dtype=dtype), ], cond=jnp.ones((2,), dtype=jnp.int32), output_shape=(28, 28), @@ -551,12 +548,13 @@ def test_conditional_refine_block( cond=jnp.ones((2,), dtype=jnp.int32), output_shape=(28, 28), ) - chex.assert_type(test_output, jnp.float32) + chex.assert_type(test_output, dtype) chex.assert_shape(test_output, (2, 28, 28, features)) -def test_conditional_refinenet() -> None: - """Integrated test for the `ConditionalRefineNet` module.""" +@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) +def test_conditional_refinenet(dtype: typing.Any) -> None: + r"""Integrated test for the `ConditionalRefineNet` module.""" model = refinenet.ConditionalRefineNet( in_channels=3, image_size=32, @@ -565,17 +563,17 @@ def test_conditional_refinenet() -> None: refinenet.ConditionalInstanceNorm2dPlus, num_classes=10, ), - dtype=jnp.float32, - param_dtype=jnp.float32, + dtype=dtype, + param_dtype=dtype, ) assert isinstance(model, nn.Module) with pytest.raises(AssertionError): _ = model.init( jax.random.PRNGKey(0), - jnp.ones((2, 28, 28, 1), dtype=jnp.float32), + jnp.ones((2, 28, 28, 1), dtype=dtype), jnp.ones((2,), dtype=jnp.int32), ) - test_input = jnp.ones((2, 32, 32, 3), dtype=jnp.float32) + test_input = jnp.ones((2, 32, 32, 3), dtype=dtype) variables = model.init( jax.random.PRNGKey(0), test_input, @@ -586,7 +584,7 @@ def test_conditional_refinenet() -> None: test_input, jnp.ones((2,), dtype=jnp.int32), ) - chex.assert_type(test_output, jnp.float32) + chex.assert_type(test_output, dtype) chex.assert_shape(test_output, (2, 32, 32, 3)) diff --git a/src/projects/generative/model/test_unet.py b/src/projects/generative/model/test_unet.py new file mode 100644 index 0000000..82fe4dd --- /dev/null +++ b/src/projects/generative/model/test_unet.py @@ -0,0 +1,152 @@ +import sys +import typing + +import jax +from jax import numpy as jnp +import pytest + +from src.projects.generative.model import unet + + +@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) +def test_resnet_block(dtype: typing.Any) -> None: + r"""Tests the residual downsampling block in U-Net models.""" + rng = jax.random.PRNGKey(42) + + block = unet.ResNetBlock(features=64, dtype=dtype, param_dtype=dtype) + test_input = jnp.ones((2, 32, 32, 32), dtype=dtype) + test_cond = jnp.ones((2, 16), dtype=dtype) + params_rng, dropout_rng = jax.random.split(rng, num=2) + variables = block.init( + rngs={"params": params_rng}, + inputs=test_input, + cond=test_cond, + deterministic=False, + ) + outputs = block.apply( + variables=variables, + inputs=test_input, + cond=test_cond, + deterministic=False, + rngs={"dropout": dropout_rng}, + ) + assert isinstance(outputs, jax.Array) + assert outputs.shape == (2, 32, 32, 64) + assert outputs.dtype == dtype + + +@pytest.mark.parametrize("with_conv", [True, False]) +@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) +def test_downsample_block(with_conv: bool, dtype: typing.Any) -> None: + r"""Tests the downsampling block in U-Net models.""" + rng = jax.random.PRNGKey(42) + + block = unet.DownsampleBlock( + with_conv=with_conv, + dtype=dtype, + param_dtype=dtype, + ) + test_input = jnp.ones((2, 32, 32, 32), dtype=dtype) + variables = block.init( + rngs={"params": rng}, + inputs=test_input, + ) + if with_conv: + assert "conv0" in variables["params"] + kernel = variables["params"]["conv0"]["kernel"] + assert isinstance(kernel, jax.Array) + assert kernel.shape == (3, 3, 32, 32) + bias = variables["params"]["conv0"]["bias"] + assert isinstance(bias, jax.Array) + assert bias.shape == (32,) + + outputs = block.apply(variables=variables, inputs=test_input) + assert isinstance(outputs, jax.Array) + assert outputs.shape == (2, 16, 16, 32) + assert outputs.dtype == dtype + + +@pytest.mark.parametrize("with_conv", [True, False]) +@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) +def test_upsample_block(with_conv: bool, dtype: typing.Any) -> None: + r"""Tests the upsampling block in U-Net models.""" + rng = jax.random.PRNGKey(42) + + block = unet.UpsampleBlock( + with_conv=with_conv, + dtype=dtype, + param_dtype=dtype, + ) + test_input = jnp.ones((2, 16, 16, 32), dtype=dtype) + variables = block.init( + rngs={"params": rng}, + inputs=test_input, + ) + if with_conv: + assert "conv0" in variables["params"] + kernel = variables["params"]["conv0"]["kernel"] + assert isinstance(kernel, jax.Array) + assert kernel.shape == (3, 3, 32, 32) + bias = variables["params"]["conv0"]["bias"] + assert isinstance(bias, jax.Array) + assert bias.shape == (32,) + + outputs = block.apply(variables=variables, inputs=test_input) + assert isinstance(outputs, jax.Array) + assert outputs.shape == (2, 32, 32, 32) + assert outputs.dtype == dtype + + +@pytest.mark.parametrize("num_heads", [1, 4]) +@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) +def test_attn_block(num_heads: int, dtype: typing.Any) -> None: + r"""Tests the attention block in U-Net models.""" + rng = jax.random.PRNGKey(42) + + block = unet.AttnBlock(num_heads=num_heads, dtype=dtype, param_dtype=dtype) + test_input = jnp.ones((2, 16, 16, 32), dtype=dtype) + variables = block.init( + rngs={"params": rng}, + inputs=test_input, + ) + + outputs = block.apply(variables=variables, inputs=test_input) + assert isinstance(outputs, jax.Array) + assert outputs.shape == (2, 16, 16, 32) + assert outputs.dtype == dtype + + +@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) +def test_score_net(dtype: typing.Any) -> None: + r"""Tests the full U-Net model for score-based generative modeling.""" + rng = jax.random.PRNGKey(42) + + model = unet.ScoreNet( + features=128, + dropout_rate=0.2, + dtype=dtype, + param_dtype=dtype, + ) + test_input = jnp.ones((2, 32, 32, 3), dtype=dtype) + test_cond = jnp.ones((2, 16), dtype=dtype) + params_rng, dropout_rng = jax.random.split(rng, num=2) + variables = model.init( + rngs={"params": params_rng}, + inputs=test_input, + cond=test_cond, + deterministic=True, + ) + outputs = model.apply( + variables=variables, + inputs=test_input, + cond=test_cond, + deterministic=True, + rngs={"dropout": dropout_rng}, + ) + assert isinstance(outputs, jax.Array) + assert outputs.shape == (2, 32, 32, 3) + assert outputs.dtype == dtype + + +if __name__ == "__main__": + sys.exit(pytest.main(["-xv", __file__])) diff --git a/src/projects/generative/model/unet.py b/src/projects/generative/model/unet.py new file mode 100644 index 0000000..fbce45f --- /dev/null +++ b/src/projects/generative/model/unet.py @@ -0,0 +1,689 @@ +import typing + +import chex +from flax import linen as nn +import jax +from jax import numpy as jnp + + +class ResNetBlock(nn.Module): + r"""A residual downsampling block with two convolutional layers. + + Args: + features (int): Dimensionality of the latent features. + num_groups (int, optional): Number of groups for `GroupNorm`. + Default is :math:`32`. + epsilon (float, optional): Small float added to variance to avoid + dividing by zero in `GroupNorm`. Default is :math:`1e-5`. + deterministic (bool, optional): If true, the model is run in + deterministic mode (e.g., no dropout). Defaults to `None`. + dropout_rate (float, optional): Dropout rate. Default is :math:`0`. + skip_scale (float, optional): Scaling factor for the residual + connection output. Default is :math:`1.0`. + dtype (Any, optional): The dtype of the computation. + param_dtype (Any, optional): The dtype of the parameters. + precision (Any, optional): Numerical precision of the computation. + """ + + features: int + num_groups: int = 32 + epsilon: float = 1e-5 + deterministic: typing.Optional[bool] = None + dropout_rate: float = 0.0 + skip_scale: float = 1.0 + dtype: typing.Any = None + param_dtype: typing.Any = None + precision: typing.Any = None + + def setup(self) -> None: + r"""Instantiates a `ResNetBlock` instance.""" + self.norm_1 = nn.GroupNorm( + num_groups=self.num_groups, + epsilon=self.epsilon, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="norm0", + ) + self.conv_1 = nn.Conv( + features=self.features, + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="conv0", + ) + self.cond_linear = nn.Dense( + features=self.features, + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="cond_in", + ) + self.dropout = nn.Dropout(rate=self.dropout_rate, name="dropout") + + self.norm_2 = nn.GroupNorm( + num_groups=self.num_groups, + epsilon=self.epsilon, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="norm1", + ) + self.conv_2 = nn.Conv( + features=self.features, + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="conv1", + ) + self.conv_shortcut = nn.Dense( + features=self.features, + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="conv_shortcut", + ) + + def __call__( + self, + inputs: jax.Array, + cond: typing.Optional[jax.Array] = None, + deterministic: typing.Optional[bool] = None, + ) -> jax.Array: + r"""Forward pass of the `ResNetBlock`. + + Args: + inputs (jax.Array): Input array of shape `(*, H, W, C_in)`. + cond (Optional[jax.Array], optional): Optional conditioning array + of shape `(*, C_cond)`. + deterministic (bool, optional): If true, the model is run in + deterministic mode (e.g., no dropout). Defaults to `None`. + + Returns: + Output array of shape `(*, H, W, C_out)`, where `C_out` is the + `features` specified during instantiation. + """ + m_deterministic = nn.merge_param( + "deterministic", + self.deterministic, + deterministic, + ) + batch_dims = inputs.shape[:-3] + dims = chex.Dimensions( + H=inputs.shape[-3], + W=inputs.shape[-2], + C=inputs.shape[-1], + ) + + out = self.conv_1(jax.nn.silu(self.norm_1(inputs))) + + if cond is not None: + out = out + self.cond_linear(jax.nn.silu(cond))[..., None, None, :] + out = jax.nn.silu(self.norm_2(out)) + out = self.dropout(out, deterministic=m_deterministic) + out = self.conv_2(out) + + if inputs.shape[-1] != self.features: + shortcut = self.conv_shortcut(inputs) + else: + shortcut = inputs + out = out + shortcut + out = out * self.skip_scale + chex.assert_shape(out, (*batch_dims, *dims["HW"], self.features)) + + return out + + +class DownsampleBlock(nn.Module): + r"""A downsampling block using averaging pooling or strided convolution. + + Args: + with_conv (bool, optional): If true, uses a strided convolution for + downsampling. If `False`, uses average pooling. Default is `True`. + dtype (Any, optional): The dtype of the computation. + param_dtype (Any, optional): The dtype of the parameters. + """ + + with_conv: bool = True + dtype: typing.Any = None + param_dtype: typing.Any = None + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + r"""Forward pass of the `DownsampleBlock`. + + Args: + inputs (jax.Array): Input array of shape `(*, H, W, C)`. + + Returns: + Output array of shape `(*, H / 2, W / 2, C)`. + """ + batch_dims = inputs.shape[:-3] + dims = chex.Dimensions( + H=inputs.shape[-3], + h=inputs.shape[-3] // 2, + W=inputs.shape[-2], + w=inputs.shape[-2] // 2, + C=inputs.shape[-1], + ) + + if self.with_conv: + out = nn.Conv( + features=inputs.shape[-1], + kernel_size=(3, 3), + strides=(2, 2), + padding=((0, 1), (0, 1)), + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="conv0", + )(inputs) + else: + out = nn.avg_pool(inputs, window_shape=(2, 2), strides=(2, 2)) + chex.assert_shape(out, (*batch_dims, *dims["hwC"])) + + return out + + +class UpsampleBlock(nn.Module): + r"""An upsampling block using nearest-neighbor interpolation. + + Args: + with_conv (bool, optional): If true, applies a convolution after + upsampling. Default is `True`. + dtype (Any, optional): The dtype of the computation. + param_dtype (Any, optional): The dtype of the parameters. + precision (Any, optional): Numerical precision of the computation. + """ + + with_conv: bool = True + dtype: typing.Any = None + param_dtype: typing.Any = None + precision: typing.Any = None + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + r"""Forward pass of the `UpsampleBlock`. + + Args: + inputs (jax.Array): Input array of shape `(*, H, W, C)` + + Returns: + Output array of shape `(*, H * 2, W * 2, C)`. + """ + batch_dims = inputs.shape[:-3] + dims = chex.Dimensions( + H=inputs.shape[-3], + h=inputs.shape[-3] * 2, + W=inputs.shape[-2], + w=inputs.shape[-2] * 2, + C=inputs.shape[-1], + ) + + out = jax.image.resize( + inputs, + shape=(*batch_dims, *dims["hwC"]), + method="nearest", + antialias=True, + precision=self.precision, + ) + if self.with_conv: + out = nn.Conv( + features=inputs.shape[-1], + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="conv0", + )(out) + + chex.assert_shape(out, (*batch_dims, *dims["hwC"])) + return out + + +class AttnBlock(nn.Module): + r"""Self-attention block with group normalization in U-Net models. + + Args: + num_heads (int): Number of attention heads. + num_groups (int): Number of groups for `GroupNorm`. + epsilon (float, optional): Small float added to variance to avoid + dividing by zero in `GroupNorm`. Default is :math:`1e-5`. + skip_scale (float, optional): Scaling factor for the residual + connection output. Default is :math:`1.0`. + dtype (Any, optional): The dtype of the computation. + param_dtype (Any, optional): The dtype of the parameters. + precision (Any, optional): Numerical precision of the computation. + """ + + num_heads: int + num_groups: int + epsilon: float = 1e-5 + skip_scale: float = 1.0 + dtype: typing.Any = None + param_dtype: typing.Any = None + precision: typing.Any = None + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + r"""Forward pass of the `AttnBlock`. + + Args: + inputs (jax.Array): Input array of shape `(*, H, W, C)`. + + Returns: + Output array of shape `(*, H, W, C)`. + """ + + norm_in = nn.GroupNorm( + num_groups=self.num_groups, + epsilon=self.epsilon, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="norm", + ) + out = norm_in(inputs) + + if self.num_heads == 1: + # scaled dot-product attention + q_proj = nn.Dense( + features=inputs.shape[-1], + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="q_proj", + ) + query = q_proj(out) + k_proj = nn.Dense( + features=inputs.shape[-1], + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="k_proj", + ) + key = k_proj(out) + v_proj = nn.Dense( + features=inputs.shape[-1], + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="v_proj", + ) + value = v_proj(out) + out = nn.dot_product_attention( + query[..., None, :], + key[..., None, :], + value[..., None, :], + broadcast_dropout=False, + dropout_rate=0.0, + dtype=self.dtype, + precision=self.precision, + ) + out_proj = nn.Dense( + features=inputs.shape[-1], + kernel_init=jax.nn.initializers.zeros, + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="out_proj", + ) + out = out_proj(out[..., 0, :]) + else: + head_dim = inputs.shape[-1] // self.num_heads + if head_dim * self.num_heads != inputs.shape[-1]: + raise ValueError( + f"Number of heads {self.num_heads} not compatible with " + f"input channels {inputs.shape[-1]}." + ) + q_proj = nn.DenseGeneral( + features=(self.num_heads, head_dim), + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="q_proj", + ) + query = q_proj(out) + k_proj = nn.DenseGeneral( + features=(self.num_heads, head_dim), + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="k_proj", + ) + key = k_proj(out) + v_proj = nn.DenseGeneral( + features=(self.num_heads, head_dim), + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="v_proj", + ) + value = v_proj(out) + out = nn.dot_product_attention( + query, + key, + value, + broadcast_dropout=False, + dropout_rate=0.0, + dtype=self.dtype, + precision=self.precision, + ) + out_proj = nn.DenseGeneral( + features=inputs.shape[-1], + kernel_init=jax.nn.initializers.zeros, + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="out_proj", + ) + out = out_proj(out) + + chex.assert_equal_shape([out, inputs]) + out = out + inputs + out = out * self.skip_scale + + return out + + +class ScoreNet(nn.Module): + r"""U-Net architecture for score-function estimation. + + This module is adapted from the original implementation of the U-Net + architecture from "Score-Based Generative Modeling through Stochastic + Differential Equations" by Yang Song et al. and the original implementation + is available at `https://github.com/yang-song/score_sde_pytorch`. + + Args: + features (int): Base number of features for the latent representations. + ch_mults (typing.Sequence[int], optional): Sequence of multipliers + for the number of features at each level of the U-Net. + num_groups (int, optional): Number of groups for `GroupNorm`. + num_res_blocks (int, optional): Number of residual blocks per level. + attn_resolutions (typing.Sequence[int], optional): Sequence of + resolutions at which to apply attention mechanisms. + dropout_rate (float, optional): Dropout rate. Default is :math:`0.0`. + epsilon (float, optional): Small float added to variance to avoid + dividing by zero in `GroupNorm`. Default is :math:`1e-5`. + skip_scale (float, optional): Scaling factor for the residual + connection outputs. Default is :math:`1.0`. + deterministic (bool, optional): If true, the model is run in + deterministic mode (e.g., no dropout). Defaults to `None`. + dtype (Any, optional): The dtype of the computation. + param_dtype (Any, optional): The dtype of the parameters. + precision (Any, optional): Numerical precision of the computation. + """ + + features: int + ch_mults: typing.Sequence[int] = (1, 2, 2, 2) + num_groups: int = 32 + num_res_blocks: int = 4 + attn_resolutions: typing.Sequence[int] = (16,) + dropout_rate: float = 0.0 + epsilon: float = 1e-5 + skip_scale: float = 1.0 + deterministic: typing.Optional[bool] = None + dtype: typing.Any = None + param_dtype: typing.Any = None + precision: typing.Any = None + + @nn.compact + def __call__( + self, + inputs: jax.Array, + cond: jax.Array, + deterministic: typing.Optional[bool] = None, + ) -> jax.Array: + r"""Forward pass of the `ScoreNet`. + + Args: + inputs (jax.Array): Input array of shape `(*, H, W, C_in)`. + cond (jax.Array): Conditioning array of shape `(*, C_cond)`. + deterministic (bool, optional): If true, the model is run in + deterministic mode (e.g., no dropout). Defaults to `None`. + + Returns: + Output array of shape `(*, H, W, C_out)`, where `C_out` is the + number of channels in the input. + """ + m_deterministic = nn.merge_param( + "deterministic", + self.deterministic, + deterministic, + ) + batch_dims = inputs.shape[:-3] + dims = chex.Dimensions( + H=inputs.shape[-3], + W=inputs.shape[-2], + C=inputs.shape[-1], + ) + skips = [] + + # forward pass the input convolution + conv_in = nn.Conv( + features=self.features, + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + name="conv_in", + ) + out = conv_in(inputs) + skips.append(out) + + # forward pass the downsampling path + for level, mult in enumerate(self.ch_mults): + out_ch = self.features * mult + for i in range(self.num_res_blocks): + res_block = ResNetBlock( + features=out_ch, + num_groups=self.num_groups, + dropout_rate=self.dropout_rate, + epsilon=self.epsilon, + skip_scale=self.skip_scale, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name=f"down_resnet_{level + 1:d}_{i + 1:d}", + ) + out = res_block( + inputs=out, + cond=cond, + deterministic=m_deterministic, + ) + if out.shape[-3] in self.attn_resolutions: + block = AttnBlock( + num_heads=1, + num_groups=self.num_groups, + epsilon=self.epsilon, + skip_scale=self.skip_scale, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name=f"down_attn_{level + 1:d}_{i + 1:d}", + ) + out = block(out) + skips.append(out) + if level != len(self.ch_mults) - 1: + downsample = DownsampleBlock( + with_conv=True, + dtype=self.dtype, + param_dtype=self.param_dtype, + name=f"downsample_{level + 1:d}", + ) + out = downsample(out) + skips.append(out) + + # forward pass the middle blocks + block = ResNetBlock( + features=out.shape[-1], + num_groups=self.num_groups, + epsilon=self.epsilon, + skip_scale=self.skip_scale, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name="mid_resnet_1", + ) + out = block(out, cond=cond, deterministic=m_deterministic) + block = AttnBlock( + num_heads=1, + num_groups=self.num_groups, + epsilon=self.epsilon, + skip_scale=self.skip_scale, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name="mid_attn", + ) + out = block(out) + block = ResNetBlock( + features=out.shape[-1], + num_groups=self.num_groups, + epsilon=self.epsilon, + skip_scale=self.skip_scale, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name="mid_resnet_2", + ) + out = block(out, cond=cond, deterministic=m_deterministic) + + # forward pass the upsampling path + for level, mult in reversed(list(enumerate(self.ch_mults))): + out_ch = self.features * mult + for i in range(self.num_res_blocks + 1): + skip = skips.pop() + out = jnp.concatenate([out, skip], axis=-1) + res_block = ResNetBlock( + features=out_ch, + dropout_rate=self.dropout_rate, + num_groups=self.num_groups, + epsilon=self.epsilon, + skip_scale=self.skip_scale, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name=f"up_resnet_{level + 1:d}_{i + 1:d}", + ) + out = res_block( + inputs=out, + cond=cond, + deterministic=m_deterministic, + ) + if out.shape[-3] in self.attn_resolutions: + block = AttnBlock( + num_heads=1, + num_groups=self.num_groups, + epsilon=self.epsilon, + skip_scale=self.skip_scale, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name=f"up_attn_{level + 1:d}_{i + 1:d}", + ) + out = block(out) + if level != 0: + upsample = UpsampleBlock( + with_conv=True, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name=f"upsample_{level + 1:d}", + ) + out = upsample(out) + + # forward pass the output convolution + norm_out = nn.GroupNorm( + num_groups=self.num_groups, + epsilon=self.epsilon, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="norm_out", + ) + out = jax.nn.silu(norm_out(out)) + conv_out = nn.Conv( + features=dims.C, # type: ignore + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + kernel_init=jax.nn.initializers.zeros, + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + name="conv_out", + ) + out = conv_out(out) + chex.assert_shape(out, (*batch_dims, *dims["HWC"])) + + return out diff --git a/src/projects/generative/test_meanflow.py b/src/projects/generative/test_meanflow.py deleted file mode 100644 index 906f2da..0000000 --- a/src/projects/generative/test_meanflow.py +++ /dev/null @@ -1,190 +0,0 @@ -import sys -import typing - -import chex -from flax import linen as nn -import jax -import jax.numpy as jnp -import pytest - -from learning.generative import meanflow - - -@pytest.mark.parametrize("distribution", ["uniform", "normal", "lognormal"]) -@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) -def test_sample_t_r(distribution: str, dtype: typing.Any) -> None: - """Test the `sample_t_r` function.""" - key = jax.random.PRNGKey(0) - shape = (2, 3) - - if distribution not in ["uniform", "lognormal"]: - with pytest.raises(ValueError): - meanflow.sample_t_r( - key=key, - shape=shape, - dtype=dtype, - distribution=distribution, - ) - return - - # Test uniform distribution - t, r = meanflow.sample_t_r( - key=key, - shape=shape, - dtype=dtype, - distribution="uniform", - ) - chex.assert_shape(t, shape) - chex.assert_shape(r, shape) - chex.assert_type(t, dtype) - chex.assert_type(r, dtype) - chex.assert_tree_all_finite(t) - chex.assert_tree_all_finite(r) - assert jnp.all(t >= 0) and jnp.all(t <= 1) - assert jnp.all(r >= 0) and jnp.all(r <= 1) - - -@pytest.mark.parametrize("features", [1, 8]) -@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) -def test_timestamp_embed(features: int, dtype: typing.Any) -> None: - """Test the `TimestampEmbed` module.""" - embed = meanflow.TimestampEmbed( - features=features, - frequency=256, - name="timestamp_embed", - dtype=dtype, - param_dtype=dtype, - ) - assert isinstance(embed, nn.Module) - assert embed.features == features - assert embed.frequency == 256 - assert embed.dtype == dtype - assert embed.param_dtype == dtype - variables = embed.init( - jax.random.PRNGKey(0), - jnp.ones((2,), dtype=jnp.int32), - ) - chex.assert_shape(variables["params"]["fc_in"]["kernel"], (256, features)) - chex.assert_type(variables["params"]["fc_in"]["kernel"], dtype) - chex.assert_shape(variables["params"]["fc_in"]["bias"], (features,)) - chex.assert_type(variables["params"]["fc_in"]["bias"], dtype) - chex.assert_shape( - variables["params"]["fc_out"]["kernel"], - (features, features), - ) - chex.assert_type(variables["params"]["fc_out"]["kernel"], dtype) - chex.assert_shape(variables["params"]["fc_out"]["bias"], (features,)) - chex.assert_type(variables["params"]["fc_out"]["bias"], dtype) - - test_output = embed.apply( - variables, - jnp.array([10, 1000], dtype=jnp.int32), - ) - chex.assert_shape(test_output, (2, features)) - chex.assert_type(test_output, dtype) - chex.assert_tree_all_finite(test_output) - - -@pytest.mark.parametrize("features", [1, 8]) -@pytest.mark.parametrize("use_cfg_embedding", [False, True]) -@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) -def test_condition_embed( - features: int, - use_cfg_embedding: bool, - dtype: typing.Any, -) -> None: - """Test the `ConditionEmbed` module.""" - if use_cfg_embedding: - # TODO: implement classifier-free guidance. - pytest.skip("Classifier-free guidance not supported yet.") - - embed = meanflow.ConditionEmbed( - features=features, - num_classes=10, - use_cfg_embedding=use_cfg_embedding, - name="condition_embed", - dtype=dtype, - param_dtype=dtype, - ) - assert isinstance(embed, nn.Module) - assert embed.features == features - assert embed.num_classes == 10 - assert embed.use_cfg_embedding == use_cfg_embedding - assert embed.dtype == dtype - assert embed.param_dtype == dtype - variables = embed.init( - jax.random.PRNGKey(0), - jnp.ones((2,), dtype=jnp.int32), - ) - chex.assert_shape( - variables["params"]["embedding_table"]["embedding"], - (10 + int(use_cfg_embedding), features), - ) - chex.assert_type( - variables["params"]["embedding_table"]["embedding"], dtype - ) - - test_output = embed.apply( - variables, - jnp.array([1, 9], dtype=jnp.int32), - ) - chex.assert_shape(test_output, (2, features)) - chex.assert_type(test_output, dtype) - chex.assert_tree_all_finite(test_output) - - -@pytest.mark.parametrize("features", [1, 8]) -@pytest.mark.parametrize("use_bias", [False, True]) -@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) -def test_conditional_instance_norm( - features: int, - use_bias: bool, - dtype: typing.Any, -) -> None: - """Test the `ConditionalInstanceNorm` module.""" - cond_features = 4 - cond = jnp.ones((2, cond_features), dtype=dtype) - - norm = meanflow.ConditionalInstanceNorm( - features=features, - use_bias=use_bias, - name="conditional_instance_norm", - dtype=dtype, - param_dtype=dtype, - ) - assert isinstance(norm, nn.Module) - assert norm.features == features - assert norm.use_bias == use_bias - assert norm.dtype == dtype - assert norm.param_dtype == dtype - variables = norm.init( - jax.random.PRNGKey(0), - jnp.ones((2, 16, 16, features), dtype=dtype), - cond, - ) - assert variables["params"].get("instance_norm") is None - if use_bias: - chex.assert_shape( - variables["params"]["embed"]["kernel"], - (cond_features, features * 3), - ) - assert variables["params"]["embed"].get("bias") is None - else: - chex.assert_shape( - variables["params"]["embed"]["kernel"], - (cond_features, features * 2), - ) - assert variables["params"]["embed"].get("bias") is None - - test_output = norm.apply( - variables, - jnp.ones((2, 16, 16, features), dtype=dtype), - cond, - ) - chex.assert_shape(test_output, (2, 16, 16, features)) - chex.assert_type(test_output, dtype) - chex.assert_tree_all_finite(test_output) - - -if __name__ == "__main__": - sys.exit(pytest.main(["-xv", __file__])) diff --git a/src/utilities/BUILD b/src/utilities/BUILD index 8c4078b..40126c8 100644 --- a/src/utilities/BUILD +++ b/src/utilities/BUILD @@ -1,6 +1,6 @@ -load("//learning:defs.bzl", "ml_py_library") +load("//third_party:defs.bzl", "ml_py_library") -package(default_visibility = ["//learning:__subpackages__"]) +package(default_visibility = ["//src:__subpackages__"]) ml_py_library( name = "logging", @@ -8,7 +8,6 @@ ml_py_library( deps = [ "absl-py", "jax", - "jaxlib", ], ) @@ -17,6 +16,14 @@ ml_py_library( srcs = ["rank_zero.py"], deps = [ "jax", - "jaxlib", + ], +) + +ml_py_library( + name = "visualization", + srcs = ["visualization.py"], + deps = [ + "jax", + ":logging", ], ) diff --git a/src/utilities/visualization.py b/src/utilities/visualization.py new file mode 100644 index 0000000..15a540c --- /dev/null +++ b/src/utilities/visualization.py @@ -0,0 +1,49 @@ +import jax +from jax import numpy as jnp + +from src.utilities import logging + + +def make_grid( + images: jax.Array, + n_rows: int = 8, + n_cols: int = 8, + padding: int = 2, +) -> jax.Array: + r"""Convert a batch of images into a grid for visualization. + + Args: + images (jax.Array): Batch of images with shape `(B, H, W, C)`. + n_rows (int): Number of rows in grid. Default is :math:`8`. + n_cols (int): Number of columns in grid. Default is :math:`8`. + padding (int, optional): Number of pixels between pair of images. + Default is :math:`2`. + + Returns: + The array containing a grid of input images. + """ + images = jnp.reshape(images, (-1,) + images.shape[-3:]) + _, h, w, c = images.shape + shape = ( + h * n_rows + padding * (n_rows - 1), + w * n_cols + padding * (n_cols - 1), + c, + ) + out = jnp.zeros(shape, dtype=images.dtype) + + for idx, img in enumerate(images): + row = idx // n_cols + col = idx % n_cols + top = row * (h + padding) + left = col * (w + padding) + out = out.at[top : top + h, left : left + w].set(img) + + if idx + 1 >= n_rows * n_cols: + logging.rank_zero_warning( + "Number of images exceed grid capacity; " + + "only the first %d images are used.", + n_rows * n_cols, + ) + break + + return out diff --git a/third_party/BUILD b/third_party/BUILD index f9e5f7c..bdd1155 100644 --- a/third_party/BUILD +++ b/third_party/BUILD @@ -18,6 +18,11 @@ config_setting( define_values = {"ml_platform": "tpu"}, ) +config_setting( + name = "is_mps", + define_values = {"ml_platform": "mps"}, +) + # compile pip requirements into a lock file compile_pip_requirements( name = "requirements_3_10_cpu", @@ -57,3 +62,17 @@ compile_pip_requirements( ], requirements_txt = "requirements_3_10_tpu_lock.txt", ) + +compile_pip_requirements( + name = "requirements_3_10_mps", + timeout = "moderate", + srcs = [ + "requirements.in", + "requirements_mps.in", + ], + extra_args = [ + "--allow-unsafe", + "--resolver=backtracking", + ], + requirements_txt = "requirements_3_10_mps_lock.txt", +) diff --git a/third_party/defs.bzl b/third_party/defs.bzl index 269135e..baa820a 100644 --- a/third_party/defs.bzl +++ b/third_party/defs.bzl @@ -2,6 +2,7 @@ load("@ml_infra_cpu_3_10//:requirements.bzl", cpu_req = "requirement") load("@ml_infra_cuda_3_10//:requirements.bzl", cuda_req = "requirement") +load("@ml_infra_mps_3_10//:requirements.bzl", mps_req = "requirement") load("@ml_infra_tpu_3_10//:requirements.bzl", tpu_req = "requirement") load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") @@ -17,6 +18,7 @@ def _select_requirement(name): return select({ "//third_party:is_cpu": [cpu_req(name)], "//third_party:is_cuda": [cuda_req(name)], + "//third_party:is_mps": [mps_req(name)], "//third_party:is_tpu": [tpu_req(name)], }) @@ -36,8 +38,13 @@ def _select_all_requirements(names = []): if "fiddle" in names and "etils" not in names: reqs += _select_requirement("etils") - if "jax" in names and "jaxlib" not in names: - reqs += _select_requirement("jaxlib") + if "jax" in names: + if "jaxlib" not in names: + reqs += _select_requirement("jaxlib") + reqs += select({ + "//third_party:is_mps": [mps_req("jax-metal")], + "//conditions:default": [], + }) return reqs diff --git a/third_party/requirements.in b/third_party/requirements.in index 2305df3..57dd24c 100644 --- a/third_party/requirements.in +++ b/third_party/requirements.in @@ -1,5 +1,5 @@ absl-py==2.3.1 -chex==0.1.91 +chex==0.1.90 clu==0.0.12 datasets==4.4.1 flax==0.10.7 diff --git a/third_party/requirements_mps.in b/third_party/requirements_mps.in new file mode 100644 index 0000000..872ef87 --- /dev/null +++ b/third_party/requirements_mps.in @@ -0,0 +1 @@ +jax-metal==0.1.1