From ccfa898b6a7d90af36937b678e8da4370ef078fb Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 04:40:32 -0500 Subject: [PATCH 01/67] hotfix: Fixed typo in core module Signed-off-by: Juanwu Lu --- src/core/config.py | 4 ++-- src/core/evaluate.py | 10 +++++----- src/core/train.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/core/config.py b/src/core/config.py index b9bd73f..6084b57 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -4,7 +4,7 @@ import fiddle as fdl import optax -from src.core import data as _data +from src.core import datamodule as _datamodule from src.core import model as _model @@ -20,7 +20,7 @@ class DataConfig: drop_remainder (bool): Whether to drop the last incomplete batch. """ - module: fdl.Partial[_data.DataModule] + module: fdl.Partial[_datamodule.DataModule] batch_size: int = 32 num_workers: int = 4 deterministic: bool = True diff --git a/src/core/evaluate.py b/src/core/evaluate.py index 85c80d1..1163d53 100644 --- a/src/core/evaluate.py +++ b/src/core/evaluate.py @@ -7,14 +7,14 @@ import jax import jaxtyping -from src.core import data as _data +from src.core import datamodule as _datamodule from src.core import model as _model from src.utilities import logging def run( model: _model.Model, - datamodule: _data.DataModule, + datamodule: _datamodule.DataModule, params: jaxtyping.PyTree, writer: metric_writers.MetricWriter, work_dir: str, @@ -88,9 +88,9 @@ def run( _scalars = {} for k, v in outputs.scalars.items(): eval_metrics[k].append(jax.device_get(v).mean()) - _scalars[ - f"eval/{k.replace('_', ' ')}" - ] = jax.device_get(v).mean() + _scalars[f"eval/{k.replace('_', ' ')}"] = ( + jax.device_get(v).mean() + ) writer.write_scalars( step=step + 1, scalars=_scalars, diff --git a/src/core/train.py b/src/core/train.py index 75c9b6a..c0b43f2 100644 --- a/src/core/train.py +++ b/src/core/train.py @@ -9,7 +9,7 @@ import jax import jaxtyping -from src.core import data as _data +from src.core import datamodule as _datamodule from src.core import model as _model from src.core import train_state as _train_state from src.utilities import logging @@ -55,7 +55,7 @@ def _shard(tree: jaxtyping.PyTree) -> jaxtyping.PyTree: def run( model: _model.Model, state: _train_state.TrainState, - datamodule: _data.DataModule, + datamodule: _datamodule.DataModule, num_train_steps: int, checkpoint_manager: checkpoint.Checkpoint, writer: metric_writers.MetricWriter, From 0810627ceb28aec01c5ba3313a98a1d6bca66a35 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 04:41:11 -0500 Subject: [PATCH 02/67] feat: Updated implementation fo RefineNet Signed-off-by: Juanwu Lu --- src/core/evaluate.py | 6 +- src/projects/generative/model/BUILD | 4 +- src/projects/generative/model/refinenet.py | 160 ++++++++++-------- .../generative/model/test_refinenet.py | 62 ++++--- 4 files changed, 124 insertions(+), 108 deletions(-) diff --git a/src/core/evaluate.py b/src/core/evaluate.py index 1163d53..44ebf22 100644 --- a/src/core/evaluate.py +++ b/src/core/evaluate.py @@ -88,9 +88,9 @@ def run( _scalars = {} for k, v in outputs.scalars.items(): eval_metrics[k].append(jax.device_get(v).mean()) - _scalars[f"eval/{k.replace('_', ' ')}"] = ( - jax.device_get(v).mean() - ) + _scalars[ + f"eval/{k.replace('_', ' ')}" + ] = jax.device_get(v).mean() writer.write_scalars( step=step + 1, scalars=_scalars, diff --git a/src/projects/generative/model/BUILD b/src/projects/generative/model/BUILD index aed17c1..b8c1565 100644 --- a/src/projects/generative/model/BUILD +++ b/src/projects/generative/model/BUILD @@ -1,4 +1,4 @@ -load("//learning:defs.bzl", "ml_py_library", "ml_py_test") +load("//third_party:defs.bzl", "ml_py_library", "ml_py_test") package(default_visibility = ["//learning/generative:__subpackages__"]) @@ -9,7 +9,6 @@ ml_py_library( "chex", "flax", "jax", - "jaxlib", ], ) @@ -20,7 +19,6 @@ ml_py_test( "chex", "flax", "jax", - "jaxlib", ":refinenet", ], ) diff --git a/src/projects/generative/model/refinenet.py b/src/projects/generative/model/refinenet.py index 9be1ed9..b66d620 100644 --- a/src/projects/generative/model/refinenet.py +++ b/src/projects/generative/model/refinenet.py @@ -5,6 +5,7 @@ import jax from jax._src import core as jax_core from jax._src import dtypes as jax_dtypes +from jax._src import typing as jax_typing import jax.numpy as jnp @@ -12,23 +13,25 @@ # Builder functions # ============================================================================== def _uniform_init() -> jax.nn.initializers.Initializer: - """Uniform initializer for convolutional layers.""" + r"""Uniform initializer for convolutional layers.""" def init( - key: jax.random.KeyArray, - shape: jax_core.Shape, - dtype: typing.Any, + key: jax.Array, + shape: jax_typing.Shape, + dtype: typing.Any = jnp.float_, + out_sharding: typing.Any = None, ) -> jax.Array: """Uniform initializer for one-dimensional parameters.""" dim = shape[-1] dtype = jax_dtypes.canonicalize_dtype(dtype) - named_shape = jax_core.as_named_shape(shape) + named_shape = jax_core.canonicalize_shape(shape) return jax.random.uniform( key=key, shape=named_shape, dtype=dtype, minval=-jnp.sqrt(1.0 / dim), maxval=jnp.sqrt(1.0 / dim), + out_sharding=out_sharding, ) return init @@ -42,7 +45,7 @@ def _conv_1x1( dtype: typing.Any = jnp.float32, param_dtype: typing.Any = jnp.float32, ) -> nn.Conv: - """1x1 convolution with stride and padding.""" + r"""1x1 convolution with stride and padding.""" return nn.Conv( features=out_channels, kernel_size=(1, 1), @@ -69,7 +72,7 @@ def _conv_3x3( dtype: typing.Any = jnp.float32, param_dtype: typing.Any = jnp.float32, ) -> nn.Conv: - """3x3 convolution with stride and padding.""" + r"""3x3 convolution with stride and padding.""" return nn.Conv( features=out_channels, kernel_size=(3, 3), @@ -96,7 +99,7 @@ def _dilated_conv_3x3( dtype: typing.Any = jnp.float32, param_dtype: typing.Any = jnp.float32, ) -> nn.Conv: - """3x3 dilated convolution with dilation and padding.""" + r"""3x3 dilated convolution with dilation and padding.""" return nn.Conv( features=out_channels, kernel_size=(3, 3), @@ -120,7 +123,7 @@ def _dilated_conv_3x3( # Layers # ============================================================================== class ConditionalInstanceNorm2dPlus(nn.Module): - """Conditional Instance Normalization with extra affine transformation.""" + r"""Conditional Instance Normalization with extra affine transformation.""" features: int """int: Dimensionality of the feature map.""" @@ -148,15 +151,21 @@ def setup(self) -> None: ) def _kernel_init( - key: jax.random.KeyArray, - shape: jax_core.Shape, + key: typing.Any, + shape: jax_typing.Shape, dtype: typing.Any, + out_sharding: typing.Any = None, ) -> jax.Array: dtype = jax_dtypes.canonicalize_dtype(dtype) - named_shape = jax_core.as_named_shape(shape) + named_shape = jax_core.canonicalize_shape(shape) return ( 1.0 - + jax.random.normal(key=key, shape=named_shape, dtype=dtype) + + jax.random.normal( + key=key, + shape=named_shape, + dtype=dtype, + out_sharding=out_sharding, + ) * 0.02 ) @@ -187,7 +196,7 @@ def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: cond (jax.Array): Condition feature map of shape `(*, )`. Returns: - jax.Array: Output feature map of shape `(*, H, W, C)`. + Output feature map of shape `(*, H, W, C)`. """ batch_dims = inputs.shape[:-3] chex.assert_shape(cond, (*batch_dims,)) @@ -222,7 +231,7 @@ def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: class ConvMeanPool(nn.Module): - """Convolution followed by average pooling.""" + r"""Convolution followed by average pooling.""" features: int """int: Number of output channels.""" @@ -248,13 +257,13 @@ def setup(self) -> None: ) def __call__(self, inputs: jax.Array) -> jax.Array: - """Forward pass of the `ConvMeanPool` module. + r"""Forward pass of the `ConvMeanPool` module. Args: inputs (jax.Array): Input feature map of shape `(*, H, W, C)`. Returns: - jax.Array: Output feature map of shape `(*, H/2, W/2, C_out)`. + Output feature map of shape `(*, H/2, W/2, C_out)`. """ batch_dims = inputs.shape[:-3] if self.adjust_padding: @@ -278,14 +287,14 @@ def __call__(self, inputs: jax.Array) -> jax.Array: # Modules # ============================================================================== class ConditionalResidualBlock(nn.Module): - """Residual block with conditioning feature map.""" + r"""Residual block with conditioning feature map.""" in_channels: int """int: Number of channels of the input feature map.""" out_channels: int """int: Number of channels of the output feature map.""" - norm_module: typing.Callable[[typing.Any], typing.Type[nn.Module]] - """Callable[[typing.Any], Type[nn.Module]]: Normalization module to use.""" + norm_module: typing.Callable[..., nn.Module] + """Callable[..., nn.Module]: Normalization module to use.""" dilation: typing.Optional[int] = None """Optional[int]: Optional dilations in the convolutional layers.""" resample: typing.Optional[str] = None @@ -298,7 +307,7 @@ class ConditionalResidualBlock(nn.Module): """param_dtype: The data type of the parameters (default: float32).""" def setup(self) -> None: - """Instantiate a conditional residual block.""" + r"""Instantiate a conditional residual block.""" self.norm_1 = self.norm_module( features=self.in_channels, name="normalize1", @@ -426,14 +435,14 @@ def setup(self) -> None: ) def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: - """Forward pass of the conditional residual block. + r"""Forward pass of the conditional residual block. Args: inputs (jax.Array): Input feature map of shape `(*, H, W, D)`. cond (jax.Array): Condition feature map of shape `(*,)`. Returns: - jax.Array: Output feature map of shape `(*, H, W, D)`. + Output feature map of shape `(*, H, W, D)`. """ output = self.norm_1(inputs, cond) output = jax.nn.elu(output) @@ -452,12 +461,12 @@ def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: class ConditionalRCUBlock(nn.Module): - """Refinement Convolution Unit (RCU) block with conditioning feature map.""" + r"""Refinement Convolution Unit (RCU) block with conditioning features.""" features: int """int: Dimensionality of the feature map.""" - norm_module: typing.Callable[[typing.Any], typing.Type[nn.Module]] - """Callable[[typing.Any], Type[nn.Module]]: Normalization module to use.""" + norm_module: typing.Callable[..., nn.Module] + """Callable[..., nn.Module]: Normalization module to use.""" num_blocks: int """int: Number of repeated blocks in the cascade.""" num_stages: int @@ -468,7 +477,7 @@ class ConditionalRCUBlock(nn.Module): """param_dtype: The data type of the parameters (default: float32).""" def setup(self) -> None: - """Instantiate a `ConditionalRCUBlock` module.""" + r"""Instantiate a `ConditionalRCUBlock` module.""" convs, norms = [], [] for i in range(self.num_blocks): for j in range(self.num_stages): @@ -491,11 +500,19 @@ def setup(self) -> None: param_dtype=self.param_dtype, ) ) - self.convs: typing.Tuple[nn.Conv, ...] = convs - self.norms: typing.Tuple[ConditionalInstanceNorm2dPlus, ...] = norms + self.convs = convs + self.norms = norms def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: - """Forward pass of the `ConditionalRCUBlock` module.""" + r"""Forward pass of the `ConditionalRCUBlock` module. + + Args: + inputs (jax.Array): Input feature map of shape `(*, H, W, C)`. + cond (jax.Array): Condition feature map of shape `(*, H, W, d)`. + + Returns: + Output feature map of shape `(*, H, W, C)`. + """ _idx: int = 0 output = inputs for _ in range(self.num_blocks): @@ -511,21 +528,21 @@ def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: class ConditionalMSFBlock(nn.Module): - """Conditional Multi-Scale Feature block.""" + r"""Conditional Multi-Scale Feature block.""" in_features: typing.Sequence[int] """Sequence[int]: List of input feature map dimensionalities.""" features: int """int: Dimensionality of the output feature map.""" - norm_module: typing.Callable[[typing.Any], typing.Type[nn.Module]] - """Callable[[typing.Any], Type[nn.Module]]: Normalization module to use.""" + norm_module: typing.Callable[..., nn.Module] + """Callable[..., nn.Module]: Normalization module to use.""" dtype: typing.Any = jnp.float32 """dtype: The data type of the computation (default: float32).""" param_dtype: typing.Any = jnp.float32 """param_dtype: The data type of the parameters (default: float32).""" def setup(self) -> None: - """Instantiate a `ConditionalMSFBlock` module.""" + r"""Instantiate a `ConditionalMSFBlock` module.""" convs, norms = [], [] for i, in_feature in enumerate(self.in_features): convs.append( @@ -547,25 +564,25 @@ def setup(self) -> None: param_dtype=self.param_dtype, ) ) - self.convs: typing.Tuple[nn.Conv, ...] = convs - self.norms: typing.Tuple[ConditionalInstanceNorm2dPlus, ...] = norms + self.convs = convs + self.norms = norms def __call__( self, inputs: typing.Sequence[jax.Array], cond: jax.Array, - shape: jax_core.Shape, + shape: jax_typing.Shape, ) -> jax.Array: - """Forward pass of the `ConditionalMSFBlock` module. + r"""Forward pass of the `ConditionalMSFBlock` module. Args: inputs (Sequence[jax.Array]): Sequence of input feature maps to be merged. Each feature map has shape `(*, H_i, W_i, C) cond (jax.Array): Condition feature map of shape `(*, H, W, d)`. - shape (jax_core.Shape): Shape of the output feature map. + shape (jax._src.typing.Shape): Shape of the output feature map. Returns: - jax.Array: Output feature map of shape `(*, H, W, C)`. + Output feature map of shape `(*, H, W, C)`. """ assert isinstance(inputs, typing.Sequence) and len(inputs) == len( self.in_features @@ -586,12 +603,12 @@ def __call__( class ConditionalCRPBlock(nn.Module): - """Conditional convolutional residual pooling (CRP) block.""" + r"""Conditional convolutional residual pooling (CRP) block.""" features: int """int: Dimensionality of the output feature map.""" - norm_module: typing.Callable[[typing.Any], typing.Type[nn.Module]] - """Callable[Any, Type[nn.Module]]: Normalization module to use.""" + norm_module: typing.Callable[..., nn.Module] + """Callable[..., nn.Module]: Normalization module to use.""" num_stages: int """int: Number of stages in the cascade.""" dtype: typing.Any = jnp.float32 @@ -600,7 +617,7 @@ class ConditionalCRPBlock(nn.Module): """param_dtype: The data type of the parameters (default: float32).""" def setup(self) -> None: - """Instantiate a `ConditionalCRPBlock` module.""" + r"""Instantiate a `ConditionalCRPBlock` module.""" convs, norms = [], [] for i in range(self.num_stages): convs.append( @@ -622,18 +639,18 @@ def setup(self) -> None: param_dtype=self.param_dtype, ) ) - self.convs: typing.Tuple[nn.Conv, ...] = convs - self.norms: typing.Tuple[ConditionalInstanceNorm2dPlus, ...] = norms + self.convs = convs + self.norms = norms def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: - """Forward pass of the `ConditionalCRPBlock` module. + r"""Forward pass of the `ConditionalCRPBlock` module. Args: inputs (jax.Array): Input feature map of shape `(*, H, W, C)`. cond (jax.Array): Condition feature map of shape `(*, H, W, d)`. Returns: - jax.Array: Output feature map of shape `(*, H, W, C)`. + Output feature map of shape `(*, H, W, C)`. """ output = jax.nn.elu(inputs) path = output @@ -643,7 +660,7 @@ def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: inputs=path, window_shape=(5, 5), strides=(1, 1), - padding=((2, 2), (2, 2)), + padding=((2, 2), (2, 2)), # type: ignore ) path = conv(path) output = output + path @@ -651,14 +668,14 @@ def __call__(self, inputs: jax.Array, cond: jax.Array) -> jax.Array: class ConditionalRefineBlock(nn.Module): - """Refinement block with skip connections and conditioning feature map.""" + r"""Refinement block with skip connections and conditioning feature map.""" in_features: typing.Sequence[int] """Sequence[int]: List of input feature map dimensionalities.""" out_features: int """int: Number of output channels of each convolution.""" - norm_module: typing.Callable[[typing.Any], typing.Type[nn.Module]] - """Callable[Any, Type[nn.Module]]: Normalization module to use.""" + norm_module: typing.Callable[[typing.Any], nn.Module] + """Callable[Any, nn.Module]: Normalization module to use.""" is_last_block: bool = False """bool: If True, this is the last refinement block.""" dtype: typing.Any = jnp.float32 @@ -681,7 +698,7 @@ def setup(self) -> None: param_dtype=self.param_dtype, ) ) - self.adapt_convs: typing.Tuple[ConditionalRCUBlock, ...] = adapt_convs + self.adapt_convs: typing.List[ConditionalRCUBlock] = adapt_convs self.output_convs = ConditionalRCUBlock( features=self.out_features, norm_module=self.norm_module, @@ -715,18 +732,18 @@ def __call__( self, inputs: typing.List[jax.Array], cond: jax.Array, - output_shape: jax_core.Shape, + output_shape: jax_typing.Shape, ) -> jax.Array: - """Forward pass of the refinement block. + r"""Forward pass of the refinement block. Args: inputs (List[jax.Array]): List of input feature maps to be merged. Each feature map has shape `(*, H_i, W_i, C)`. cond (jax.Array): Condition feature map of shape `(*, H, W, d)`. - output_shape (jax_core.Shape): Shape of the output feature map. + output_shape (jax._src.typing.Shape): Shape of the output feature. Returns: - jax.Array: Output feature map of shape `(*, H, W, 128)`. + Output feature map of shape `(*, H, W, 128)`. """ assert ( isinstance(inputs, typing.Sequence) @@ -760,13 +777,11 @@ def __call__( # Models # ============================================================================== class ConditionalRefineNet(nn.Module): - """Multi-path Refinement Network with Conditional Instance Normlization. - - .. note:: + r"""Multi-path Refinement Network with Conditional Instance Normlization. - This module is adapted from the original implementation of - `CondRefineNetDeeperDilated` in the NCSN official repository: - `https://github.com/ermongroup/ncsn/blob/master/models/cond_refinenet_dilated.py` + This module is adapted from the original implementation of + `CondRefineNetDeeperDilated` in the NCSN official repository: + `https://github.com/ermongroup/ncsn/blob/master/models/cond_refinenet_dilated.py` Attributes: in_channels (int): Number of channels of the input feature map. @@ -780,12 +795,12 @@ class ConditionalRefineNet(nn.Module): in_channels: int """int: Number of channels of the input feature map.""" - image_size: typing.Literal[28, 32] + image_size: int """int: Size of the input (square) image, either `28` or `32`.""" latent_channels: int """int: Number of channels of the latent feature map.""" - norm_module: typing.Callable[[typing.Any], typing.Type[nn.Module]] - """Callable[[typing.Any], Type[nn.Module]]: Normalization module to use.""" + norm_module: typing.Callable[..., nn.Module] + """Callable[..., nn.Module]: Normalization module to use.""" dtype: typing.Any = jnp.float32 """dtype: The data type of the computation (default: float32).""" param_dtype: typing.Any = jnp.float32 @@ -793,6 +808,11 @@ class ConditionalRefineNet(nn.Module): def setup(self) -> None: """Instantiate a Refinement Network module.""" + if self.image_size not in [28, 32]: + raise ValueError( + "`image_size` must be either `28` or `32`, " + f"but got {self.image_size}." + ) self.conv_in = _conv_3x3( out_channels=self.latent_channels, @@ -947,14 +967,14 @@ def __call__( cond: jax.Array, **kwargs, # type: ignore[unused-argument] ) -> jax.Array: - """Forward pass of the conditional refinement network. + r"""Forward pass of the conditional refinement network. Args: inputs (jax.Array): Input feature map of shape `(*, H, W, C)`. cond (jax.Array): Condition feature map of shape `(*,)`. Returns: - jax.Array: Output feature map of shape `(*, H, W, C)`. + Output feature map of shape `(*, H, W, C)`. """ batch_dims = inputs.shape[:-3] dims = chex.Dimensions( @@ -1019,11 +1039,11 @@ def __call__( @staticmethod def _forward_cond_res_block( - module: nn.Module, + module: typing.Sequence[nn.Module], inputs: jax.Array, cond: jax.Array, ) -> jax.Array: - """Forward pass through a residual block with conditional inputs.""" + r"""Forward pass through a residual block with conditional inputs.""" for m in module: assert isinstance(m, ConditionalResidualBlock) inputs = m(inputs=inputs, cond=cond) diff --git a/src/projects/generative/model/test_refinenet.py b/src/projects/generative/model/test_refinenet.py index 7c4a7b2..cd78fc1 100644 --- a/src/projects/generative/model/test_refinenet.py +++ b/src/projects/generative/model/test_refinenet.py @@ -5,16 +5,16 @@ import chex from flax import linen as nn import jax -import jax.numpy as jnp +from jax import numpy as jnp import pytest -from learning.generative.model import refinenet +from src.projects.generative.model import refinenet @pytest.mark.parametrize("out_channels", [1, 3]) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) def test_conv_1x1(out_channels: int, dtype: typing.Any) -> None: - """Test 1x1 convolution builder.""" + r"""Test 1x1 convolution builder.""" layer = refinenet._conv_1x1( out_channels=out_channels, name="conv1", @@ -40,7 +40,7 @@ def test_conv_1x1(out_channels: int, dtype: typing.Any) -> None: @pytest.mark.parametrize("out_channels", [1, 3]) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) def test_conv_3x3(out_channels: int, dtype: typing.Any) -> None: - """Test 3x3 convolution builder.""" + r"""Test 3x3 convolution builder.""" layer = refinenet._conv_3x3( out_channels=out_channels, name="conv3", @@ -71,7 +71,7 @@ def test_dilated_conv_3x3( dilation: int, dtype: typing.Any, ) -> None: - """Test dilated 3x3 convolution builder.""" + r"""Test dilated 3x3 convolution builder.""" layer = refinenet._dilated_conv_3x3( out_channels=out_channels, dilation=dilation, @@ -106,7 +106,7 @@ def test_conditional_instance_norm_2d_plus( use_bias: bool, dtype: typing.Any, ) -> None: - """Test `ConditionalInstanceNorm2dPlus` layer.""" + r"""Test `ConditionalInstanceNorm2dPlus` layer.""" layer = refinenet.ConditionalInstanceNorm2dPlus( features=features, num_classes=num_classes, @@ -136,7 +136,7 @@ def test_conditional_instance_norm_2d_plus( chex.assert_type(variables["params"]["embed"]["embedding"], dtype) test_output = layer.apply( variables, - jnp.ones((1, 32, 32, features)), + jnp.ones((1, 32, 32, features), dtype=dtype), jnp.ones((1,), dtype=jnp.int32), ) chex.assert_type(test_output, dtype) @@ -151,7 +151,7 @@ def test_conv_mean_pool( kernel_size: int, dtype: typing.Any, ) -> None: - """Test `ConvMeanPool` layer.""" + r"""Test `ConvMeanPool` layer.""" layer = refinenet.ConvMeanPool( features=features, kernel_size=kernel_size, @@ -188,7 +188,7 @@ def test_conditional_residual_block( resample: typing.Optional[str], dtype: typing.Any, ) -> None: - """Test `ConditionalResidualBlock` module.""" + r"""Test `ConditionalResidualBlock` module.""" if resample not in (None, "down"): with pytest.raises(ValueError): block = refinenet.ConditionalResidualBlock( @@ -357,7 +357,7 @@ def test_conditional_residual_block( @pytest.mark.parametrize("features", [1, 3]) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) def test_conditional_rcu_block(features: int, dtype: typing.Any) -> None: - """Test the `ConditionalRCUBlock` module.""" + r"""Test the `ConditionalRCUBlock` module.""" block = refinenet.ConditionalRCUBlock( features=features, norm_module=functools.partial( @@ -451,20 +451,20 @@ def test_conditional_msf_block(features: int, dtype: typing.Any) -> None: test_output = block.apply( variables, inputs=[ - jnp.ones((2, 32, 32, 3), dtype=jnp.float32), - jnp.ones((2, 16, 16, 8), dtype=jnp.float32), + jnp.ones((2, 32, 32, 3), dtype=dtype), + jnp.ones((2, 16, 16, 8), dtype=dtype), ], cond=jnp.ones((2,), dtype=jnp.int32), shape=(28, 28), ) - chex.assert_type(test_output, jnp.float32) + chex.assert_type(test_output, dtype) chex.assert_shape(test_output, (2, 28, 28, features)) @pytest.mark.parametrize("features", [1, 3]) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) def test_conditional_crp_block(features: int, dtype: typing.Any) -> None: - """Test the `ConditionalCRPBlock` module.""" + r"""Test the `ConditionalCRPBlock` module.""" block = refinenet.ConditionalCRPBlock( features=features, norm_module=functools.partial( @@ -486,18 +486,15 @@ def test_conditional_crp_block(features: int, dtype: typing.Any) -> None: variables["params"][f"convs.{i:d}"]["kernel"], (3, 3, features, features), ) - chex.assert_type( - variables["params"][f"convs.{i:d}"]["kernel"], - jnp.float32, - ) + chex.assert_type(variables["params"][f"convs.{i:d}"]["kernel"], dtype) assert variables["params"][f"convs.{i:d}"].get("bias") is None test_output = block.apply( variables, - jnp.ones((1, 32, 32, features), dtype=jnp.float32), + jnp.ones((1, 32, 32, features), dtype=dtype), jnp.ones((1,), dtype=jnp.int32), ) - chex.assert_type(test_output, jnp.float32) + chex.assert_type(test_output, dtype) chex.assert_shape(test_output, (1, 32, 32, features)) @@ -509,10 +506,10 @@ def test_conditional_refine_block( dtype: typing.Any, is_last_block: bool, ) -> None: - """Test the `ConditionalRefineBlock` module.""" + r"""Test the `ConditionalRefineBlock` module.""" test_inputs = [ - jnp.ones((2, 32, 32, 3), dtype=jnp.float32), - jnp.ones((2, 16, 16, 8), dtype=jnp.float32), + jnp.ones((2, 32, 32, 3), dtype=dtype), + jnp.ones((2, 16, 16, 8), dtype=dtype), ] block = refinenet.ConditionalRefineBlock( in_features=[3, 8], @@ -534,7 +531,7 @@ def test_conditional_refine_block( _ = block.init( jax.random.PRNGKey(0), inputs=[ - jnp.ones((2, 32, 32, 3), dtype=jnp.float32), + jnp.ones((2, 32, 32, 3), dtype=dtype), ], cond=jnp.ones((2,), dtype=jnp.int32), output_shape=(28, 28), @@ -551,12 +548,13 @@ def test_conditional_refine_block( cond=jnp.ones((2,), dtype=jnp.int32), output_shape=(28, 28), ) - chex.assert_type(test_output, jnp.float32) + chex.assert_type(test_output, dtype) chex.assert_shape(test_output, (2, 28, 28, features)) -def test_conditional_refinenet() -> None: - """Integrated test for the `ConditionalRefineNet` module.""" +@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) +def test_conditional_refinenet(dtype: typing.Any) -> None: + r"""Integrated test for the `ConditionalRefineNet` module.""" model = refinenet.ConditionalRefineNet( in_channels=3, image_size=32, @@ -565,17 +563,17 @@ def test_conditional_refinenet() -> None: refinenet.ConditionalInstanceNorm2dPlus, num_classes=10, ), - dtype=jnp.float32, - param_dtype=jnp.float32, + dtype=dtype, + param_dtype=dtype, ) assert isinstance(model, nn.Module) with pytest.raises(AssertionError): _ = model.init( jax.random.PRNGKey(0), - jnp.ones((2, 28, 28, 1), dtype=jnp.float32), + jnp.ones((2, 28, 28, 1), dtype=dtype), jnp.ones((2,), dtype=jnp.int32), ) - test_input = jnp.ones((2, 32, 32, 3), dtype=jnp.float32) + test_input = jnp.ones((2, 32, 32, 3), dtype=dtype) variables = model.init( jax.random.PRNGKey(0), test_input, @@ -586,7 +584,7 @@ def test_conditional_refinenet() -> None: test_input, jnp.ones((2,), dtype=jnp.int32), ) - chex.assert_type(test_output, jnp.float32) + chex.assert_type(test_output, dtype) chex.assert_shape(test_output, (2, 32, 32, 3)) From b60e7469909e427389eac757b848c4d0370d1bbd Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 05:50:04 -0500 Subject: [PATCH 03/67] hotfix: Fixed typo in core module Signed-off-by: Juanwu Lu --- src/core/BUILD | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/core/BUILD b/src/core/BUILD index 1f39da2..168245e 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -8,7 +8,7 @@ ml_py_library( deps = [ "fiddle", "optax", - ":data", + ":datamodule", ":model", ], ) @@ -24,7 +24,7 @@ ml_py_library( deps = [ "clu", "jax", - ":data", + ":datamodule", ":model", "//src/utilities:logging", ], @@ -60,7 +60,7 @@ ml_py_library( "flax", "jax", "jaxtyping", - ":data", + ":datamodule", ":model", ":train_state", "//src/utilities:logging", From 95be2cb837835e133c6a8dff2c79623862c94142 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 05:50:43 -0500 Subject: [PATCH 04/67] feat: Updated implementation of HuggingFace datamodule Signed-off-by: Juanwu Lu --- src/data/BUILD | 1 + src/data/huggingface.py | 110 ++++++++++++++++------------------- src/data/test_huggingface.py | 17 +++--- 3 files changed, 61 insertions(+), 67 deletions(-) diff --git a/src/data/BUILD b/src/data/BUILD index 908264f..c268513 100644 --- a/src/data/BUILD +++ b/src/data/BUILD @@ -20,6 +20,7 @@ ml_py_test( name = "test_huggingface", srcs = ["test_huggingface.py"], deps = [ + "jax", "numpy", "tensorflow", ":huggingface", diff --git a/src/data/huggingface.py b/src/data/huggingface.py index a4d477a..a445412 100644 --- a/src/data/huggingface.py +++ b/src/data/huggingface.py @@ -40,12 +40,13 @@ class HuggingFaceDataModule(datamodule.DataModule): deterministic (bool): Whether to enforce deterministic loading behavior. drop_remainder (bool): Whether to drop the last incomplete batch. num_workers (int): Number of shards for distributed loading. - seed (int): Random seed for shuffling. shuffle_buffer_size (int): Buffer size for shuffling the dataset. transform (Optional[Callable], optional): An optional function to - transform the input features. Defaults to `None`. + transform the input features. Default is `None`. target_transform (Optional[Callable], optional): An optional function - to transform the target features. Defaults to `None`. + to transform the target features. Default is `None`. + rng (jax.Array, optional): Random key for shuffling. + Default is `random.PRNGKey(42)`. """ def __init__( @@ -54,21 +55,17 @@ def __init__( deterministic: bool, drop_remainder: bool, num_workers: int, - seed: int, shuffle_buffer_size: int, transform: typing.Optional[typing.Callable] = None, target_transform: typing.Optional[typing.Callable] = None, + rng: jax.Array = random.PRNGKey(42), ) -> None: self._batch_size = batch_size self._deterministic = deterministic self._drop_remainder = drop_remainder self._num_workers = num_workers - self._seed = seed self._shuffle_buffer_size = shuffle_buffer_size - self._rng = random.fold_in( - random.PRNGKey(self._seed), - jax.process_index(), - ) + self._rng = rng self._transform = transform self._target_transform = target_transform @@ -154,11 +151,6 @@ def num_test_examples(self) -> int: r"""int: Number of test examples.""" return len(self.hf_dataset["test"]) # type: ignore - @property - def seed(self) -> int: - r"""int: Random seed for shuffling.""" - return self._seed - @property def shuffle_buffer_size(self) -> int: r"""int: Buffer size for shuffling the dataset.""" @@ -179,6 +171,11 @@ def target_transform(self) -> typing.Optional[typing.Callable]: r"""Optional[Callable]: Transformation for the target features.""" return self._target_transform + @property + def rng(self) -> jax.Array: + r"""jax.Array: Random key for shuffling.""" + return self._rng + def train_dataloader(self) -> typing.Generator[PyTree, None, None]: r"""Returns an iterable over the training dataset.""" self._rng, shuffle_rng = random.split(self._rng, num=2) @@ -214,7 +211,10 @@ class HuggingFaceImageDataModule(HuggingFaceDataModule): resample (int): Resampling filter to use for resizing images. transform (Optional[Callable], optional): An optional function to transform the input images. Defaults to `None`. - seed (int, optional): Random seed for shuffling. Defaults to `42`. + target_transform (Optional[Callable], optional): An optional function + to transform the target features. Defaults to `None`. + rng (jax.Array, optional): Random key for shuffling. + Default is `random.PRNGKey(42)`. """ def __init__( @@ -226,9 +226,9 @@ def __init__( resize: int, resample: int, shuffle_buffer_size: int, - seed: int, transform: typing.Optional[typing.Callable] = None, target_transform: typing.Optional[typing.Callable] = None, + rng: jax.Array = random.PRNGKey(42), ) -> None: r"""Instantiates a `HuggingFaceImageDataModule` object.""" self._resize = resize @@ -240,11 +240,24 @@ def __init__( drop_remainder=drop_remainder, num_workers=num_workers, shuffle_buffer_size=shuffle_buffer_size, - seed=seed, transform=transform, target_transform=target_transform, + rng=rng, ) + @property + def image_shape(self) -> typing.Tuple[int, int, int]: + r"""Tuple[int, int, int]: The shape of the images.""" + return (self._resize, self._resize, 3) + + @property + def output_signature(self) -> typing.Dict[str, tf.TensorSpec]: + r"""Dict[str, tf.TensorSpec]: Tensor specifications.""" + return { + "image": tf.TensorSpec(shape=self.image_shape, dtype=tf.uint8), # type: ignore + "label": tf.TensorSpec(shape=(), dtype=tf.int64), # type: ignore + } + def _create_dataset( self, *, @@ -298,7 +311,7 @@ def __hf_generator() -> typing.Generator[typing.Any, None, None]: bottom = (new_height + self._resize) / 2 image = image.crop((left, top, right, bottom)) - yield image, target + yield {"image": image, "label": target} ds = tf.data.Dataset.from_generator( __hf_generator, @@ -384,11 +397,12 @@ class CIFAR10DataModule(HuggingFaceImageDataModule): image to before cropping. Defaults to `224`. resample (int, optional): Resampling filter to use when resizing images. Defaults to `3` (PIL.Image.BICUBIC). - seed (int, optional): Random seed for shuffling. Defaults to `42`. shuffle_buffer_size (int, optional): Buffer size for random shuffling. Defaults to `10_000`. streaming (bool, optional): Whether to stream the dataset using the `datasets` library. Defaults to `False`. + rng (jax.Array, optional): Random key for shuffling. + Default is `random.PRNGKey(42)`. """ def __init__( @@ -399,11 +413,11 @@ def __init__( num_workers: int = 4, resize: int = 224, resample: int = 3, - seed: int = 42, shuffle_buffer_size: int = 10_000, streaming: bool = False, transform: typing.Optional[typing.Callable] = None, target_transform: typing.Optional[typing.Callable] = None, + rng: jax.Array = random.PRNGKey(42), ) -> None: self._hf_dataset = datasets.load_dataset( path="uoft-cs/cifar10", @@ -418,10 +432,10 @@ def __init__( num_workers=num_workers, resize=resize, resample=resample, - seed=seed, shuffle_buffer_size=shuffle_buffer_size, transform=transform, target_transform=target_transform, + rng=rng, ) @property @@ -439,14 +453,6 @@ def target_key(self) -> str: r"""str: The key in the dataset features to use as target.""" return "label" - @property - def output_signature(self) -> typing.Tuple[tf.TensorSpec, tf.TensorSpec]: - r"""Tuple[tf.TensorSpec, tf.TensorSpec]: Tensor specifications.""" - return ( - tf.TensorSpec(shape=(224, 224, 3), dtype=tf.uint8), # type: ignore - tf.TensorSpec(shape=(), dtype=tf.int64), # type: ignore - ) - @property @typing_extensions.override def num_val_examples(self) -> int: @@ -489,6 +495,8 @@ class CIFAR100DataModule(HuggingFaceImageDataModule): Defaults to `10_000`. streaming (bool, optional): Whether to stream the dataset using the `datasets` library. Defaults to `False`. + rng (jax.Array, optional): Random key for shuffling. + Defaults to `random.PRNGKey(42)`. """ def __init__( @@ -499,11 +507,11 @@ def __init__( num_workers: int = 4, resize: int = 224, resample: int = 3, - seed: int = 42, shuffle_buffer_size: int = 10_000, streaming: bool = False, transform: typing.Optional[typing.Callable] = None, target_transform: typing.Optional[typing.Callable] = None, + rng: jax.Array = random.PRNGKey(42), ) -> None: self._hf_dataset = datasets.load_dataset( path="uoft-cs/cifar100", @@ -518,10 +526,10 @@ def __init__( num_workers=num_workers, resize=resize, resample=resample, - seed=seed, shuffle_buffer_size=shuffle_buffer_size, transform=transform, target_transform=target_transform, + rng=rng, ) @property @@ -539,14 +547,6 @@ def target_key(self) -> str: r"""str: The key in the dataset features to use as target.""" return "fine_label" - @property - def output_signature(self) -> typing.Tuple[tf.TensorSpec, tf.TensorSpec]: - r"""Tuple[tf.TensorSpec, tf.TensorSpec]: Tensor specifications.""" - return ( - tf.TensorSpec(shape=(224, 224, 3), dtype=tf.uint8), # type: ignore - tf.TensorSpec(shape=(), dtype=tf.int64), # type: ignore - ) - @property @typing_extensions.override def num_val_examples(self) -> int: @@ -583,11 +583,12 @@ class ImageNet1KDataModule(HuggingFaceImageDataModule): image to before cropping. Defaults to `224`. resample (int, optional): Resampling filter to use when resizing images. Defaults to `3` (PIL.Image.BICUBIC). - seed (int, optional): Random seed for shuffling. Defaults to `42`. shuffle_buffer_size (int, optional): Buffer size for random shuffling. Defaults to `10_000`. streaming (bool, optional): Whether to stream the dataset using the `datasets` library. Defaults to `False`. + rng (jax.Array, optional): Random key for shuffling. + Default is `random.PRNGKey(42)`. """ def __init__( @@ -598,11 +599,11 @@ def __init__( num_workers: int = 4, resize: int = 224, resample: int = 3, - seed: int = 42, shuffle_buffer_size: int = 10_000, streaming: bool = False, transform: typing.Optional[typing.Callable] = None, target_transform: typing.Optional[typing.Callable] = None, + rng: jax.Array = random.PRNGKey(42), ) -> None: self._hf_dataset = datasets.load_dataset( path="ILSVRC/imagenet-1k", @@ -617,10 +618,10 @@ def __init__( num_workers=num_workers, resize=resize, resample=resample, - seed=seed, shuffle_buffer_size=shuffle_buffer_size, transform=transform, target_transform=target_transform, + rng=rng, ) @property @@ -638,14 +639,6 @@ def target_key(self) -> str: r"""str: The key in the dataset features to use as target.""" return "label" - @property - def output_signature(self) -> typing.Tuple[tf.TensorSpec, tf.TensorSpec]: - r"""Tuple[tf.TensorSpec, tf.TensorSpec]: Tensor specifications.""" - return ( - tf.TensorSpec(shape=(224, 224, 3), dtype=tf.uint8), # type: ignore - tf.TensorSpec(shape=(), dtype=tf.int64), # type: ignore - ) - class MNISTDataModule(HuggingFaceImageDataModule): r"""MNIST Handwritten Digit Dataset. @@ -667,11 +660,12 @@ class MNISTDataModule(HuggingFaceImageDataModule): image to before cropping. Defaults to `224`. resample (int, optional): Resampling filter to use when resizing images. Defaults to `3` (PIL.Image.BICUBIC). - seed (int, optional): Random seed for shuffling. Defaults to `42`. shuffle_buffer_size (int, optional): Buffer size for random shuffling. Defaults to `10_000`. streaming (bool, optional): Whether to stream the dataset using the `datasets` library. Defaults to `False`. + rng (jax.Array, optional): Random key for shuffling. + Default is `random.PRNGKey(42)`. """ def __init__( @@ -682,11 +676,11 @@ def __init__( num_workers: int = 4, resize: int = 224, resample: int = 3, - seed: int = 42, shuffle_buffer_size: int = 10_000, streaming: bool = False, transform: typing.Optional[typing.Callable] = None, target_transform: typing.Optional[typing.Callable] = None, + rng: jax.Array = random.PRNGKey(42), ) -> None: self._hf_dataset = datasets.load_dataset( path="ylecun/mnist", @@ -701,10 +695,10 @@ def __init__( num_workers=num_workers, resize=resize, resample=resample, - seed=seed, shuffle_buffer_size=shuffle_buffer_size, transform=transform, target_transform=target_transform, + rng=rng, ) @property @@ -723,12 +717,10 @@ def target_key(self) -> str: return "label" @property - def output_signature(self) -> typing.Tuple[tf.TensorSpec, tf.TensorSpec]: - r"""Tuple[tf.TensorSpec, tf.TensorSpec]: Tensor specifications.""" - return ( - tf.TensorSpec(shape=(224, 224, 3), dtype=tf.uint8), # type: ignore - tf.TensorSpec(shape=(), dtype=tf.int64), # type: ignore - ) + @typing_extensions.override + def image_shape(self) -> typing.Tuple[int, int, int]: + r"""Tuple[int, int, int]: The shape of the images.""" + return (self._resize, self._resize, 1) @property @typing_extensions.override diff --git a/src/data/test_huggingface.py b/src/data/test_huggingface.py index 9dc153e..603eef2 100644 --- a/src/data/test_huggingface.py +++ b/src/data/test_huggingface.py @@ -1,6 +1,7 @@ import sys import typing +import jax import numpy as np import pytest import tensorflow as tf @@ -25,9 +26,9 @@ def test_cifar10_datamodule() -> None: dm = huggingface.CIFAR10DataModule( batch_size=2, num_workers=1, - seed=0, transform=_default_transform, streaming=False, + rng=jax.random.PRNGKey(0), ) assert dm.batch_size == 2 assert dm.deterministic is True @@ -36,7 +37,7 @@ def test_cifar10_datamodule() -> None: assert dm.num_train_examples == 50000 assert dm.num_val_examples == 10000 assert dm.num_test_examples == 10000 - assert dm.seed == 0 + assert dm.rng == jax.random.PRNGKey(0) assert all(key in dm.splits for key in ["train", "test"]) # test training dataloader @@ -62,9 +63,9 @@ def test_cifar100_datamodule() -> None: dm = huggingface.CIFAR100DataModule( batch_size=2, num_workers=1, - seed=0, transform=_default_transform, streaming=False, + rng=jax.random.PRNGKey(0), ) assert dm.batch_size == 2 assert dm.deterministic is True @@ -73,7 +74,7 @@ def test_cifar100_datamodule() -> None: assert dm.num_train_examples == 50000 assert dm.num_val_examples == 10000 assert dm.num_test_examples == 10000 - assert dm.seed == 0 + assert dm.rng == jax.random.PRNGKey(0) assert all(key in dm.splits for key in ["train", "test"]) # test training dataloader @@ -99,9 +100,9 @@ def test_imagenet1k_datamodule() -> None: dm = huggingface.ImageNet1KDataModule( batch_size=2, num_workers=1, - seed=0, transform=_default_transform, streaming=False, + rng=jax.random.PRNGKey(0), ) assert dm.batch_size == 2 assert dm.deterministic is True @@ -110,7 +111,7 @@ def test_imagenet1k_datamodule() -> None: assert dm.num_train_examples == 1_281_167 assert dm.num_val_examples == 50_000 assert dm.num_test_examples == 100_000 - assert dm.seed == 0 + assert dm.rng == jax.random.PRNGKey(0) assert all(key in dm.splits for key in ["train", "validation", "test"]) # test training dataloader @@ -136,9 +137,9 @@ def test_mnist_datamodule() -> None: dm = huggingface.MNISTDataModule( batch_size=2, num_workers=1, - seed=0, transform=_default_transform, streaming=False, + rng=jax.random.PRNGKey(0), ) assert dm.batch_size == 2 assert dm.deterministic is True @@ -147,7 +148,7 @@ def test_mnist_datamodule() -> None: assert dm.num_train_examples == 60000 assert dm.num_val_examples == 10000 assert dm.num_test_examples == 10000 - assert dm.seed == 0 + assert dm.rng == jax.random.PRNGKey(0) assert all(key in dm.splits for key in ["train", "test"]) # test training dataloader From 2e923ceecb69561bec927bc5b0317fa1fd998018 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 05:51:09 -0500 Subject: [PATCH 05/67] hotfix: Fixed build target for utility module Signed-off-by: Juanwu Lu --- src/utilities/BUILD | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/utilities/BUILD b/src/utilities/BUILD index 8c4078b..41825a8 100644 --- a/src/utilities/BUILD +++ b/src/utilities/BUILD @@ -1,6 +1,6 @@ -load("//learning:defs.bzl", "ml_py_library") +load("//third_party:defs.bzl", "ml_py_library") -package(default_visibility = ["//learning:__subpackages__"]) +package(default_visibility = ["//src:__subpackages__"]) ml_py_library( name = "logging", @@ -8,7 +8,6 @@ ml_py_library( deps = [ "absl-py", "jax", - "jaxlib", ], ) @@ -17,6 +16,5 @@ ml_py_library( srcs = ["rank_zero.py"], deps = [ "jax", - "jaxlib", ], ) From 7ca6e1e0d8fa3c6e395b859d4593e3df76ac97ae Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 05:51:34 -0500 Subject: [PATCH 06/67] hotfix: Fixed typo in build targets Signed-off-by: Juanwu Lu --- src/projects/generative/model/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/projects/generative/model/BUILD b/src/projects/generative/model/BUILD index b8c1565..7b747b5 100644 --- a/src/projects/generative/model/BUILD +++ b/src/projects/generative/model/BUILD @@ -1,6 +1,6 @@ load("//third_party:defs.bzl", "ml_py_library", "ml_py_test") -package(default_visibility = ["//learning/generative:__subpackages__"]) +package(default_visibility = ["//src/projects/generative:__subpackages__"]) ml_py_library( name = "refinenet", From 35ad18407aeb40b1160dfae9fccb3b0921060224 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 05:52:26 -0500 Subject: [PATCH 07/67] feat: Updated implementation of MeanFlow Signed-off-by: Juanwu Lu --- src/projects/generative/meanflow.py | 410 +++++++++++++---------- src/projects/generative/test_meanflow.py | 190 ----------- 2 files changed, 233 insertions(+), 367 deletions(-) delete mode 100644 src/projects/generative/test_meanflow.py diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 9201d26..3923b93 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -4,49 +4,35 @@ from flax import linen as nn from flax.core import frozen_dict import jax -import jax.core as jax_core -import jax.numpy as jnp +from jax import numpy as jnp +from jax._src import typing as jax_typing import jaxtyping import typing_extensions -from learning.core import mixin as _mixin -from learning.generative.model import refinenet +from src.core import model as _model +from src.core import train_state as _train_state +from src.projects.generative.model import refinenet # Type Aliases PyTree = jaxtyping.PyTree -# ============================================================================== -# Data Structures -# ============================================================================== -@chex.dataclass -class MeanFlowOutputs: - """Generic output structure from a `MeanFlow` model.""" - - loss: typing.Optional[jax.Array] = None - """jax.Array: The training loss.""" - velocity_loss: typing.Optional[jax.Array] = None - """jax.Array: The velocity loss for monitoring.""" - output: typing.Optional[jax.Array] = None - """jax.Array: The model output.""" - - # ============================================================================== # Helper functions # ============================================================================== def sample_t_r( *, - key: jax.random.KeyArray, - shape: jax_core.Shape, + key: jax.Array, + shape: jax_typing.Shape, dtype: typing.Any, - distribution: typing.Literal["uniform", "lognormal"], + distribution: str, **kwargs, ) -> typing.Tuple[jax.Array, jax.Array]: """Samples begin and end timestamps randomly from a given distribution. Attributes: - key (jax.random.KeyArray): JAX random key. - shape (jax_core.Shape): The shape of the output arrays. + key (jax.Array): JAX random key. + shape (jax.typing.Shape): The shape of the output arrays. dtype (dtype): The dtype of the output arrays. distribution (str): The distribution to sample from. One of `["uniform", "lognormal"]`. @@ -78,8 +64,8 @@ def sample_t_r( elif distribution == "lognormal": def _lognormal( - key: jax.random.KeyArray, - shape: jax_core.Shape, + key: jax.Array, + shape: jax_typing.Shape, dtype: typing.Any, mean: float, stddev: float, @@ -277,7 +263,7 @@ def __call__( def _drop_token( cond: jax.Array, dropout_rate: float, - rng: jax.random.KeyArray, + rng: jax.Array, ) -> jax.Array: """Drops class tokens for classifier-free guidance.""" raise NotImplementedError("This method is not yet implemented.") @@ -398,13 +384,15 @@ class MeanFlowUNetModule(nn.Module): """Optional[bool]: Whether to run deterministically.""" dropout_rate: float = 0.0 """float: Dropout rate for the classifier-free guidance.""" - dtype: typing.Any = jnp.float32 + dtype: typing.Any = None """typing.Any: The dtype of the computation.""" - param_dtype: typing.Any = jnp.float32 + param_dtype: typing.Any = None """typing.Any: The dtype of the parameters.""" + precision: typing.Any = None + """typing.Any: The precision of the computation.""" def setup(self) -> None: - """Instantiate a `MeanFlowUNetModel` module.""" + r"""Instantiate a `MeanFlowUNetModel` module.""" self.backbone = refinenet.ConditionalRefineNet( in_channels=self.in_channels, image_size=self.image_size, @@ -478,19 +466,43 @@ def __call__( ) y_emb = self.label_embed(label, deterministic=m_deterministic) - r_emb = self.r_embed(begin) - t_emb = self.t_embed(end) + if begin is not None: + r_emb = self.r_embed(begin) + else: + r_emb = jnp.zeros_like(y_emb) + if end is not None: + t_emb = self.t_embed(end) + else: + t_emb = jnp.zeros_like(y_emb) cond = t_emb + r_emb + y_emb output = self.backbone(inputs=image, cond=cond) return output -class MeanFlowUNetModel(_mixin.ModelMixin): - """`MeanFlow` generative model with a U-Net backbone.""" +class MeanFlowUNetModel(_model.Model): + r"""`MeanFlow` generative model with a U-Net backbone. - module_class = MeanFlowUNetModule - """Type[nn.Module]: The class of the model module.""" + Args: + in_channels (int): Number of channels in the input images. + image_size (int): Height and width of the (square) input images. + latent_channels (int): Number of channels in the latent feature maps. + num_classes (int): Number of conditioning classes. + use_cfg_embedding (bool): Whether to use classifier-free guidance (CFG). + dropout_rate (float): Dropout rate for the classifier-free guidance. + dtype (dtype): The dtype of the computation (default: float32). + param_dtype (dtype): The dtype of the parameters (default: float32). + timestamp_cond (Literal): The type of timestamp conditioning. + One of `["t_and_r", "t_and_t_minus_r", + "t_and_r_and_t_minus_r", "t_minus_r"]`. + timestamp_sampler (str): The distribution to sample timestamps from. + One of `["uniform", "lognormal"]`. + timestamp_sampler_kwargs (Dict[str, Any]): Additional keyword arguments + for the timestamp sampler. + timestamp_overlap_rate (float): The minimum overlap rate between + begin and end timestamps. + adaptive_weight_power (float): The power for adaptive weight scaling. + """ def __init__( self, @@ -500,8 +512,9 @@ def __init__( num_classes: int, use_cfg_embedding: bool, dropout_rate: float, - dtype: typing.Any = jnp.float32, - param_dtype: typing.Any = jnp.float32, + dtype: typing.Any = None, + param_dtype: typing.Any = None, + precision: typing.Any = None, timestamp_cond: typing.Literal[ "t_and_r", "t_and_t_minus_r", @@ -524,7 +537,7 @@ def __init__( self.timestamp_sampler_kwargs = timestamp_sampler_kwargs self.timestamp_overlap_rate = timestamp_overlap_rate self.adaptive_weight_power = adaptive_weight_power - self._module = MeanFlowUNetModule( + self._network = MeanFlowUNetModule( in_channels=in_channels, image_size=image_size, latent_channels=latent_channels, @@ -534,101 +547,112 @@ def __init__( name="unet", dtype=dtype, param_dtype=param_dtype, + precision=precision, ) + @property @typing_extensions.override - def compute_loss( + def network(self) -> MeanFlowUNetModule: + r"""MeanFlowUNetModule: The U-Net neural network module.""" + return self._network + + def init( self, *, - rngs: typing.Union[ - jax.random.KeyArray, - typing.Dict[str, jax.random.KeyArray], - ], - image: jax.Array, - label: jax.Array, - params: frozen_dict.FrozenDict, - deterministic: bool = False, + batch: typing.Any, + rngs: typing.Any, **kwargs, - ) -> MeanFlowOutputs: - """Computes the loss given parameters and model inputs. + ) -> PyTree: + del batch # unused - Args: - rngs (Union[jax.random.KeyArray, Dict[str, jax.random.KeyArray]]): - JAX random key or a dictionary of JAX random keys. - image (jax.Array): The input images of shape `(*, H, W, C)`. - label (jax.Array): The class labels of shape `(*,)`. - params (frozen_dict.FrozenDict): The model parameters. - deterministic (bool): Whether to run the model deterministically. - **kwargs: additional keyword arguments. - - Returns: - MeanFlowOutputs: The model outputs. - """ - # NOTE: following the notation in Algorithm 1 of the source paper - # sample t and r - batch_dims = image.shape[:-3] - rngs, tr_rng, mask_rng, e_rng = jax.random.split(rngs, num=4) - t, r = sample_t_r( - key=tr_rng, - shape=batch_dims, - dtype=image.dtype, - distribution=self.timestamp_sampler, - **self.timestamp_sampler_kwargs, - ) - t, r = jnp.maximum(t, r), jnp.minimum(t, r) - # ensure a portion of overlap between t and r - r_neq_t_mask = jnp.greater_equal( - jax.random.uniform( - key=mask_rng, - shape=batch_dims, - dtype=image.dtype, - minval=0.0, - maxval=1.0, + # create dummy inputs + dummy_inputs = { + "image": jnp.zeros( + (1, self.image_size, self.image_size, self.in_channels), + dtype=jnp.float32, ), - self.timestamp_overlap_rate, + "label": jnp.zeros((1,), dtype=jnp.int32), + "begin": jnp.zeros((1,), dtype=jnp.float32), + "end": jnp.zeros((1,), dtype=jnp.float32), + } + variables = self.network.init( + rngs=rngs, + image=dummy_inputs["image"], + label=dummy_inputs["label"], + begin=dummy_inputs["begin"], + end=dummy_inputs["end"], + deterministic=True, ) - r = jnp.where(r_neq_t_mask, t, r) + _tabulate_fn = nn.summary.tabulate(self.network, rngs=rngs) + print(_tabulate_fn(**dummy_inputs, deterministic=True)) - # sample e ~ N(0, I) - e = jax.random.normal(key=e_rng, shape=image.shape, dtype=image.dtype) + return variables["params"] - # generate z_{t} - z = jnp.add( - (1 - t[..., None, None, None]) * image, - t[..., None, None, None] * e, + @typing_extensions.override + def evaluation_step( + self, + *, + params: PyTree, + batch: typing.Any, + rngs: typing.Any, + **kwargs, + ) -> _model.StepOutputs: + del kwargs # unused + + local_rng = jax.random.fold_in(rngs, jax.process_index()) + e_rng = jax.random.fold_in(local_rng, 0) + + image, label = batch["image"], batch["label"] + + batch_dims = image.shape[:-3] + dims = chex.Dimensions( + H=self.image_size, + W=self.image_size, + C=self.in_channels, + ) + chex.assert_shape(image, (*batch_dims, *dims["HWC"])) + chex.assert_shape(label, batch_dims) + + r = jnp.zeros(batch_dims, dtype=image.dtype) + t = jnp.ones(batch_dims, dtype=image.dtype) + sample = jnp.subtract( + image, + jnp.einsum( + "...,...n->...n", + (t - r), + self._u_fn( + label=label, + params=params, + deterministic=True, + )(image, r, t), + ), ) - v = e - image - # applies Jacobian vector product drdt = jnp.zeros_like(r) dtdt = jnp.ones_like(t) - + e = jax.random.normal(key=e_rng, shape=image.shape, dtype=image.dtype) + v = e - image u, dudt = jax.jvp( - self.u_fn( + self._u_fn( label=label, params=params, - deterministic=deterministic, + deterministic=True, ), - (z, r, t), + (e, r, t), (v, drdt, dtdt), ) - - # computes the target u_target = jax.lax.stop_gradient( v - jnp.clip(t - r, a_min=0.0, a_max=1.0)[..., None, None, None] - * dudt + * dudt, ) - # NOTE: sum over all the pixels, following official implementation loss = jnp.sum(jnp.square(u - u_target), axis=(-1, -2, -3)) - # applies adaptive weight power if self.adaptive_weight_power > 0.0: ada_wt = jnp.power(loss + 1e-2, self.adaptive_weight_power) loss = loss / jax.lax.stop_gradient(ada_wt) loss = jnp.mean(loss) - # calculate velocity loss for monitoring velocity_loss = jnp.where( jnp.equal(t, r)[..., None, None, None], jnp.square(u - v), @@ -636,103 +660,135 @@ def compute_loss( ) velocity_loss = jnp.sum(velocity_loss, axis=(-1, -2, -3)).mean() - return MeanFlowOutputs( - loss=loss, - velocity_loss=velocity_loss, - output=u, + return _model.StepOutputs( + scalars={"loss": loss, "velocity_loss": velocity_loss}, + images={"samples": sample}, ) @typing_extensions.override - def forward( + def predict_step( self, *, - rngs: typing.Union[ - jax.random.KeyArray, - typing.Dict[str, jax.random.KeyArray], - ], - params: frozen_dict.FrozenDict, - image: jax.Array, - label: jax.Array, - begin: typing.Optional[jax.Array] = None, - end: typing.Optional[jax.Array] = None, - deterministic: bool = False, + params: jaxtyping.PyTree, + batch: typing.Any, + rngs: typing.Any, **kwargs, - ) -> MeanFlowOutputs: - """Forward sampling with average velocity prediction. + ) -> typing.Any: + # TODO (juanwulu): implement predict step + raise NotImplementedError("Predict step is not implemented yet.") - Args: - params (frozen_dict.FrozenDict): The model parameters. - image (jax.Array): Input latent image `z_t` of shape `(*, H, W, C)`. - label (jax.Array): Conditioning labels of shape `(*,)`. - begin (jax.Array): Begin timestamp `r` of shape `(*, )`. - end (jax.Array): End timestamp `t` of shape `(*, )`. - deterministic (bool): Whether to run the model deterministically. - **kwargs: Additional keyword arguments. + @typing_extensions.override + def training_step( + self, + *, + state: _train_state.TrainState, + batch: typing.Any, + rngs: typing.Any, + **kwargs, + ) -> typing.Tuple[_train_state.TrainState, _model.StepOutputs]: + del kwargs # unused - Returns: - MeanFlowOutputs: The model outputs. - """ + local_rng = jax.random.fold_in(rngs, jax.process_index()) + tr_rng = jax.random.fold_in(local_rng, 0) + mask_rng = jax.random.fold_in(local_rng, 1) + e_rng = jax.random.fold_in(local_rng, 2) + + image, label = batch["image"], batch["label"] batch_dims = image.shape[:-3] - dims = chex.Dimensions( - H=self.image_size, - W=self.image_size, - C=self.in_channels, + + # step 1: randomly sample begin and end timestamps + t, r = sample_t_r( + key=tr_rng, + shape=batch_dims, + dtype=image.dtype, + distribution=self.timestamp_sampler, + **self.timestamp_sampler_kwargs, ) - chex.assert_shape(image, (*batch_dims, *dims["HWC"])) - chex.assert_shape(label, batch_dims) + t, r = jnp.maximum(t, r), jnp.minimum(t, r) + # ensure a portion of overlap between t and r + r_neq_t_mask = jnp.greater_equal( + jax.random.uniform( + key=mask_rng, + shape=batch_dims, + dtype=image.dtype, + minval=0.0, + maxval=1.0, + ), + self.timestamp_overlap_rate, + ) + r = jnp.where(r_neq_t_mask, t, r) - if begin is None: - begin = jnp.zeros(batch_dims, dtype=image.dtype) - if end is None: - end = jnp.ones(batch_dims, dtype=image.dtype) - chex.assert_shape(begin, batch_dims) - assert jnp.all(begin >= 0) and jnp.all( - begin <= 1 - ), "Invalid input `r`." - chex.assert_shape(end, batch_dims) - assert jnp.all(end >= 0) and jnp.all(end <= 1), "Invalid input `t`." - r, t = jnp.minimum(begin, end), jnp.maximum( - begin, end - ) # ensure r <= t + # sample noise e ~ N(0, I) + e = jax.random.normal(key=e_rng, shape=image.shape, dtype=image.dtype) - sample = jnp.subtract( - image, - jnp.einsum( - "...,...n->...n", - (t - r), - self.u_fn( + # generate intermediate z(t) + z = jnp.add( + (1 - t[..., None, None, None]) * image, + t[..., None, None, None] * e, + ) + + # calculate velocity v + v = e - image + + # step 2: compute the loss + def _loss_fn(params: PyTree) -> typing.Tuple[jax.Array, jax.Array]: + drdt = jnp.zeros_like(r) + dtdt = jnp.ones_like(t) + u, dudt = jax.jvp( + self._u_fn( label=label, params=params, - deterministic=deterministic, - )(image, r, t), - ), - ) + deterministic=False, + ), + (z, r, t), + (v, drdt, dtdt), + ) + u_target = jax.lax.stop_gradient( + v + - jnp.clip(t - r, a_min=0.0, a_max=1.0)[..., None, None, None] + * dudt + ) - return MeanFlowOutputs(output=sample) + # step 3: compute the loss + loss = jnp.sum(jnp.square(u - u_target), axis=(-1, -2, -3)) + if self.adaptive_weight_power > 0.0: + ada_wt = jnp.power(loss + 1e-2, self.adaptive_weight_power) + loss = loss / jax.lax.stop_gradient(ada_wt) + loss = jnp.mean(loss) + + # calculate velocity loss for monitoring + velocity_loss = jnp.where( + jnp.equal(t, r)[..., None, None, None], + jnp.square(u - v), + jnp.zeros_like(u), + ) + velocity_loss = jnp.sum(velocity_loss, axis=(-1, -2, -3)).mean() - @property - def dummy_input(self) -> PyTree: - """PyTree: A dictionary mapping feature names to example arrays.""" - return { - "image": jnp.zeros( - (1, self.image_size, self.image_size, self.in_channels), - dtype=jnp.float32, + return loss, velocity_loss + + grad_fn = jax.value_and_grad(_loss_fn, argnums=0, has_aux=True) + (loss, velocity_loss), grads = grad_fn(state.params) + loss = jax.lax.pmean(loss, axis_name="batch") + velocity_loss = jax.lax.pmean(velocity_loss, axis_name="batch") + new_state = state.apply_gradients(grads=grads) + + return ( + new_state, + _model.StepOutputs( + scalars={"loss": loss, "velocity_loss": velocity_loss}, ), - "label": jnp.zeros((1,), dtype=jnp.int32), - "begin": jnp.zeros((1,), dtype=jnp.float32), - "end": jnp.zeros((1,), dtype=jnp.float32), - } + ) - def u_fn( + def _u_fn( self, *, label: jax.Array, params: frozen_dict.FrozenDict, deterministic: bool = True, - ) -> typing.Callable[[jax.Array, jax.Array, jax.Array], jax.Array]: - """Returns the average velocity function `u(z_t, r, t)`.""" + ) -> typing.Callable[[jax.Array, jax.Array, jax.Array], typing.Any]: + r"""Returns the average velocity function `u(z_t, r, t)`.""" if self.timestamp_cond == "t_and_r": - return lambda z_t, r, t: self._module.apply( + return lambda z_t, r, t: self.network.apply( variables={"params": params}, image=z_t, label=label, @@ -741,7 +797,7 @@ def u_fn( deterministic=deterministic, ) elif self.timestamp_cond == "t_and_t_minus_r": - return lambda z_t, r, t: self._module.apply( + return lambda z_t, r, t: self.network.apply( variables={"params": params}, image=z_t, label=label, @@ -755,7 +811,7 @@ def u_fn( "Conditioning on (t, r, t - r) is not implemented yet." ) elif self.timestamp_cond == "t_minus_r": - return lambda z_t, r, t: self._module.apply( + return lambda z_t, r, t: self.network.apply( variables={"params": params}, image=z_t, label=label, diff --git a/src/projects/generative/test_meanflow.py b/src/projects/generative/test_meanflow.py deleted file mode 100644 index 906f2da..0000000 --- a/src/projects/generative/test_meanflow.py +++ /dev/null @@ -1,190 +0,0 @@ -import sys -import typing - -import chex -from flax import linen as nn -import jax -import jax.numpy as jnp -import pytest - -from learning.generative import meanflow - - -@pytest.mark.parametrize("distribution", ["uniform", "normal", "lognormal"]) -@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) -def test_sample_t_r(distribution: str, dtype: typing.Any) -> None: - """Test the `sample_t_r` function.""" - key = jax.random.PRNGKey(0) - shape = (2, 3) - - if distribution not in ["uniform", "lognormal"]: - with pytest.raises(ValueError): - meanflow.sample_t_r( - key=key, - shape=shape, - dtype=dtype, - distribution=distribution, - ) - return - - # Test uniform distribution - t, r = meanflow.sample_t_r( - key=key, - shape=shape, - dtype=dtype, - distribution="uniform", - ) - chex.assert_shape(t, shape) - chex.assert_shape(r, shape) - chex.assert_type(t, dtype) - chex.assert_type(r, dtype) - chex.assert_tree_all_finite(t) - chex.assert_tree_all_finite(r) - assert jnp.all(t >= 0) and jnp.all(t <= 1) - assert jnp.all(r >= 0) and jnp.all(r <= 1) - - -@pytest.mark.parametrize("features", [1, 8]) -@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) -def test_timestamp_embed(features: int, dtype: typing.Any) -> None: - """Test the `TimestampEmbed` module.""" - embed = meanflow.TimestampEmbed( - features=features, - frequency=256, - name="timestamp_embed", - dtype=dtype, - param_dtype=dtype, - ) - assert isinstance(embed, nn.Module) - assert embed.features == features - assert embed.frequency == 256 - assert embed.dtype == dtype - assert embed.param_dtype == dtype - variables = embed.init( - jax.random.PRNGKey(0), - jnp.ones((2,), dtype=jnp.int32), - ) - chex.assert_shape(variables["params"]["fc_in"]["kernel"], (256, features)) - chex.assert_type(variables["params"]["fc_in"]["kernel"], dtype) - chex.assert_shape(variables["params"]["fc_in"]["bias"], (features,)) - chex.assert_type(variables["params"]["fc_in"]["bias"], dtype) - chex.assert_shape( - variables["params"]["fc_out"]["kernel"], - (features, features), - ) - chex.assert_type(variables["params"]["fc_out"]["kernel"], dtype) - chex.assert_shape(variables["params"]["fc_out"]["bias"], (features,)) - chex.assert_type(variables["params"]["fc_out"]["bias"], dtype) - - test_output = embed.apply( - variables, - jnp.array([10, 1000], dtype=jnp.int32), - ) - chex.assert_shape(test_output, (2, features)) - chex.assert_type(test_output, dtype) - chex.assert_tree_all_finite(test_output) - - -@pytest.mark.parametrize("features", [1, 8]) -@pytest.mark.parametrize("use_cfg_embedding", [False, True]) -@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) -def test_condition_embed( - features: int, - use_cfg_embedding: bool, - dtype: typing.Any, -) -> None: - """Test the `ConditionEmbed` module.""" - if use_cfg_embedding: - # TODO: implement classifier-free guidance. - pytest.skip("Classifier-free guidance not supported yet.") - - embed = meanflow.ConditionEmbed( - features=features, - num_classes=10, - use_cfg_embedding=use_cfg_embedding, - name="condition_embed", - dtype=dtype, - param_dtype=dtype, - ) - assert isinstance(embed, nn.Module) - assert embed.features == features - assert embed.num_classes == 10 - assert embed.use_cfg_embedding == use_cfg_embedding - assert embed.dtype == dtype - assert embed.param_dtype == dtype - variables = embed.init( - jax.random.PRNGKey(0), - jnp.ones((2,), dtype=jnp.int32), - ) - chex.assert_shape( - variables["params"]["embedding_table"]["embedding"], - (10 + int(use_cfg_embedding), features), - ) - chex.assert_type( - variables["params"]["embedding_table"]["embedding"], dtype - ) - - test_output = embed.apply( - variables, - jnp.array([1, 9], dtype=jnp.int32), - ) - chex.assert_shape(test_output, (2, features)) - chex.assert_type(test_output, dtype) - chex.assert_tree_all_finite(test_output) - - -@pytest.mark.parametrize("features", [1, 8]) -@pytest.mark.parametrize("use_bias", [False, True]) -@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) -def test_conditional_instance_norm( - features: int, - use_bias: bool, - dtype: typing.Any, -) -> None: - """Test the `ConditionalInstanceNorm` module.""" - cond_features = 4 - cond = jnp.ones((2, cond_features), dtype=dtype) - - norm = meanflow.ConditionalInstanceNorm( - features=features, - use_bias=use_bias, - name="conditional_instance_norm", - dtype=dtype, - param_dtype=dtype, - ) - assert isinstance(norm, nn.Module) - assert norm.features == features - assert norm.use_bias == use_bias - assert norm.dtype == dtype - assert norm.param_dtype == dtype - variables = norm.init( - jax.random.PRNGKey(0), - jnp.ones((2, 16, 16, features), dtype=dtype), - cond, - ) - assert variables["params"].get("instance_norm") is None - if use_bias: - chex.assert_shape( - variables["params"]["embed"]["kernel"], - (cond_features, features * 3), - ) - assert variables["params"]["embed"].get("bias") is None - else: - chex.assert_shape( - variables["params"]["embed"]["kernel"], - (cond_features, features * 2), - ) - assert variables["params"]["embed"].get("bias") is None - - test_output = norm.apply( - variables, - jnp.ones((2, 16, 16, features), dtype=dtype), - cond, - ) - chex.assert_shape(test_output, (2, 16, 16, features)) - chex.assert_type(test_output, dtype) - chex.assert_tree_all_finite(test_output) - - -if __name__ == "__main__": - sys.exit(pytest.main(["-xv", __file__])) From 8ca3ed3366d81bda6223ef9c7c8c2b2e803efda8 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 05:53:10 -0500 Subject: [PATCH 08/67] feat: Added main entrypoint for training and evaluation of generative models Signed-off-by: Juanwu Lu --- src/projects/generative/main.py | 187 ++++++++++++++++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 src/projects/generative/main.py diff --git a/src/projects/generative/main.py b/src/projects/generative/main.py new file mode 100644 index 0000000..df6b0a6 --- /dev/null +++ b/src/projects/generative/main.py @@ -0,0 +1,187 @@ +from datetime import datetime +import os +import platform +import typing + +from absl import app +from absl import flags +from clu import checkpoint +from clu import metric_writers +from clu import platform as clu_platform +from fiddle import absl_flags +import fiddle as fdl +import jax +import optax +import tensorflow as tf + +from src.core import config as _config +from src.core import evaluate as _evaluate +from src.core import train as _train +from src.core import train_state as _train_state +from src.utilities import logging + +CONFIG = absl_flags.DEFINE_fiddle_config( + name="experiment", + default=None, + help_string="Experiment configuration.", + required=True, +) +FLAGS = flags.FLAGS +flags.DEFINE_string( + name="work_dir", + default=None, + help="Directory to store the experiment results.", + required=True, +) + + +# toggle off GPU/TPU for TensorFlow +tf.config.experimental.set_visible_devices([], "GPU") +tf.config.experimental.set_visible_devices([], "TPU") +assert not tf.config.experimental.get_visible_devices("GPU") + + +def main(_: typing.List[str]) -> int: + r"""Main entry point for training and evaluate generative models.""" + del _ # unused. + + # Log the current platform + logging.rank_zero_info("Running on platform: %s", platform.node()) + + # Setup JAX runtime + logging.rank_zero_info("Running on JAX backend: %s", jax.default_backend()) + logging.rank_zero_info( + "Running on JAX process: %d / %d", + jax.process_index() + 1, + jax.process_count(), + ) + logging.rank_zero_info("Running on JAX devices: %r", jax.devices()) + + clu_platform.work_unit().set_task_status( + "process_index: %d, process_count: %d" + % (jax.process_index() + 1, jax.process_count()), + ) + clu_platform.work_unit().create_artifact( + clu_platform.ArtifactType.DIRECTORY, + FLAGS.work_dir, + "Working directory.", + ) + + # Setup Experiment + exp_config = CONFIG.value + if not isinstance(exp_config, _config.ExperimentConfig): + logging.rank_zero_error( + "Expect configuration to be of type `ExperimentConfig`, got %s.", + type(exp_config), + ) + return 1 + logging.rank_zero_info("Experiment Configuration:\n%s", exp_config) + + rng = jax.random.PRNGKey(exp_config.seed) + log_dir = os.path.join( + FLAGS.work_dir, + exp_config.name, + datetime.now().strftime("%Y%m%d_%H%M%S"), + ) + writer = metric_writers.create_default_writer( + logdir=log_dir, + just_logging=(jax.process_index() > 0), + ) + + logging.rank_zero_info("Building dataset...") + rng, data_rng = jax.random.split(rng, num=2) + p_datamodule = fdl.build(exp_config.data.module) + datamodule = p_datamodule( + batch_size=exp_config.data.batch_size, + deterministic=exp_config.data.deterministic, + drop_remainder=exp_config.data.drop_remainder, + num_workers=exp_config.data.num_workers, + rng=data_rng, + ) + logging.rank_zero_info( + "Building dataset %s... DONE!", + datamodule.__class__.__name__, + ) + + logging.rank_zero_info("Building model...") + rng, init_rng = jax.random.split(rng, num=2) + p_model = fdl.build(exp_config.model) + model = p_model( + dtype=exp_config.dtype, + param_dtype=exp_config.param_dtype, + precision=exp_config.precision, + ) + params = model.init(batch=None, rngs=init_rng) # NOTE: use dummy batch + logging.rank_zero_info( + "Building model %s... DONE!", + model.__class__.__name__, + ) + + logging.rank_zero_info("Building train state...") + lr_scheduler = fdl.build(exp_config.optimizer.lr_schedule) + p_optimizer = fdl.build(exp_config.optimizer.optimizer) + tx = p_optimizer(learning_rate=lr_scheduler) + if exp_config.optimizer.grad_clip_method == "norm": + tx = optax.chain( + optax.clip_by_global_norm(exp_config.optimizer.grad_clip_value), + tx, + ) + elif exp_config.optimizer.grad_clip_method == "value": + tx = optax.chain( + optax.clip(exp_config.optimizer.grad_clip_value), + tx, + ) + elif exp_config.optimizer.grad_clip_method is not None: + logging.rank_zero_error( + "Unknown grad clip method: %s", + exp_config.optimizer.grad_clip_method, + ) + return 1 + state = _train_state.TrainState.create( + params=params, + tx=tx, + ema_rate=exp_config.optimizer.ema_rate, + ) + logging.rank_zero_info("Building train state... DONE!") + + checkpoint_manager = checkpoint.MultihostCheckpoint( + os.path.join(log_dir, "checkpoints"), + max_to_keep=max(2, exp_config.trainer.max_checkpoints_to_keep), + ) + if exp_config.trainer.checkpoint_dir is not None: + logging.rank_zero_error("Resuming from checkpoint not implemented.") + return 1 + + if exp_config.mode == "train": + _train.run( + model=model, + state=state, + datamodule=datamodule, + num_train_steps=exp_config.trainer.num_train_steps, + checkpoint_manager=checkpoint_manager, + writer=writer, + work_dir=log_dir, + rng=rng, + log_every_n_steps=exp_config.trainer.log_every_n_steps, + eval_every_n_steps=exp_config.trainer.eval_every_n_steps, + profile=exp_config.trainer.profile, + ) + elif exp_config.mode == "evaluate": + _evaluate.run( + model=model, + datamodule=datamodule, + params=params, + writer=writer, + work_dir=log_dir, + rng=rng, + ) + else: + logging.rank_zero_error("Mode %s not implemented.", exp_config.mode) + return 1 + + return 0 + + +if __name__ == "__main__": + jax.config.config_with_absl() + app.run(main=main) From 0ab6cfc4c1c14f5202f2d655d31f27f1cab65c44 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 05:53:54 -0500 Subject: [PATCH 09/67] feat: Updated configuration for training U-Net meanflow on CIFAR-10 Signed-off-by: Juanwu Lu --- src/projects/generative/BUILD | 47 +++++++++++++++------------ src/projects/generative/config.py | 53 ++++++++++++++++++------------- 2 files changed, 58 insertions(+), 42 deletions(-) diff --git a/src/projects/generative/BUILD b/src/projects/generative/BUILD index e5bf7f3..9397e94 100644 --- a/src/projects/generative/BUILD +++ b/src/projects/generative/BUILD @@ -1,6 +1,6 @@ -load("//learning:defs.bzl", "ml_py_library", "ml_py_test") +load("//third_party:defs.bzl", "ml_py_binary", "ml_py_library") -package(default_visibility = ["//learning:__subpackages__"]) +package(default_visibility = ["//src/projects/generative:__subpackages__"]) ml_py_library( name = "config", @@ -9,35 +9,42 @@ ml_py_library( "fiddle", "optax", ":meanflow", - "//learning/core:config", - "//learning/data:cifar", - "//learning/data:preprocess", + "//src/core:config", + "//src/data:huggingface", + "//src/data:preprocess", ], ) -ml_py_library( - name = "meanflow", - srcs = ["meanflow.py"], +ml_py_binary( + name = "main", + srcs = ["main.py"], deps = [ - "chex", - "flax", + "absl-py", + "clu", + "fiddle", "jax", - "jaxlib", - "jaxtyping", - "typing_extensions", - "//learning/core:mixin", - "//learning/generative/model:refinenet", + "optax", + "tensorflow", + ":config", + "//src/core:config", + "//src/core:evaluate", + "//src/core:train", + "//src/core:train_state", + "//src/utilities:logging", ], ) -ml_py_test( - name = "test_meanflow", - srcs = ["test_meanflow.py"], +ml_py_library( + name = "meanflow", + srcs = ["meanflow.py"], deps = [ "chex", "flax", "jax", - "jaxlib", - ":meanflow", + "jaxtyping", + "typing_extensions", + "//src/core:model", + "//src/core:train_state", + "//src/projects/generative/model:refinenet", ], ) diff --git a/src/projects/generative/config.py b/src/projects/generative/config.py index eb979e6..358d99f 100644 --- a/src/projects/generative/config.py +++ b/src/projects/generative/config.py @@ -3,31 +3,40 @@ import fiddle as fdl import optax -from learning.core import config as _config -from learning.data import cifar -from learning.data import preprocess -from learning.generative import meanflow +from src.core import config as _config +from src.data import huggingface +from src.data import preprocess +from src.projects.generative import meanflow +# ============================================================================== # MeanFlow Models def meanflow_unet_cifar_10() -> _config.ExperimentConfig: return _config.ExperimentConfig( name="meanflow_unet_cifar_10", - data=fdl.Partial( - cifar.CIFAR10DataModule, - preprocess_fn=preprocess.chain( - functools.partial( - preprocess.filter_keys, - keys=["image", "label"], - ), - functools.partial( - preprocess.normalize, - mean=(0.5, 0.5, 0.5), - std=(0.5, 0.5, 0.5), + mode="train", + data=_config.DataConfig( + module=fdl.Partial( + huggingface.CIFAR10DataModule, + resize=32, + transform=preprocess.chain( + functools.partial( + preprocess.filter_keys, + keys=["image", "label"], + ), + functools.partial( + preprocess.normalize, + mean=(0.5, 0.5, 0.5), + std=(0.5, 0.5, 0.5), + ), ), ), + batch_size=1024, + num_workers=4, + deterministic=True, + drop_remainder=True, ), - model=fdl.Config( + model=fdl.Partial( meanflow.MeanFlowUNetModel, in_channels=3, image_size=32, @@ -41,10 +50,10 @@ def meanflow_unet_cifar_10() -> _config.ExperimentConfig: timestamp_overlap_rate=0.25, adaptive_weight_power=0.75, ), - # TODO: implement the warmup in https://arxiv.org/abs/1706.02677 - batch_size=1024, - lr_scheduler=fdl.Config(optax.constant_schedule, value=6e-4), - optimizer=fdl.Partial(optax.adam, b1=0.9, b2=0.999), - ema_rate=0.99995, - num_train_steps=800_000, + trainer=_config.TrainerConfig(), + optimizer=_config.OptimizerConfig( + lr_schedule=fdl.Config(optax.constant_schedule, value=6e-4), + optimizer=fdl.Partial(optax.adam, b1=0.9, b2=0.999), + ), + seed=42, ) From f36fbeb13fd51d78712ad9904e24b50160962573 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 06:09:21 -0500 Subject: [PATCH 10/67] hotfix: Fixed issue with version of `chex` Signed-off-by: Juanwu Lu --- third_party/requirements.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/requirements.in b/third_party/requirements.in index 2305df3..57dd24c 100644 --- a/third_party/requirements.in +++ b/third_party/requirements.in @@ -1,5 +1,5 @@ absl-py==2.3.1 -chex==0.1.91 +chex==0.1.90 clu==0.0.12 datasets==4.4.1 flax==0.10.7 From e9816c81c7b8ad40281278a815852950bbc673b5 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 06:56:16 -0500 Subject: [PATCH 11/67] hotfix: Updated main train logic Signed-off-by: Juanwu Lu --- src/core/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/train.py b/src/core/train.py index c0b43f2..928fa0f 100644 --- a/src/core/train.py +++ b/src/core/train.py @@ -206,13 +206,13 @@ def run( ) # logging on the end of epoch - logging.rank_zero_info("Epoch %d done.", epoch) + logging.rank_zero_info("Epoch %d done.", epoch + 1) scalar_output = { f"train/{k.replace('_', ' ')}_epoch": sum(v) / len(v) for k, v in train_metrics.items() } writer.write_scalars( - step=epoch, + step=epoch + 1, scalars=scalar_output, ) epoch += 1 From c2fb1ba7b1e6244d943c7e0ad1055d139467214d Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 06:57:11 -0500 Subject: [PATCH 12/67] hotfix: Improve the log frequency for meanflow on CIFAR-10 Signed-off-by: Juanwu Lu --- src/projects/generative/config.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/projects/generative/config.py b/src/projects/generative/config.py index 358d99f..13101b2 100644 --- a/src/projects/generative/config.py +++ b/src/projects/generative/config.py @@ -32,7 +32,7 @@ def meanflow_unet_cifar_10() -> _config.ExperimentConfig: ), ), batch_size=1024, - num_workers=4, + num_workers=2, deterministic=True, drop_remainder=True, ), @@ -50,7 +50,12 @@ def meanflow_unet_cifar_10() -> _config.ExperimentConfig: timestamp_overlap_rate=0.25, adaptive_weight_power=0.75, ), - trainer=_config.TrainerConfig(), + trainer=_config.TrainerConfig( + num_train_steps=800_000, + log_every_n_steps=5, + max_checkpoints_to_keep=3, + profile=False, + ), optimizer=_config.OptimizerConfig( lr_schedule=fdl.Config(optax.constant_schedule, value=6e-4), optimizer=fdl.Partial(optax.adam, b1=0.9, b2=0.999), From a2f8fbe8c5fa59046ecf7ebcea86ef8484aed6d9 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 06:57:32 -0500 Subject: [PATCH 13/67] feat: Updated the main logic for training step in MeanFlow Signed-off-by: Juanwu Lu --- src/projects/generative/meanflow.py | 87 ++++++++++++++++------------- 1 file changed, 49 insertions(+), 38 deletions(-) diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 3923b93..494c9d0 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -689,49 +689,59 @@ def training_step( del kwargs # unused local_rng = jax.random.fold_in(rngs, jax.process_index()) - tr_rng = jax.random.fold_in(local_rng, 0) - mask_rng = jax.random.fold_in(local_rng, 1) - e_rng = jax.random.fold_in(local_rng, 2) - - image, label = batch["image"], batch["label"] - batch_dims = image.shape[:-3] - - # step 1: randomly sample begin and end timestamps - t, r = sample_t_r( - key=tr_rng, - shape=batch_dims, - dtype=image.dtype, - distribution=self.timestamp_sampler, - **self.timestamp_sampler_kwargs, - ) - t, r = jnp.maximum(t, r), jnp.minimum(t, r) - # ensure a portion of overlap between t and r - r_neq_t_mask = jnp.greater_equal( - jax.random.uniform( - key=mask_rng, + local_rng = jax.random.fold_in(local_rng, state.step) + + def _loss_fn( + params: PyTree, + batch: typing.Any, + local_rng: jax.Array, + ) -> typing.Tuple[jax.Array, jax.Array]: + tr_rng = jax.random.fold_in(local_rng, 0) + mask_rng = jax.random.fold_in(local_rng, 1) + e_rng = jax.random.fold_in(local_rng, 2) + + image, label = batch["image"], batch["label"] + batch_dims = image.shape[:-3] + + # step 1: randomly sample begin and end timestamps + t, r = sample_t_r( + key=tr_rng, shape=batch_dims, dtype=image.dtype, - minval=0.0, - maxval=1.0, - ), - self.timestamp_overlap_rate, - ) - r = jnp.where(r_neq_t_mask, t, r) + distribution=self.timestamp_sampler, + **self.timestamp_sampler_kwargs, + ) + t, r = jnp.maximum(t, r), jnp.minimum(t, r) + # ensure a portion of overlap between t and r + r_neq_t_mask = jnp.greater_equal( + jax.random.uniform( + key=mask_rng, + shape=batch_dims, + dtype=image.dtype, + minval=0.0, + maxval=1.0, + ), + self.timestamp_overlap_rate, + ) + r = jnp.where(r_neq_t_mask, t, r) - # sample noise e ~ N(0, I) - e = jax.random.normal(key=e_rng, shape=image.shape, dtype=image.dtype) + # sample noise e ~ N(0, I) + e = jax.random.normal( + key=e_rng, + shape=image.shape, + dtype=image.dtype, + ) - # generate intermediate z(t) - z = jnp.add( - (1 - t[..., None, None, None]) * image, - t[..., None, None, None] * e, - ) + # generate intermediate z(t) + z = jnp.add( + (1 - t[..., None, None, None]) * image, + t[..., None, None, None] * e, + ) - # calculate velocity v - v = e - image + # calculate velocity v + v = e - image - # step 2: compute the loss - def _loss_fn(params: PyTree) -> typing.Tuple[jax.Array, jax.Array]: + # # compute the Jaxobian-vector product drdt = jnp.zeros_like(r) dtdt = jnp.ones_like(t) u, dudt = jax.jvp( @@ -767,7 +777,8 @@ def _loss_fn(params: PyTree) -> typing.Tuple[jax.Array, jax.Array]: return loss, velocity_loss grad_fn = jax.value_and_grad(_loss_fn, argnums=0, has_aux=True) - (loss, velocity_loss), grads = grad_fn(state.params) + (loss, velocity_loss), grads = grad_fn(state.params, batch, local_rng) + grads = jax.lax.pmean(grads, axis_name="batch") loss = jax.lax.pmean(loss, axis_name="batch") velocity_loss = jax.lax.pmean(velocity_loss, axis_name="batch") new_state = state.apply_gradients(grads=grads) From 91e96b3586953e47c11b7ce125be6f449511952e Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 10:21:41 -0500 Subject: [PATCH 14/67] feat: Updated implementation of MeanFlow Signed-off-by: Juanwu Lu --- src/projects/generative/config.py | 1 + src/projects/generative/meanflow.py | 152 ++++++---------------------- 2 files changed, 30 insertions(+), 123 deletions(-) diff --git a/src/projects/generative/config.py b/src/projects/generative/config.py index 13101b2..df688f5 100644 --- a/src/projects/generative/config.py +++ b/src/projects/generative/config.py @@ -53,6 +53,7 @@ def meanflow_unet_cifar_10() -> _config.ExperimentConfig: trainer=_config.TrainerConfig( num_train_steps=800_000, log_every_n_steps=5, + eval_every_n_steps=1_000_000, # NOTE: never evaluate now max_checkpoints_to_keep=3, profile=False, ), diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 494c9d0..5d8b178 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -2,7 +2,6 @@ import chex from flax import linen as nn -from flax.core import frozen_dict import jax from jax import numpy as jnp from jax._src import typing as jax_typing @@ -598,72 +597,7 @@ def evaluation_step( **kwargs, ) -> _model.StepOutputs: del kwargs # unused - - local_rng = jax.random.fold_in(rngs, jax.process_index()) - e_rng = jax.random.fold_in(local_rng, 0) - - image, label = batch["image"], batch["label"] - - batch_dims = image.shape[:-3] - dims = chex.Dimensions( - H=self.image_size, - W=self.image_size, - C=self.in_channels, - ) - chex.assert_shape(image, (*batch_dims, *dims["HWC"])) - chex.assert_shape(label, batch_dims) - - r = jnp.zeros(batch_dims, dtype=image.dtype) - t = jnp.ones(batch_dims, dtype=image.dtype) - sample = jnp.subtract( - image, - jnp.einsum( - "...,...n->...n", - (t - r), - self._u_fn( - label=label, - params=params, - deterministic=True, - )(image, r, t), - ), - ) - - drdt = jnp.zeros_like(r) - dtdt = jnp.ones_like(t) - e = jax.random.normal(key=e_rng, shape=image.shape, dtype=image.dtype) - v = e - image - u, dudt = jax.jvp( - self._u_fn( - label=label, - params=params, - deterministic=True, - ), - (e, r, t), - (v, drdt, dtdt), - ) - u_target = jax.lax.stop_gradient( - v - - jnp.clip(t - r, a_min=0.0, a_max=1.0)[..., None, None, None] - * dudt, - ) - loss = jnp.sum(jnp.square(u - u_target), axis=(-1, -2, -3)) - - if self.adaptive_weight_power > 0.0: - ada_wt = jnp.power(loss + 1e-2, self.adaptive_weight_power) - loss = loss / jax.lax.stop_gradient(ada_wt) - loss = jnp.mean(loss) - - velocity_loss = jnp.where( - jnp.equal(t, r)[..., None, None, None], - jnp.square(u - v), - jnp.zeros_like(u), - ) - velocity_loss = jnp.sum(velocity_loss, axis=(-1, -2, -3)).mean() - - return _model.StepOutputs( - scalars={"loss": loss, "velocity_loss": velocity_loss}, - images={"samples": sample}, - ) + raise NotImplementedError("Evaluation not implemented yet.") @typing_extensions.override def predict_step( @@ -741,17 +675,36 @@ def _loss_fn( # calculate velocity v v = e - image - # # compute the Jaxobian-vector product - drdt = jnp.zeros_like(r) - dtdt = jnp.ones_like(t) - u, dudt = jax.jvp( - self._u_fn( + def u_fn( + p: PyTree, + z_t: jax.Array, + r_val: jax.Array, + t_val: jax.Array, + ) -> typing.Any: + if self.timestamp_cond == "t_and_r": + b_arg, e_arg = r_val, t_val + elif self.timestamp_cond == "t_and_t_minus_r": + b_arg, e_arg = t_val - r_val, t_val + elif self.timestamp_cond == "t_minus_r": + b_arg, e_arg = t_val - r_val, None + else: + # Fallback + b_arg, e_arg = t_val - r_val, t_val + + return self.network.apply( + variables={"params": p}, + image=z_t, label=label, - params=params, + begin=b_arg, + end=e_arg, deterministic=False, - ), - (z, r, t), - (v, drdt, dtdt), + ) + + params_tangent = jax.tree_util.tree_map(jnp.zeros_like, params) + u, dudt = jax.jvp( + u_fn, + (params, z, r, t), + (params_tangent, v, jnp.zeros_like(r), jnp.ones_like(t)), ) u_target = jax.lax.stop_gradient( v @@ -789,50 +742,3 @@ def _loss_fn( scalars={"loss": loss, "velocity_loss": velocity_loss}, ), ) - - def _u_fn( - self, - *, - label: jax.Array, - params: frozen_dict.FrozenDict, - deterministic: bool = True, - ) -> typing.Callable[[jax.Array, jax.Array, jax.Array], typing.Any]: - r"""Returns the average velocity function `u(z_t, r, t)`.""" - if self.timestamp_cond == "t_and_r": - return lambda z_t, r, t: self.network.apply( - variables={"params": params}, - image=z_t, - label=label, - begin=r, - end=t, - deterministic=deterministic, - ) - elif self.timestamp_cond == "t_and_t_minus_r": - return lambda z_t, r, t: self.network.apply( - variables={"params": params}, - image=z_t, - label=label, - begin=t - r, - end=t, - deterministic=deterministic, - ) - elif self.timestamp_cond == "t_and_r_and_t_minus_r": - # TODO: implement this - raise NotImplementedError( - "Conditioning on (t, r, t - r) is not implemented yet." - ) - elif self.timestamp_cond == "t_minus_r": - return lambda z_t, r, t: self.network.apply( - variables={"params": params}, - image=z_t, - label=label, - begin=t - r, - end=None, - deterministic=deterministic, - ) - else: - raise ValueError( - f"Unsupported timestamp condition: {self.timestamp_cond}. " - 'Must be one of ["t_and_r", "t_and_t_minus_r", ' - '"t_and_r_and_t_minus_r", "t_minus_r"].' - ) From dd4df9b454c64d523b3e8a6286d521fbba39b814 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 10:40:40 -0500 Subject: [PATCH 15/67] feat: Updated implementation of MeanFlow Signed-off-by: Juanwu Lu --- src/projects/generative/meanflow.py | 235 +++++++++++++++------------- 1 file changed, 129 insertions(+), 106 deletions(-) diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 5d8b178..d2cc3c1 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -377,6 +377,14 @@ class MeanFlowUNetModule(nn.Module): """int: Number of channels in the latent feature maps.""" num_classes: int """int: Number of conditioning classes.""" + timestamp_sampler: str + """str: The distribution to sample timestamps from.""" + timestamp_sampler_kwargs: typing.Dict[str, typing.Any] + """Dict[str, Any]: Keyword arguments for the timestamp sampler.""" + timestamp_overlap_rate: float + """float: The minimum overlap rate between begin and end timestamps.""" + adaptive_weight_power: typing.Optional[float] = None + """Optional[float]: The power for adaptive weight scaling.""" use_cfg_embedding: bool = False """bool: Whether to use classifier-free guidance (CFG) embedding.""" deterministic: typing.Optional[bool] = None @@ -444,7 +452,7 @@ def __call__( end (jax.Array): End timestamp `t` of shape `(*, )`. Returns: - jax.Array: The predicted average velocity of shape `(*, H, W, C)`. + The predicted average velocity of shape `(*, H, W, C)`. """ # sanity check for the input arrays batch_dims = image.shape[:-3] @@ -478,6 +486,102 @@ def __call__( return output + def compute_loss( + self, + image: jax.Array, + label: jax.Array, + ) -> typing.Tuple[jax.Array, jax.Array]: + r"""Compute the `MeanFlow` loss. + + Args: + image (jax.Array): Input images of shape `(*, H, W, C)`. + label (jax.Array): Conditioning labels of shape `(*,)`. + + Returns: + The mean flow loss and velocity loss. + """ + batch_dims = image.shape[:-3] + + # step 1: randomly sample begin and end timestamps + t, r = sample_t_r( + key=self.make_rng("timestamp"), + shape=batch_dims, + dtype=image.dtype, + distribution=self.timestamp_sampler, + **self.timestamp_sampler_kwargs, + ) + t, r = jnp.maximum(t, r), jnp.minimum(t, r) + # ensure a portion of overlap between t and r + r_neq_t_mask = jnp.greater_equal( + jax.random.uniform( + key=self.make_rng("mask"), + shape=batch_dims, + dtype=image.dtype, + minval=0.0, + maxval=1.0, + ), + self.timestamp_overlap_rate, + ) + r = jnp.where(r_neq_t_mask, t, r) + + # sample noise e ~ N(0, I) + e = jax.random.normal( + key=self.make_rng("noise"), + shape=image.shape, + dtype=image.dtype, + ) + + # generate intermediate z(t) + z = jnp.add( + (1 - t[..., None, None, None]) * image, + t[..., None, None, None] * e, + ) + + # calculate velocity v + v = e - image + + def u_fn( + z_t: jax.Array, + r_val: jax.Array, + t_val: jax.Array, + ) -> typing.Any: + b_arg, e_arg = t_val - r_val, t_val + + return self( + image=z_t, + label=label, + begin=b_arg, + end=e_arg, + deterministic=False, + ) + + u, dudt = jax.jvp( + u_fn, + (z, r, t), + (v, jnp.zeros_like(r), jnp.ones_like(t)), + ) + u_target = jax.lax.stop_gradient( + v + - jnp.clip(t - r, a_min=0.0, a_max=1.0)[..., None, None, None] + * dudt + ) + + # step 3: compute the loss + loss = jnp.sum(jnp.square(u - u_target), axis=(-1, -2, -3)) + if self.adaptive_weight_power is not None: + ada_wt = jnp.power(loss + 1e-2, self.adaptive_weight_power) + loss = loss / jax.lax.stop_gradient(ada_wt) + loss = jnp.mean(loss) + + velocity_loss = jnp.where( + jnp.equal(t, r)[..., None, None, None], + jnp.square(u - v), + jnp.zeros_like(u), + ) + velocity_loss = jnp.sum(velocity_loss, axis=(-1, -2, -3)).mean() + + return loss, velocity_loss + class MeanFlowUNetModel(_model.Model): r"""`MeanFlow` generative model with a U-Net backbone. @@ -532,15 +636,15 @@ def __init__( self.in_channels = in_channels self.image_size = image_size self.timestamp_cond = timestamp_cond - self.timestamp_sampler = timestamp_sampler - self.timestamp_sampler_kwargs = timestamp_sampler_kwargs - self.timestamp_overlap_rate = timestamp_overlap_rate - self.adaptive_weight_power = adaptive_weight_power self._network = MeanFlowUNetModule( in_channels=in_channels, image_size=image_size, latent_channels=latent_channels, num_classes=num_classes, + timestamp_sampler=timestamp_sampler, + timestamp_sampler_kwargs=timestamp_sampler_kwargs, + timestamp_overlap_rate=timestamp_overlap_rate, + adaptive_weight_power=adaptive_weight_power, use_cfg_embedding=use_cfg_embedding, dropout_rate=dropout_rate, name="unet", @@ -625,112 +729,31 @@ def training_step( local_rng = jax.random.fold_in(rngs, jax.process_index()) local_rng = jax.random.fold_in(local_rng, state.step) - def _loss_fn( - params: PyTree, - batch: typing.Any, - local_rng: jax.Array, - ) -> typing.Tuple[jax.Array, jax.Array]: - tr_rng = jax.random.fold_in(local_rng, 0) - mask_rng = jax.random.fold_in(local_rng, 1) - e_rng = jax.random.fold_in(local_rng, 2) - - image, label = batch["image"], batch["label"] - batch_dims = image.shape[:-3] - - # step 1: randomly sample begin and end timestamps - t, r = sample_t_r( - key=tr_rng, - shape=batch_dims, - dtype=image.dtype, - distribution=self.timestamp_sampler, - **self.timestamp_sampler_kwargs, - ) - t, r = jnp.maximum(t, r), jnp.minimum(t, r) - # ensure a portion of overlap between t and r - r_neq_t_mask = jnp.greater_equal( - jax.random.uniform( - key=mask_rng, - shape=batch_dims, - dtype=image.dtype, - minval=0.0, - maxval=1.0, - ), - self.timestamp_overlap_rate, - ) - r = jnp.where(r_neq_t_mask, t, r) - - # sample noise e ~ N(0, I) - e = jax.random.normal( - key=e_rng, - shape=image.shape, - dtype=image.dtype, - ) - - # generate intermediate z(t) - z = jnp.add( - (1 - t[..., None, None, None]) * image, - t[..., None, None, None] * e, - ) - - # calculate velocity v - v = e - image - - def u_fn( - p: PyTree, - z_t: jax.Array, - r_val: jax.Array, - t_val: jax.Array, - ) -> typing.Any: - if self.timestamp_cond == "t_and_r": - b_arg, e_arg = r_val, t_val - elif self.timestamp_cond == "t_and_t_minus_r": - b_arg, e_arg = t_val - r_val, t_val - elif self.timestamp_cond == "t_minus_r": - b_arg, e_arg = t_val - r_val, None - else: - # Fallback - b_arg, e_arg = t_val - r_val, t_val - - return self.network.apply( - variables={"params": p}, - image=z_t, - label=label, - begin=b_arg, - end=e_arg, - deterministic=False, - ) - - params_tangent = jax.tree_util.tree_map(jnp.zeros_like, params) - u, dudt = jax.jvp( - u_fn, - (params, z, r, t), - (params_tangent, v, jnp.zeros_like(r), jnp.ones_like(t)), - ) - u_target = jax.lax.stop_gradient( - v - - jnp.clip(t - r, a_min=0.0, a_max=1.0)[..., None, None, None] - * dudt - ) - - # step 3: compute the loss - loss = jnp.sum(jnp.square(u - u_target), axis=(-1, -2, -3)) - if self.adaptive_weight_power > 0.0: - ada_wt = jnp.power(loss + 1e-2, self.adaptive_weight_power) - loss = loss / jax.lax.stop_gradient(ada_wt) - loss = jnp.mean(loss) - - # calculate velocity loss for monitoring - velocity_loss = jnp.where( - jnp.equal(t, r)[..., None, None, None], - jnp.square(u - v), - jnp.zeros_like(u), + tr_rng = jax.random.fold_in(local_rng, 0) + mask_rng = jax.random.fold_in(local_rng, 1) + e_rng = jax.random.fold_in(local_rng, 2) + + image, label = batch["image"], batch["label"] + + def _loss_fn(params: PyTree) -> typing.Tuple[jax.Array, jax.Array]: + loss, velocity_loss = self.network.apply( + variables={"params": params}, + image=image, + label=label, + rngs={ + "timestamp": tr_rng, + "mask": mask_rng, + "noise": e_rng, + }, + method=self.network.compute_loss, ) - velocity_loss = jnp.sum(velocity_loss, axis=(-1, -2, -3)).mean() + assert isinstance(loss, jax.Array) + assert isinstance(velocity_loss, jax.Array) return loss, velocity_loss grad_fn = jax.value_and_grad(_loss_fn, argnums=0, has_aux=True) - (loss, velocity_loss), grads = grad_fn(state.params, batch, local_rng) + (loss, velocity_loss), grads = grad_fn(state.params) grads = jax.lax.pmean(grads, axis_name="batch") loss = jax.lax.pmean(loss, axis_name="batch") velocity_loss = jax.lax.pmean(velocity_loss, axis_name="batch") From d585b664df38afb4ac58bd260eb5559a98242564 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 14:51:46 -0500 Subject: [PATCH 16/67] feat: Added checkpoint frequency attribute to trainer config Signed-off-by: Juanwu Lu --- src/core/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/core/config.py b/src/core/config.py index 6084b57..e7544c1 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -45,6 +45,8 @@ class TrainerConfig: Attributes: num_train_steps (int): Total number of training steps. + checkpoint_every_n_steps (Optional[int]): Frequency of checkpointing. + If `None`, defaults to `eval_every_n_steps`. log_every_n_steps (int): Frequency of logging training metrics. eval_every_n_steps (int): Frequency of evaluation during training. checkpoint_dir (Optional[str]): Directory of checkpoint to resume from. @@ -53,6 +55,7 @@ class TrainerConfig: """ num_train_steps: int = 10_000 + checkpoint_every_n_steps: typing.Optional[int] = None log_every_n_steps: int = 50 eval_every_n_steps: int = 1_000 checkpoint_dir: typing.Optional[str] = None From 41223e057b84e04ebc6645e90e05731c5b2e6e9a Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 14:52:05 -0500 Subject: [PATCH 17/67] feat: Updated implementation of train logic Signed-off-by: Juanwu Lu --- src/core/train.py | 57 ++++++++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/src/core/train.py b/src/core/train.py index 928fa0f..eb7f43d 100644 --- a/src/core/train.py +++ b/src/core/train.py @@ -1,15 +1,17 @@ import collections import functools +import os +import traceback import typing -from clu import checkpoint from clu import metric_writers from clu import periodic_actions from flax import jax_utils +from flax.training import checkpoints import jax import jaxtyping -from src.core import datamodule as _datamodule +from src.core import datamodule as _data from src.core import model as _model from src.core import train_state as _train_state from src.utilities import logging @@ -55,9 +57,8 @@ def _shard(tree: jaxtyping.PyTree) -> jaxtyping.PyTree: def run( model: _model.Model, state: _train_state.TrainState, - datamodule: _datamodule.DataModule, + datamodule: _data.DataModule, num_train_steps: int, - checkpoint_manager: checkpoint.Checkpoint, writer: metric_writers.MetricWriter, work_dir: str, rng: typing.Any, @@ -110,7 +111,7 @@ def run( num_profile_steps=5, ) ) - step, epoch = state.step, 0 + step = state.step state = jax_utils.replicate(state) logging.rank_zero_info("Training...") with metric_writers.ensure_flushes(writer): @@ -135,25 +136,23 @@ def run( if outputs.scalars is not None: for k, v in outputs.scalars.items(): train_metrics[k].append(jax.device_get(v).mean()) - step += 1 for hook in hooks: hook(step) if step % log_every_n_steps == 0: if outputs.scalars is not None: - scalar_output = { - f"train/{k.replace('_', ' ')}_step": sum(v) - / len(v) - for k, v in outputs.scalars.items() - } writer.write_scalars( step=step, - scalars=scalar_output, + scalars={ + f"train/{k}_step": sum(v) / len(v) + for k, v in outputs.scalars.items() + }, ) if outputs.images is not None: writer.write_images( step=step, images=outputs.images, ) + step += 1 # evaluation if ( @@ -182,7 +181,7 @@ def run( writer.write_scalars( step=step, scalars={ - f"eval/{k.replace('_', ' ')}": sum(v) / len(v) + f"eval/{k}": sum(v) / len(v) for k, v in eval_metrics.items() }, ) @@ -195,32 +194,40 @@ def run( # checkpointing if step % checkpoint_every_n_steps == 0: logging.rank_zero_info("Checkpointing...") - # TODO (juanwulu): resolve the error (no __enter__) - with report_progress.timed("checkpoint"): - filepath = checkpoint_manager.save( - state=jax_utils.unreplicate(state) + if jax.process_index() == 0: + with report_progress.timed("checkpoint"): + filepath = checkpoints.save_checkpoint( + ckpt_dir=os.path.join( + work_dir, + "checkpoints", + ), + target=jax_utils.unreplicate(state), + keep=3, + overwrite=True, + prefix="ckpt-", + step=step, + ) + logging.rank_zero_info( + "Checkpoint saved to %s", + filepath, ) - logging.rank_zero_info( - "Checkpoint saved to %s", - filepath, - ) # logging on the end of epoch - logging.rank_zero_info("Epoch %d done.", epoch + 1) scalar_output = { - f"train/{k.replace('_', ' ')}_epoch": sum(v) / len(v) + f"train/{k}_epoch": sum(v) / len(v) for k, v in train_metrics.items() } writer.write_scalars( - step=epoch + 1, + step=step, scalars=scalar_output, ) - epoch += 1 except Exception as e: logging.rank_zero_error( "Exception occurred during training: %s", e ) + error_trace = traceback.format_exc() + logging.rank_zero_error(error_trace) _status = 1 finally: state = jax_utils.unreplicate(state) From 4357a025595b4b79130fd9387e58bf3cfd8c0664 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 14:53:45 -0500 Subject: [PATCH 18/67] feat: Updated the main logic for training step in MeanFlow Signed-off-by: Juanwu Lu --- src/projects/generative/main.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/projects/generative/main.py b/src/projects/generative/main.py index df6b0a6..4201a18 100644 --- a/src/projects/generative/main.py +++ b/src/projects/generative/main.py @@ -5,7 +5,6 @@ from absl import app from absl import flags -from clu import checkpoint from clu import metric_writers from clu import platform as clu_platform from fiddle import absl_flags @@ -144,10 +143,6 @@ def main(_: typing.List[str]) -> int: ) logging.rank_zero_info("Building train state... DONE!") - checkpoint_manager = checkpoint.MultihostCheckpoint( - os.path.join(log_dir, "checkpoints"), - max_to_keep=max(2, exp_config.trainer.max_checkpoints_to_keep), - ) if exp_config.trainer.checkpoint_dir is not None: logging.rank_zero_error("Resuming from checkpoint not implemented.") return 1 @@ -158,10 +153,10 @@ def main(_: typing.List[str]) -> int: state=state, datamodule=datamodule, num_train_steps=exp_config.trainer.num_train_steps, - checkpoint_manager=checkpoint_manager, writer=writer, work_dir=log_dir, rng=rng, + checkpoint_every_n_steps=exp_config.trainer.checkpoint_every_n_steps, log_every_n_steps=exp_config.trainer.log_every_n_steps, eval_every_n_steps=exp_config.trainer.eval_every_n_steps, profile=exp_config.trainer.profile, From b7d196790ae248386af120a6ab037ae7a5f3a235 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 17:25:57 -0500 Subject: [PATCH 19/67] feat: Updated the model protocol Signed-off-by: Juanwu Lu --- src/core/model.py | 73 ++++++++++++++++------------------------------- 1 file changed, 25 insertions(+), 48 deletions(-) diff --git a/src/core/model.py b/src/core/model.py index cf09dfa..53169ad 100644 --- a/src/core/model.py +++ b/src/core/model.py @@ -2,11 +2,10 @@ import typing import chex -from flax import struct +from flax.core import frozen_dict +import jax import jaxtyping -from src.core import train_state as _train_state - @chex.dataclass class StepOutputs: @@ -51,67 +50,45 @@ def init( pass @abc.abstractmethod - def training_step( + def compute_loss( self, *, - state: _train_state.TrainState, - batch: typing.Any, - rngs: typing.Union[typing.Any, typing.Dict[str, typing.Any]], + rngs: typing.Any, + deterministic: bool = False, + params: frozen_dict.FrozenDict, **kwargs, - ) -> typing.Tuple[struct.PyTreeNode, StepOutputs]: - r"""Performs a single training step. + ) -> typing.Tuple[jax.Array, StepOutputs]: + """Computes the loss given parameters and model inputs. Args: - state (TrainState): The current training state. - batch (Any): A batch of data. - rngs (Union[Any, Dict[str, Any]]): Random generators. - **kwargs: Additional keyword arguments. + deterministic (bool): Whether to run the model in deterministic + mode (e.g., disable dropout). Default is `False`. + params (FrozenDict): The model parameters. + **kwargs: Keyword arguments consumed by the model. Returns: - A tuple containing the updated state and step outputs. + A dictionary containing the loss and other outputs. """ - pass + raise NotImplementedError @abc.abstractmethod - def evaluation_step( + def forward( self, *, - params: jaxtyping.PyTree, - batch: typing.Any, - rngs: typing.Union[typing.Any, typing.Dict[str, typing.Any]], + rngs: typing.Any, + deterministic: bool = True, + params: frozen_dict.FrozenDict, **kwargs, ) -> StepOutputs: - r"""Performs a single evaluation step. - - Args: - params (PyTree): The model parameters. - batch (Any): A batch of data. - rngs (Union[Any, Dict[str, Any]]): Random generators. - **kwargs: Additional keyword arguments. - - Returns: - The step outputs containing evaluation metrics. - """ - pass - - @abc.abstractmethod - def predict_step( - self, - *, - params: jaxtyping.PyTree, - batch: typing.Any, - rngs: typing.Union[typing.Any, typing.Dict[str, typing.Any]], - **kwargs, - ) -> typing.Any: - r"""Performs a single prediction step during inference. + """Forward pass the model and returns the output tree structure. Args: - params (PyTree): The model parameters. - batch (Any): A batch of data. - rngs (Union[Any, Dict[str, Any]]): Random generators. - **kwargs: Additional keyword arguments. + deterministic (bool): Whether to run the model in deterministic + mode (e.g., disable dropout). Default is `True`. + params (FrozenDict): The model parameters. + **kwargs: Keyword arguments consumed by the model. Returns: - The model's predictions. + The model outputs. """ - pass + raise NotImplementedError From 60db887b60bb6f57ba6124d51b3563b1d950fbcb Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 17:26:18 -0500 Subject: [PATCH 20/67] feat: Updated the training logic Signed-off-by: Juanwu Lu --- src/core/train.py | 44 +++++++++++++++++++------------------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/src/core/train.py b/src/core/train.py index eb7f43d..969b9fc 100644 --- a/src/core/train.py +++ b/src/core/train.py @@ -16,22 +16,8 @@ from src.core import train_state as _train_state from src.utilities import logging - -def _create_step_fn( - model: _model.Model, - rng: typing.Any, -) -> typing.Tuple[jax.Array, typing.Callable, typing.Callable]: - """Creates the step functions for training and evaluation.""" - # create training step function - rng, train_rng = jax.random.split(rng, num=2) - p_training_step = functools.partial(model.training_step, rngs=train_rng) - p_training_step = jax.pmap(p_training_step, axis_name="batch") - - rng, eval_rng = jax.random.split(rng, num=2) - p_evaluation_step = functools.partial(model.evaluation_step, rngs=eval_rng) - p_evaluation_step = jax.pmap(p_evaluation_step, axis_name="batch") - - return rng, p_training_step, p_evaluation_step +EVAL_STEP_OUTPUT = _model.StepOutputs +TRAIN_STEP_OUTPUT = typing.Tuple[_train_state.TrainState, _model.StepOutputs] def _shard(tree: jaxtyping.PyTree) -> jaxtyping.PyTree: @@ -55,9 +41,10 @@ def _shard(tree: jaxtyping.PyTree) -> jaxtyping.PyTree: def run( - model: _model.Model, state: _train_state.TrainState, datamodule: _data.DataModule, + training_step: typing.Callable[..., TRAIN_STEP_OUTPUT], + evaluation_step: typing.Callable[..., EVAL_STEP_OUTPUT], num_train_steps: int, writer: metric_writers.MetricWriter, work_dir: str, @@ -70,9 +57,9 @@ def run( """Runs training and evaluation loop with given model and dataloaders. Args: - model (Model): The model to run. - train_dataloader (Any): The training dataloaders. - eval_dataloader (Any): The evaluation dataloaders. + datamodule (DataModule): The data module for loading data. + training_step (Callable): The training step function. + evaluation_step (Callable): The evaluation step function. num_train_steps (int): Number of training steps. checkpoint_manager (Checkpoint): The checkpoint manager. writer (MetricWriter): The metric writer for logging. @@ -88,14 +75,21 @@ def run( Integer status code. """ _status = 0 - logging.rank_zero_debug(f"running {model.__class__.__name__} fit stage...") if checkpoint_every_n_steps is None: checkpoint_every_n_steps = eval_every_n_steps - rng, p_training_step, p_evaluation_step = _create_step_fn( - model=model, - rng=rng, - ) + + logging.rank_zero_info("Compiling training step function...") + rng, train_rng = jax.random.split(rng, num=2) + p_training_step = functools.partial(training_step, rngs=train_rng) + p_training_step = jax.pmap(p_training_step, axis_name="batch") + logging.rank_zero_info("Compiling training step function... DONE!") + + logging.rank_zero_info("Compiling evaluation step function...") + rng, eval_rng = jax.random.split(rng, num=2) + p_evaluation_step = functools.partial(evaluation_step, rngs=eval_rng) + p_evaluation_step = jax.pmap(p_evaluation_step, axis_name="batch") + logging.rank_zero_info("Compiling evaluation step function... DONE!") hooks = [] report_progress = periodic_actions.ReportProgress( From c435f024fa9d89cd8aba6f9a5d311a6a9be16e1a Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 17:26:52 -0500 Subject: [PATCH 21/67] feat: Implemented the new model protocol for MeanFlow Signed-off-by: Juanwu Lu --- src/core/BUILD | 2 +- src/projects/generative/meanflow.py | 316 +++++++++++++--------------- 2 files changed, 149 insertions(+), 169 deletions(-) diff --git a/src/core/BUILD b/src/core/BUILD index 168245e..623a926 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -36,8 +36,8 @@ ml_py_library( deps = [ "chex", "flax", + "jax", "jaxtyping", - ":train_state", ], ) diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index d2cc3c1..5013190 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -2,6 +2,7 @@ import chex from flax import linen as nn +from flax.core import frozen_dict import jax from jax import numpy as jnp from jax._src import typing as jax_typing @@ -9,7 +10,6 @@ import typing_extensions from src.core import model as _model -from src.core import train_state as _train_state from src.projects.generative.model import refinenet # Type Aliases @@ -377,14 +377,6 @@ class MeanFlowUNetModule(nn.Module): """int: Number of channels in the latent feature maps.""" num_classes: int """int: Number of conditioning classes.""" - timestamp_sampler: str - """str: The distribution to sample timestamps from.""" - timestamp_sampler_kwargs: typing.Dict[str, typing.Any] - """Dict[str, Any]: Keyword arguments for the timestamp sampler.""" - timestamp_overlap_rate: float - """float: The minimum overlap rate between begin and end timestamps.""" - adaptive_weight_power: typing.Optional[float] = None - """Optional[float]: The power for adaptive weight scaling.""" use_cfg_embedding: bool = False """bool: Whether to use classifier-free guidance (CFG) embedding.""" deterministic: typing.Optional[bool] = None @@ -443,7 +435,7 @@ def __call__( end: typing.Optional[jax.Array] = None, deterministic: typing.Optional[bool] = None, ) -> jax.Array: - """Forward pass the `MeanFlowUNetModel`. + r"""Forward pass the `MeanFlowUNetModel`. Args: inputs (jax.Array): Input images of shape `(*, H, W, C)`. @@ -486,102 +478,6 @@ def __call__( return output - def compute_loss( - self, - image: jax.Array, - label: jax.Array, - ) -> typing.Tuple[jax.Array, jax.Array]: - r"""Compute the `MeanFlow` loss. - - Args: - image (jax.Array): Input images of shape `(*, H, W, C)`. - label (jax.Array): Conditioning labels of shape `(*,)`. - - Returns: - The mean flow loss and velocity loss. - """ - batch_dims = image.shape[:-3] - - # step 1: randomly sample begin and end timestamps - t, r = sample_t_r( - key=self.make_rng("timestamp"), - shape=batch_dims, - dtype=image.dtype, - distribution=self.timestamp_sampler, - **self.timestamp_sampler_kwargs, - ) - t, r = jnp.maximum(t, r), jnp.minimum(t, r) - # ensure a portion of overlap between t and r - r_neq_t_mask = jnp.greater_equal( - jax.random.uniform( - key=self.make_rng("mask"), - shape=batch_dims, - dtype=image.dtype, - minval=0.0, - maxval=1.0, - ), - self.timestamp_overlap_rate, - ) - r = jnp.where(r_neq_t_mask, t, r) - - # sample noise e ~ N(0, I) - e = jax.random.normal( - key=self.make_rng("noise"), - shape=image.shape, - dtype=image.dtype, - ) - - # generate intermediate z(t) - z = jnp.add( - (1 - t[..., None, None, None]) * image, - t[..., None, None, None] * e, - ) - - # calculate velocity v - v = e - image - - def u_fn( - z_t: jax.Array, - r_val: jax.Array, - t_val: jax.Array, - ) -> typing.Any: - b_arg, e_arg = t_val - r_val, t_val - - return self( - image=z_t, - label=label, - begin=b_arg, - end=e_arg, - deterministic=False, - ) - - u, dudt = jax.jvp( - u_fn, - (z, r, t), - (v, jnp.zeros_like(r), jnp.ones_like(t)), - ) - u_target = jax.lax.stop_gradient( - v - - jnp.clip(t - r, a_min=0.0, a_max=1.0)[..., None, None, None] - * dudt - ) - - # step 3: compute the loss - loss = jnp.sum(jnp.square(u - u_target), axis=(-1, -2, -3)) - if self.adaptive_weight_power is not None: - ada_wt = jnp.power(loss + 1e-2, self.adaptive_weight_power) - loss = loss / jax.lax.stop_gradient(ada_wt) - loss = jnp.mean(loss) - - velocity_loss = jnp.where( - jnp.equal(t, r)[..., None, None, None], - jnp.square(u - v), - jnp.zeros_like(u), - ) - velocity_loss = jnp.sum(velocity_loss, axis=(-1, -2, -3)).mean() - - return loss, velocity_loss - class MeanFlowUNetModel(_model.Model): r"""`MeanFlow` generative model with a U-Net backbone. @@ -636,15 +532,15 @@ def __init__( self.in_channels = in_channels self.image_size = image_size self.timestamp_cond = timestamp_cond + self.timestamp_sampler = timestamp_sampler + self.timestamp_sampler_kwargs = timestamp_sampler_kwargs + self.timestamp_overlap_rate = timestamp_overlap_rate + self.adaptive_weight_power = adaptive_weight_power self._network = MeanFlowUNetModule( in_channels=in_channels, image_size=image_size, latent_channels=latent_channels, num_classes=num_classes, - timestamp_sampler=timestamp_sampler, - timestamp_sampler_kwargs=timestamp_sampler_kwargs, - timestamp_overlap_rate=timestamp_overlap_rate, - adaptive_weight_power=adaptive_weight_power, use_cfg_embedding=use_cfg_embedding, dropout_rate=dropout_rate, name="unet", @@ -692,76 +588,160 @@ def init( return variables["params"] @typing_extensions.override - def evaluation_step( + def compute_loss( self, *, - params: PyTree, - batch: typing.Any, rngs: typing.Any, + image: jax.Array, + label: jax.Array, + params: frozen_dict.FrozenDict, + deterministic: bool = False, **kwargs, - ) -> _model.StepOutputs: - del kwargs # unused - raise NotImplementedError("Evaluation not implemented yet.") + ) -> typing.Tuple[jax.Array, _model.StepOutputs]: + r"""Computes the loss given parameters and model inputs. - @typing_extensions.override - def predict_step( - self, - *, - params: jaxtyping.PyTree, - batch: typing.Any, - rngs: typing.Any, - **kwargs, - ) -> typing.Any: - # TODO (juanwulu): implement predict step - raise NotImplementedError("Predict step is not implemented yet.") + Args: + rngs (Union[jax.random.KeyArray, Dict[str, jax.random.KeyArray]]): + JAX random key or a dictionary of JAX random keys. + image (jax.Array): The input images of shape `(*, H, W, C)`. + label (jax.Array): The class labels of shape `(*,)`. + params (frozen_dict.FrozenDict): The model parameters. + deterministic (bool): Whether to run the model deterministically. + **kwargs: additional keyword arguments. - @typing_extensions.override - def training_step( - self, - *, - state: _train_state.TrainState, - batch: typing.Any, - rngs: typing.Any, - **kwargs, - ) -> typing.Tuple[_train_state.TrainState, _model.StepOutputs]: + Returns: + The computed loss and other outputs. + """ del kwargs # unused - local_rng = jax.random.fold_in(rngs, jax.process_index()) - local_rng = jax.random.fold_in(local_rng, state.step) + # NOTE: following the notation in Algorithm 1 of the source paper + # sample t and r + batch_dims = image.shape[:-3] + tr_rng, mask_rng, e_rng = jax.random.split(rngs, num=3) + t, r = sample_t_r( + key=tr_rng, + shape=batch_dims, + dtype=image.dtype, + distribution=self.timestamp_sampler, + **self.timestamp_sampler_kwargs, + ) + t, r = jnp.maximum(t, r), jnp.minimum(t, r) + # ensure a portion of overlap between t and r + r_neq_t_mask = jnp.greater_equal( + jax.random.uniform( + key=mask_rng, + shape=batch_dims, + dtype=image.dtype, + minval=0.0, + maxval=1.0, + ), + self.timestamp_overlap_rate, + ) + r = jnp.where(r_neq_t_mask, t, r) - tr_rng = jax.random.fold_in(local_rng, 0) - mask_rng = jax.random.fold_in(local_rng, 1) - e_rng = jax.random.fold_in(local_rng, 2) + # sample e ~ N(0, I) + e = jax.random.normal(key=e_rng, shape=image.shape, dtype=image.dtype) - image, label = batch["image"], batch["label"] + # generate z_{t} + z = jnp.add( + (1 - t[..., None, None, None]) * image, + t[..., None, None, None] * e, + ) + v = e - image - def _loss_fn(params: PyTree) -> typing.Tuple[jax.Array, jax.Array]: - loss, velocity_loss = self.network.apply( + # applies Jacobian vector product + def u_fn( + z_t: jax.Array, + r_in: jax.Array, + t_in: jax.Array, + ) -> jax.Array: + if self.timestamp_cond == "t_and_r": + b_arg, e_arg = r_in, t_in + elif self.timestamp_cond == "t_and_t_minus_r": + b_arg, e_arg = t_in - r_in, t_in + elif self.timestamp_cond == "t_and_r_and_t_minus_r": + raise NotImplementedError( + "`t_and_r_and_t_minus_r` conditioning is not implemented." + ) + elif self.timestamp_cond == "t_minus_r": + b_arg, e_arg = t_in - r_in, None + else: + raise ValueError( + f"Unsupported timestamp conditioning: {self.timestamp_cond}." + ) + + out = self.network.apply( variables={"params": params}, - image=image, + image=z_t, label=label, - rngs={ - "timestamp": tr_rng, - "mask": mask_rng, - "noise": e_rng, - }, - method=self.network.compute_loss, + begin=b_arg, + end=e_arg, + deterministic=deterministic, ) - assert isinstance(loss, jax.Array) - assert isinstance(velocity_loss, jax.Array) - - return loss, velocity_loss - - grad_fn = jax.value_and_grad(_loss_fn, argnums=0, has_aux=True) - (loss, velocity_loss), grads = grad_fn(state.params) - grads = jax.lax.pmean(grads, axis_name="batch") - loss = jax.lax.pmean(loss, axis_name="batch") - velocity_loss = jax.lax.pmean(velocity_loss, axis_name="batch") - new_state = state.apply_gradients(grads=grads) - - return ( - new_state, - _model.StepOutputs( - scalars={"loss": loss, "velocity_loss": velocity_loss}, - ), + assert isinstance(out, jax.Array) + + return out + + drdt = jnp.zeros_like(r) + dtdt = jnp.ones_like(t) + u, dudt = jax.jvp(u_fn, (z, r, t), (v, drdt, dtdt)) + + # computes the target + u_target = jax.lax.stop_gradient( + v + - jnp.clip(t - r, a_min=0.0, a_max=1.0)[..., None, None, None] + * dudt ) + # NOTE: sum over all the pixels, following official implementation + loss = jnp.sum(jnp.square(u - u_target), axis=(-1, -2, -3)) + + # applies adaptive weight power + if self.adaptive_weight_power > 0.0: + ada_wt = jnp.power(loss + 1e-2, self.adaptive_weight_power) + loss = loss / jax.lax.stop_gradient(ada_wt) + loss = jnp.mean(loss) + + # calculate velocity loss for monitoring + velocity_loss = jnp.where( + jnp.equal(t, r)[..., None, None, None], + jnp.square(u - v), + jnp.zeros_like(u), + ) + velocity_loss = jnp.sum(velocity_loss, axis=(-1, -2, -3)).mean() + + out = _model.StepOutputs( + scalars={"loss": loss, "velocity_loss": velocity_loss}, + ) + + return loss, out + + @typing_extensions.override + def forward( + self, + *, + rngs: typing.Any, + params: frozen_dict.FrozenDict, + image: jax.Array, + label: jax.Array, + begin: typing.Optional[jax.Array] = None, + end: typing.Optional[jax.Array] = None, + deterministic: bool = False, + **kwargs, + ) -> _model.StepOutputs: + r"""Forward sampling with average velocity prediction. + + Args: + params (frozen_dict.FrozenDict): The model parameters. + image (jax.Array): Input latent image `z_t` of shape `(*, H, W, C)`. + label (jax.Array): Conditioning labels of shape `(*,)`. + begin (jax.Array): Begin timestamp `r` of shape `(*, )`. + end (jax.Array): End timestamp `t` of shape `(*, )`. + deterministic (bool): Whether to run the model deterministically. + **kwargs: Additional keyword arguments. + + Returns: + The output samples. + """ + del kwargs # unused + + raise NotImplementedError From 3b12acbaee845554c71535452a8e443a905aaae7 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 17:27:14 -0500 Subject: [PATCH 22/67] feat: Implemented the new main entrypoint with train logic Signed-off-by: Juanwu Lu --- src/projects/generative/main.py | 42 ++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/src/projects/generative/main.py b/src/projects/generative/main.py index 4201a18..123194f 100644 --- a/src/projects/generative/main.py +++ b/src/projects/generative/main.py @@ -1,4 +1,5 @@ from datetime import datetime +import functools import os import platform import typing @@ -10,11 +11,13 @@ from fiddle import absl_flags import fiddle as fdl import jax +import jaxtyping import optax import tensorflow as tf from src.core import config as _config from src.core import evaluate as _evaluate +from src.core import model as _model from src.core import train as _train from src.core import train_state as _train_state from src.utilities import logging @@ -32,6 +35,7 @@ help="Directory to store the experiment results.", required=True, ) +PyTree = jaxtyping.PyTree # toggle off GPU/TPU for TensorFlow @@ -40,6 +44,40 @@ assert not tf.config.experimental.get_visible_devices("GPU") +def evaluation_step(rngs: jax.Array) -> _model.StepOutputs: + r"""Conduct a single evaluation step and compute metrics.""" + raise NotImplementedError + + +def training_step( + rngs: jax.Array, + model: _model.Model, + state: _train_state.TrainState, + batch: typing.Dict[str, typing.Any], + **kwargs, +) -> typing.Tuple[_train_state.TrainState, _model.StepOutputs]: + r"""Conduct a single training step and update train state.""" + local_rng = jax.random.fold_in(rngs, state.step) + local_rng = jax.random.fold_in(local_rng, jax.lax.axis_index("batch")) + + def loss_fn(params: PyTree) -> typing.Tuple[jax.Array, _model.StepOutputs]: + loss, outputs = model.compute_loss( + rngs=local_rng, + params=params, + deterministic=False, + **batch, + **kwargs, + ) + return loss, outputs + + grad_fn = jax.value_and_grad(loss_fn, argnums=0, has_aux=True) + (_, outputs), grads = grad_fn(state.params) + grads = jax.lax.pmean(grads, axis_name="batch") + new_state = state.apply_gradients(grads=grads) + + return new_state, outputs + + def main(_: typing.List[str]) -> int: r"""Main entry point for training and evaluate generative models.""" del _ # unused. @@ -148,10 +186,12 @@ def main(_: typing.List[str]) -> int: return 1 if exp_config.mode == "train": + p_training_step = functools.partial(training_step, model=model) _train.run( - model=model, state=state, datamodule=datamodule, + training_step=p_training_step, + evaluation_step=evaluation_step, num_train_steps=exp_config.trainer.num_train_steps, writer=writer, work_dir=log_dir, From cbfaa3c1b1fc93ce6bf355e59bc1ffa90d59ee34 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 17:54:47 -0500 Subject: [PATCH 23/67] feat: Updated implementation for huggingface dataset Signed-off-by: Juanwu Lu --- src/data/huggingface.py | 462 ++++++++++++++++++++++------------------ 1 file changed, 258 insertions(+), 204 deletions(-) diff --git a/src/data/huggingface.py b/src/data/huggingface.py index a445412..bd89772 100644 --- a/src/data/huggingface.py +++ b/src/data/huggingface.py @@ -1,6 +1,8 @@ import abc import functools import os +import shutil +import tempfile import typing import datasets @@ -27,26 +29,18 @@ class HuggingFaceDataModule(datamodule.DataModule): - `hf_dataset`: the HuggingFace dataset object. - `feature_key`: the key in the dataset features to use as input. - `target_key`: the key in the dataset features to use as target. - - `output_signature`: a (nested) structure of `tf.TensorSpec` objects. - - `_create_dataset`: method to create a `tf.data.Dataset` from the + - `create_dataset`: method to create a `tf.data.Dataset` from the HuggingFace dataset object. - Attributes: - path (str): The path to the HuggingFace dataset. - revision (str): The revision of the dataset for version control. - Args: batch_size (int): The batch size for data loading. - deterministic (bool): Whether to enforce deterministic loading behavior. + deterministic (bool): Whether enforce deterministic loading behavior. drop_remainder (bool): Whether to drop the last incomplete batch. num_workers (int): Number of shards for distributed loading. - shuffle_buffer_size (int): Buffer size for shuffling the dataset. transform (Optional[Callable], optional): An optional function to - transform the input features. Default is `None`. - target_transform (Optional[Callable], optional): An optional function - to transform the target features. Default is `None`. - rng (jax.Array, optional): Random key for shuffling. - Default is `random.PRNGKey(42)`. + transform the features. Default is `None`. + shuffle_buffer_size (int): Buffer size for shuffling the dataset. + rng (Any): Random seed for shuffling. Default is `PRNGKey(42)`. """ def __init__( @@ -57,17 +51,15 @@ def __init__( num_workers: int, shuffle_buffer_size: int, transform: typing.Optional[typing.Callable] = None, - target_transform: typing.Optional[typing.Callable] = None, - rng: jax.Array = random.PRNGKey(42), + rng: typing.Any = jax.random.PRNGKey(42), ) -> None: self._batch_size = batch_size self._deterministic = deterministic self._drop_remainder = drop_remainder self._num_workers = num_workers self._shuffle_buffer_size = shuffle_buffer_size - self._rng = rng + self._rng = jax.random.fold_in(rng, jax.process_index()) self._transform = transform - self._target_transform = target_transform # ========================================= # Interface @@ -91,28 +83,34 @@ def target_key(self) -> typing.Optional[str]: @property @abc.abstractmethod - def output_signature(self) -> typing.Any: - r"""Any: A (nested) structure of `tf.TensorSpec` objects.""" + def train_dataset(self) -> typing.Iterable: + r"""Iterable: The training dataset split.""" ... + @property @abc.abstractmethod - def _create_dataset( - self, - *, - split: str, - shuffle_seed: typing.Optional[int] = None, - ) -> tf.data.Dataset: - r"""Create an `tf.data.Dataset` from the HuggingFace dataset object. + def eval_dataset(self) -> typing.Iterable: + r"""Iterable: The validation dataset split.""" + ... - Args: - split (str): The dataset split to create. - shuffle_seed (Optional[int], optional): Seed for shuffling. - If `None`, no shuffling is applied. + @property + @abc.abstractmethod + def test_dataset(self) -> typing.Iterable: + r"""Iterable: The test dataset split.""" + ... + + @staticmethod + @abc.abstractmethod + def create_dataset(*args, **kwargs) -> tf.data.Dataset: + r"""Create sharded `tf.data.Dataset` from the HuggingFace dataset. + + The default method is suitable for processing image datasets with + `Pillow` images. Override this method for custom dataset processing. Returns: The created `tf.data.Dataset` instance. """ - pass + ... # ========================================= @property @@ -132,7 +130,7 @@ def drop_remainder(self) -> bool: @property def num_workers(self) -> int: - r"""int: Number of shards for distributed loading.""" + r"""int: Number of workers for distributed loading.""" return self._num_workers @property @@ -151,6 +149,11 @@ def num_test_examples(self) -> int: r"""int: Number of test examples.""" return len(self.hf_dataset["test"]) # type: ignore + @property + def rng(self) -> typing.Any: + r"""Any: Random seed for shuffling.""" + return self._rng + @property def shuffle_buffer_size(self) -> int: r"""int: Buffer size for shuffling the dataset.""" @@ -166,42 +169,14 @@ def transform(self) -> typing.Optional[typing.Callable]: r"""Optional[Callable]: Transformation for the input features.""" return self._transform - @property - def target_transform(self) -> typing.Optional[typing.Callable]: - r"""Optional[Callable]: Transformation for the target features.""" - return self._target_transform - - @property - def rng(self) -> jax.Array: - r"""jax.Array: Random key for shuffling.""" - return self._rng - - def train_dataloader(self) -> typing.Generator[PyTree, None, None]: - r"""Returns an iterable over the training dataset.""" - self._rng, shuffle_rng = random.split(self._rng, num=2) - ds = self._create_dataset( - split="train", - shuffle_seed=int(shuffle_rng[0]), # type: ignore - ) - for data in ds.as_numpy_iterator(): - yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) - - def eval_dataloader(self) -> typing.Generator[PyTree, None, None]: - r"""Returns an iterable over the validation dataset.""" - ds = self._create_dataset(split="validation") - for data in ds.as_numpy_iterator(): - yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) - - def test_dataloader(self) -> typing.Generator[PyTree, None, None]: - r"""Returns an iterable over the test dataset.""" - ds = self._create_dataset(split="test") - for data in ds.as_numpy_iterator(): - yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) - class HuggingFaceImageDataModule(HuggingFaceDataModule): r"""Data module for HuggingFace image datasets. + Attributes: + path (str): The path to the HuggingFace dataset. + revision (str): The revision of the dataset for version control. + Args: batch_size (int): The batch size for data loading. deterministic (bool): Whether the dataloaders are deterministic. @@ -209,12 +184,12 @@ class HuggingFaceImageDataModule(HuggingFaceDataModule): num_workers (int): Number of shards for distributed loading. resize (int): The size to resize images to (square). resample (int): Resampling filter to use for resizing images. + shuffle_buffer_size (int): Buffer size for random shuffling. transform (Optional[Callable], optional): An optional function to - transform the input images. Defaults to `None`. - target_transform (Optional[Callable], optional): An optional function - to transform the target features. Defaults to `None`. - rng (jax.Array, optional): Random key for shuffling. - Default is `random.PRNGKey(42)`. + transform the input images. Default is `None`. + use_cache (bool, optional): Whether to use cached dataset. + Default is `True`. + rng (Any): Random seed for shuffling. Default is `PRNGKey(42)`. """ def __init__( @@ -227,12 +202,23 @@ def __init__( resample: int, shuffle_buffer_size: int, transform: typing.Optional[typing.Callable] = None, - target_transform: typing.Optional[typing.Callable] = None, - rng: jax.Array = random.PRNGKey(42), + use_cache: bool = True, + rng: typing.Any = jax.random.PRNGKey(42), ) -> None: - r"""Instantiates a `HuggingFaceImageDataModule` object.""" self._resize = resize self._resample = resample + if use_cache: + cache_dir = os.path.join( + tempfile.gettempdir(), + "chimera", + "huggingface", + ) + if os.path.exists(cache_dir): + # NOTE: clear the cache directory to avoid corrupted cache + shutil.rmtree(cache_dir) + os.makedirs(cache_dir, exist_ok=True) + else: + cache_dir = None super().__init__( batch_size=batch_size, @@ -241,133 +227,236 @@ def __init__( num_workers=num_workers, shuffle_buffer_size=shuffle_buffer_size, transform=transform, - target_transform=target_transform, rng=rng, ) + # prepare the dataset splits + pre_transform = functools.partial( + self.pre_transform, + feature_key=self.feature_key, + target_key=self.target_key, + center_crop=True, + resample=self._resample, + resize=self._resize, + ) + self._train_dataset = self.create_dataset( + batch_size=self.batch_size, + dataset=self.hf_dataset["train"] + .map(pre_transform, batched=False, num_proc=1) + .to_tf_dataset(batch_size=None, prefetch=False), + deterministic=self.deterministic, + drop_remainder=self.drop_remainder, + shuffle_buffer_size=self.shuffle_buffer_size, + shuffle_seed=int(self._rng[0]), + transform=self.transform, + cache_dir=( + os.path.join(cache_dir, "train_" + self.__class__.__name__) + if cache_dir is not None + else None + ), + ) + self._test_dataset = self.create_dataset( + batch_size=self.batch_size, + dataset=self.hf_dataset["test"] + .map(pre_transform, batched=False, num_proc=1) + .to_tf_dataset(batch_size=None, prefetch=False), + deterministic=self.deterministic, + drop_remainder=self.drop_remainder, + shuffle_buffer_size=self.shuffle_buffer_size, + shuffle_seed=None, + transform=self.transform, + cache_dir=( + os.path.join(cache_dir, "test_" + self.__class__.__name__) + if cache_dir is not None + else None + ), + ) + if "validation" in self.hf_dataset: + self._eval_dataset = self.create_dataset( + batch_size=self.batch_size, + dataset=self.hf_dataset["validation"] + .map(pre_transform, batched=False, num_proc=1) + .to_tf_dataset(batch_size=None, prefetch=False), + deterministic=self.deterministic, + drop_remainder=self.drop_remainder, + shuffle_buffer_size=self.shuffle_buffer_size, + shuffle_seed=None, + transform=self.transform, + cache_dir=( + os.path.join(cache_dir, "val_" + self.__class__.__name__) + if cache_dir is not None + else None + ), + ) + elif "val" in self.hf_dataset: + self._eval_dataset = self.create_dataset( + batch_size=self.batch_size, + dataset=self.hf_dataset["val"] + .map(pre_transform, batched=False, num_proc=1) + .to_tf_dataset(batch_size=None, prefetch=False), + deterministic=self.deterministic, + drop_remainder=self.drop_remainder, + shuffle_buffer_size=self.shuffle_buffer_size, + shuffle_seed=None, + transform=self.transform, + cache_dir=( + os.path.join(cache_dir, "val_" + self.__class__.__name__) + if cache_dir is not None + else None + ), + ) + else: + # NOTE: otherwise, use test set as validation set by default + self._eval_dataset = self._test_dataset + @property - def image_shape(self) -> typing.Tuple[int, int, int]: - r"""Tuple[int, int, int]: The shape of the images.""" - return (self._resize, self._resize, 3) + def train_dataset(self) -> tf.data.Dataset: + r"""tf.data.Dataset: The training dataset split.""" + return self._train_dataset @property - def output_signature(self) -> typing.Dict[str, tf.TensorSpec]: - r"""Dict[str, tf.TensorSpec]: Tensor specifications.""" - return { - "image": tf.TensorSpec(shape=self.image_shape, dtype=tf.uint8), # type: ignore - "label": tf.TensorSpec(shape=(), dtype=tf.int64), # type: ignore - } + def eval_dataset(self) -> tf.data.Dataset: + r"""tf.data.Dataset: The validation dataset split.""" + return self._eval_dataset - def _create_dataset( - self, + @property + def test_dataset(self) -> tf.data.Dataset: + r"""tf.data.Dataset: The test dataset split.""" + return self._test_dataset + + @staticmethod + def create_dataset( *, - split: str, + batch_size: int, + deterministic: bool, + drop_remainder: bool, + dataset: tf.data.Dataset, + shuffle_buffer_size: int, shuffle_seed: typing.Optional[int] = None, + transform: typing.Optional[typing.Callable] = None, + cache_dir: typing.Optional[str] = None, ) -> tf.data.Dataset: - r"""Create an `tf.data.Dataset` from the HuggingFace dataset object. + r"""Create sharded `tf.data.Dataset` from the HuggingFace dataset. The default method is suitable for processing image datasets with `Pillow` images. Override this method for custom dataset processing. Args: - split (str): The dataset split to create. + batch_size (int): The batch size for data loading. + deterministic (bool): Whether to enforce deterministic loading. + drop_remainder (bool): Whether to drop the last incomplete batch. + dataset (tf.data.Dataset): The converted HuggingFace dataset. + shuffle_buffer_size (int): Buffer size for random shuffling. shuffle_seed (Optional[int], optional): Seed for shuffling. If `None`, no shuffling is applied. + transform (Optional[Callable], optional): An optional function to + transform the features. Default is `None`. + cache_dir (Optional[str], optional): Directory to cache the dataset. Returns: The created `tf.data.Dataset` instance. """ - _hf_dataset = self.hf_dataset[split] - - def __hf_generator() -> typing.Generator[typing.Any, None, None]: - r"""Default iterator over HuggingFace dataset.""" - for example in _hf_dataset: - image = example[self.feature_key] # type: ignore - target = ( - example[self.target_key] # type: ignore - if self.target_key - else None - ) - if not isinstance(image, Image.Image): - raise ValueError( - "Default iterator expects the image to be a " - f"`PIL.Image.Image` object, but got {type(image)}." - ) - image = image.convert("RGB") - - # resize the image - width, height = image.size - scale = self._resize / min(width, height) - new_width, new_height = int(width * scale), int(height * scale) - image = image.resize( - size=(new_width, new_height), - resample=self._resample, - ) - - # center crop - left = (new_width - self._resize) / 2 - top = (new_height - self._resize) / 2 - right = (new_width + self._resize) / 2 - bottom = (new_height + self._resize) / 2 - image = image.crop((left, top, right, bottom)) + if isinstance(transform, typing.Callable): + dataset = dataset.map( + map_func=transform, + deterministic=deterministic, + num_parallel_calls=tf.data.AUTOTUNE, + ) - yield {"image": image, "label": target} + if shuffle_seed is not None: + dataset = dataset.shuffle( + buffer_size=shuffle_buffer_size, + seed=shuffle_seed, + reshuffle_each_iteration=True, + ) - ds = tf.data.Dataset.from_generator( - __hf_generator, - output_signature=self.output_signature, + if cache_dir is not None: + dataset = dataset.cache(filename=cache_dir) + + dataset = dataset.batch( + batch_size=batch_size, + deterministic=deterministic, + drop_remainder=drop_remainder, + num_parallel_calls=tf.data.AUTOTUNE, ) - def __make_shard_dataset( - shard_index: int, - num_workers: int, - dataset: tf.data.Dataset, - local_seed: typing.Optional[int] = None, - ) -> tf.data.Dataset: - r"""Shards the input TensorFlow dataset for parallel loading.""" - local_ds = dataset.shard(num_shards=num_workers, index=shard_index) - if local_seed is not None: - local_ds = local_ds.shuffle( - buffer_size=self.shuffle_buffer_size, - seed=int(local_seed), # type: ignore - ) - if self.transform is not None: - local_ds = local_ds.map( - map_func=self.transform, - deterministic=self.deterministic, - num_parallel_calls=tf.data.AUTOTUNE, - ) - local_ds = local_ds.batch( - batch_size=self.batch_size, - deterministic=self.deterministic, - drop_remainder=self.drop_remainder, - num_parallel_calls=tf.data.AUTOTUNE, + return dataset.prefetch(buffer_size=tf.data.AUTOTUNE) + + @staticmethod + def pre_transform( + example: typing.Dict[str, typing.Any], + feature_key: str, + target_key: typing.Optional[str], + center_crop: bool = True, + resample: typing.Optional[int] = None, + resize: typing.Optional[int] = None, + ) -> typing.Dict[str, typing.Any]: + r"""Pre-transformation function for input images. + + Args: + example (Dict[str, Any]): A dictionary of data from the dataset. + feature_key (str): The name of the input features to use. + target_key (Optional[str]): The name of the target features to use. + center_crop (bool, optional): Whether to apply center cropping + after resizing. Default is `True`. + resample (Optional[int], optional): The resampling filter to use for + resizing images. If `None`, use `PIL.Image.NEAREST`. + resize (Optional[int], optional): The size to resize images to + (square). If `None`, no resizing is applied. Default is `None`. + + Returns: + A dictionary with processed images and targets. + """ + image = example[feature_key] + target = example[target_key] if target_key is not None else None + if not isinstance(image, Image.Image): + raise ValueError( + "Default pre-transformation expects the image to be a " + f"`PIL.Image.Image` object, but got {type(image)}." ) - return local_ds + image = image.convert("RGB") - if shuffle_seed is not None: - local_seed = random.fold_in( - random.PRNGKey(shuffle_seed), - jax.process_index(), - )[0] - local_seed = int(local_seed) # type: ignore + # resize the image + if resize is not None: + width, height = image.size + scale = resize / min(width, height) + new_width, new_height = int(width * scale), int(height * scale) + image = image.resize( + size=(new_width, new_height), + resample=resample, + ) + + # center crop + if center_crop: + left = (new_width - resize) / 2 + top = (new_height - resize) / 2 + right = (new_width + resize) / 2 + bottom = (new_height + resize) / 2 + image = image.crop((left, top, right, bottom)) + + if target_key is None: + return {feature_key: image} else: - local_seed = None - - indices = tf.data.Dataset.range(self.num_workers) - out = indices.interleave( - map_func=functools.partial( - __make_shard_dataset, - num_workers=self.num_workers, - dataset=ds, - local_seed=local_seed, - ), - deterministic=self.deterministic, - num_parallel_calls=tf.data.AUTOTUNE, - ) + return { + feature_key: image, + target_key: target, + } + + def train_dataloader(self) -> typing.Generator[PyTree, None, None]: + r"""Generator[PyTree]: Returns an iterable over the training data.""" + for data in self.train_dataset.as_numpy_iterator(): + yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) - return out.prefetch(buffer_size=tf.data.AUTOTUNE) + def eval_dataloader(self) -> typing.Generator[PyTree, None, None]: + r"""Generator[PyTree]: Returns an iterable over the validation data.""" + for data in self.eval_dataset.as_numpy_iterator(): + yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) + + def test_dataloader(self) -> typing.Generator[PyTree, None, None]: + r"""Generator[PyTree]: Returns an iterable over the test data.""" + for data in self.test_dataset.as_numpy_iterator(): + yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) # ============================================================================== @@ -416,7 +505,6 @@ def __init__( shuffle_buffer_size: int = 10_000, streaming: bool = False, transform: typing.Optional[typing.Callable] = None, - target_transform: typing.Optional[typing.Callable] = None, rng: jax.Array = random.PRNGKey(42), ) -> None: self._hf_dataset = datasets.load_dataset( @@ -434,7 +522,6 @@ def __init__( resample=resample, shuffle_buffer_size=shuffle_buffer_size, transform=transform, - target_transform=target_transform, rng=rng, ) @@ -460,13 +547,6 @@ def num_val_examples(self) -> int: # NOTE: using test set as validation set by default return len(self.hf_dataset["test"]) # type: ignore - @typing_extensions.override - def eval_dataloader(self) -> typing.Generator[PyTree, None, None]: - r"""Returns an iterable over the validation dataset.""" - ds = self._create_dataset(split="test") - for data in ds.as_numpy_iterator(): - yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) - class CIFAR100DataModule(HuggingFaceImageDataModule): r"""CIFAR-100 Image Classification Dataset. @@ -510,7 +590,6 @@ def __init__( shuffle_buffer_size: int = 10_000, streaming: bool = False, transform: typing.Optional[typing.Callable] = None, - target_transform: typing.Optional[typing.Callable] = None, rng: jax.Array = random.PRNGKey(42), ) -> None: self._hf_dataset = datasets.load_dataset( @@ -528,7 +607,6 @@ def __init__( resample=resample, shuffle_buffer_size=shuffle_buffer_size, transform=transform, - target_transform=target_transform, rng=rng, ) @@ -554,13 +632,6 @@ def num_val_examples(self) -> int: # NOTE: using test set as validation set by default return len(self.hf_dataset["test"]) # type: ignore - @typing_extensions.override - def eval_dataloader(self) -> typing.Generator[PyTree, None, None]: - r"""Returns an iterable over the validation dataset.""" - ds = self._create_dataset(split="test") - for data in ds.as_numpy_iterator(): - yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) - class ImageNet1KDataModule(HuggingFaceImageDataModule): r"""ILSVRC2012 image dataset subset with :math:`1,000` classes. @@ -602,7 +673,6 @@ def __init__( shuffle_buffer_size: int = 10_000, streaming: bool = False, transform: typing.Optional[typing.Callable] = None, - target_transform: typing.Optional[typing.Callable] = None, rng: jax.Array = random.PRNGKey(42), ) -> None: self._hf_dataset = datasets.load_dataset( @@ -620,7 +690,6 @@ def __init__( resample=resample, shuffle_buffer_size=shuffle_buffer_size, transform=transform, - target_transform=target_transform, rng=rng, ) @@ -679,7 +748,6 @@ def __init__( shuffle_buffer_size: int = 10_000, streaming: bool = False, transform: typing.Optional[typing.Callable] = None, - target_transform: typing.Optional[typing.Callable] = None, rng: jax.Array = random.PRNGKey(42), ) -> None: self._hf_dataset = datasets.load_dataset( @@ -697,7 +765,6 @@ def __init__( resample=resample, shuffle_buffer_size=shuffle_buffer_size, transform=transform, - target_transform=target_transform, rng=rng, ) @@ -716,12 +783,6 @@ def target_key(self) -> str: r"""str: The key in the dataset features to use as target.""" return "label" - @property - @typing_extensions.override - def image_shape(self) -> typing.Tuple[int, int, int]: - r"""Tuple[int, int, int]: The shape of the images.""" - return (self._resize, self._resize, 1) - @property @typing_extensions.override def num_val_examples(self) -> int: @@ -729,13 +790,6 @@ def num_val_examples(self) -> int: # NOTE: using test set as validation set by default return len(self.hf_dataset["test"]) # type: ignore - @typing_extensions.override - def eval_dataloader(self) -> typing.Generator[PyTree, None, None]: - r"""Returns an iterable over the validation dataset.""" - ds = self._create_dataset(split="test") - for data in ds.as_numpy_iterator(): - yield jax.tree_util.tree_map(lambda x: jnp.asarray(x), data) - __all__ = [ "HuggingFaceDataModule", From d368e52ffb275810b9780606cdf8094fc1315a9a Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Wed, 19 Nov 2025 17:57:21 -0500 Subject: [PATCH 24/67] hotfix: Fixed error in huggingface datamodule Signed-off-by: Juanwu Lu --- src/data/huggingface.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/data/huggingface.py b/src/data/huggingface.py index bd89772..9992f00 100644 --- a/src/data/huggingface.py +++ b/src/data/huggingface.py @@ -436,12 +436,9 @@ def pre_transform( image = image.crop((left, top, right, bottom)) if target_key is None: - return {feature_key: image} + return {"image": image} else: - return { - feature_key: image, - target_key: target, - } + return {"image": image, "label": target} def train_dataloader(self) -> typing.Generator[PyTree, None, None]: r"""Generator[PyTree]: Returns an iterable over the training data.""" From d91518fa7b8fef70777739764791664815f547a2 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Thu, 20 Nov 2025 05:24:05 -0500 Subject: [PATCH 25/67] hotfix: Updated checkpoint frequency Signed-off-by: Juanwu Lu --- src/projects/generative/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/projects/generative/config.py b/src/projects/generative/config.py index df688f5..79651f7 100644 --- a/src/projects/generative/config.py +++ b/src/projects/generative/config.py @@ -53,6 +53,7 @@ def meanflow_unet_cifar_10() -> _config.ExperimentConfig: trainer=_config.TrainerConfig( num_train_steps=800_000, log_every_n_steps=5, + checkpoint_every_n_steps=10_000, # save every 10k steps eval_every_n_steps=1_000_000, # NOTE: never evaluate now max_checkpoints_to_keep=3, profile=False, From 07ab0aa13b47c3454f7ba25b64509375af52bbba Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Fri, 21 Nov 2025 11:34:28 -0500 Subject: [PATCH 26/67] feat: Added visualization utility to create a grid of images Signed-off-by: Juanwu Lu --- src/utilities/BUILD | 8 ++++++ src/utilities/visualization.py | 45 ++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 src/utilities/visualization.py diff --git a/src/utilities/BUILD b/src/utilities/BUILD index 41825a8..fb1fc18 100644 --- a/src/utilities/BUILD +++ b/src/utilities/BUILD @@ -18,3 +18,11 @@ ml_py_library( "jax", ], ) + +ml_py_library( + name = "visualization", + srcs = ["visualization.py"], + deps = [ + "numpy", + ], +) diff --git a/src/utilities/visualization.py b/src/utilities/visualization.py new file mode 100644 index 0000000..e810556 --- /dev/null +++ b/src/utilities/visualization.py @@ -0,0 +1,45 @@ +import typing + +import numpy as np +import numpy.typing as npt + + +def make_grid( + images: typing.Union[typing.Any, npt.NDArray], + n_rows: int = 8, + padding: int = 2, +) -> npt.NDArray: + r"""Convert a batch of images into a grid for visualization. + + Args: + images (Any | NDArray): Batch of images with shape `(B, H, W, C)`. + n_rows (int, optional): Number of rows in grid. Default is :math:`8`. + padding (int, optional): Number of pixels between pair of images. + Default is :math:`2`. + + Returns: + The array containing a grid of input images. + """ + images = np.asarray(images) + if images.ndim != 4: + raise ValueError( + "Input images must be a 4-D numpy array with shape (B, H, W, C). " + f"But got {images}." + ) + bz, h, w, c = images.shape + n_cols = int(np.ceil(bz / n_rows)) + shape = ( + h * n_rows + padding * (n_rows - 1), + w * n_cols + padding * (n_cols - 1), + c, + ) + out = np.zeros(shape, dtype=images.dtype) + + for idx, img in enumerate(images): + row = idx // n_cols + col = idx % n_cols + top = row * (h + padding) + left = col * (w + padding) + out[top : top + h, left : left + w] = img + + return out From f462f553a09bcee175f1aa1f55e314adc0db9312 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Fri, 21 Nov 2025 14:58:04 -0500 Subject: [PATCH 27/67] hotfix: Fixed wrong implementation of t\neq{r} in meanflow Signed-off-by: Juanwu Lu --- src/projects/generative/meanflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 5013190..f82662c 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -637,7 +637,7 @@ def compute_loss( ), self.timestamp_overlap_rate, ) - r = jnp.where(r_neq_t_mask, t, r) + r = jnp.where(r_neq_t_mask, r, t) # sample e ~ N(0, I) e = jax.random.normal(key=e_rng, shape=image.shape, dtype=image.dtype) From 3649da173c1f70308a5104a5c9095a300e00e2f5 Mon Sep 17 00:00:00 2001 From: jiaru Date: Fri, 21 Nov 2025 16:42:40 -0500 Subject: [PATCH 28/67] feat: symmetric mean flow --- .gitignore | 2 ++ src/projects/generative/meanflow.py | 23 ++++++++++++++++++----- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 36293a4..709699b 100644 --- a/.gitignore +++ b/.gitignore @@ -212,3 +212,5 @@ cython_debug/ /data/ /logs/ requirements_*.txt + +.specstory/ diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index f82662c..155a379 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -682,16 +682,29 @@ def u_fn( return out - drdt = jnp.zeros_like(r) - dtdt = jnp.ones_like(t) - u, dudt = jax.jvp(u_fn, (z, r, t), (v, drdt, dtdt)) - - # computes the target + # NOTE: following the original meanflow + # drdt = jnp.zeros_like(r) + # dtdt = jnp.ones_like(t) + # u, dudt = jax.jvp(u_fn, (z, r, t), (v, drdt, dtdt)) + # u_target = jax.lax.stop_gradient( + # v + # - jnp.clip(t - r, a_min=0.0, a_max=1.0)[..., None, None, None] + # * dudt + # ) + + # NOTE: following the symmetric meanflow + drdt = jnp.ones_like(r) + dtdt = - jnp.ones_like(t) + u, dudt = jax.jvp(u_fn, (z, r, t), (-v, drdt, dtdt)) u_target = jax.lax.stop_gradient( v - jnp.clip(t - r, a_min=0.0, a_max=1.0)[..., None, None, None] * dudt + * 0.5 ) + + # computes the target + # NOTE: sum over all the pixels, following official implementation loss = jnp.sum(jnp.square(u - u_target), axis=(-1, -2, -3)) From 4bad08216f4eb90cb4afd8f4b16356e6b6551dfe Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Fri, 21 Nov 2025 17:24:26 -0500 Subject: [PATCH 29/67] hotfix: Switch back to original meanflow loss Signed-off-by: Juanwu Lu --- src/projects/generative/meanflow.py | 30 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 155a379..1de1da9 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -627,6 +627,7 @@ def compute_loss( ) t, r = jnp.maximum(t, r), jnp.minimum(t, r) # ensure a portion of overlap between t and r + # NOTE: the following code randomly mask by uniform samples r_neq_t_mask = jnp.greater_equal( jax.random.uniform( key=mask_rng, @@ -683,28 +684,27 @@ def u_fn( return out # NOTE: following the original meanflow - # drdt = jnp.zeros_like(r) - # dtdt = jnp.ones_like(t) - # u, dudt = jax.jvp(u_fn, (z, r, t), (v, drdt, dtdt)) - # u_target = jax.lax.stop_gradient( - # v - # - jnp.clip(t - r, a_min=0.0, a_max=1.0)[..., None, None, None] - # * dudt - # ) - - # NOTE: following the symmetric meanflow - drdt = jnp.ones_like(r) - dtdt = - jnp.ones_like(t) - u, dudt = jax.jvp(u_fn, (z, r, t), (-v, drdt, dtdt)) + drdt = jnp.zeros_like(r) + dtdt = jnp.ones_like(t) + u, dudt = jax.jvp(u_fn, (z, r, t), (v, drdt, dtdt)) u_target = jax.lax.stop_gradient( v - jnp.clip(t - r, a_min=0.0, a_max=1.0)[..., None, None, None] * dudt - * 0.5 ) + # NOTE: following the symmetric meanflow + # drdt = jnp.ones_like(r) + # dtdt = jnp.negative(jnp.ones_like(t)) + # u, dudt = jax.jvp(u_fn, (z, r, t), (-v, drdt, dtdt)) + # u_target = jax.lax.stop_gradient( + # v + # - jnp.clip(t - r, a_min=0.0, a_max=1.0)[..., None, None, None] + # * dudt + # * 0.5 + # ) + # computes the target - # NOTE: sum over all the pixels, following official implementation loss = jnp.sum(jnp.square(u - u_target), axis=(-1, -2, -3)) From 63b4cdfff6556b7f9baa54b4847c799f232c5c87 Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 03:57:25 -0600 Subject: [PATCH 30/67] feat: Added dependencies for running on MPS framework Signed-off-by: Juanwu --- third_party/BUILD | 19 +++++++++++++++++++ third_party/defs.bzl | 2 ++ third_party/requirements_mps.in | 1 + 3 files changed, 22 insertions(+) create mode 100644 third_party/requirements_mps.in diff --git a/third_party/BUILD b/third_party/BUILD index f9e5f7c..bdd1155 100644 --- a/third_party/BUILD +++ b/third_party/BUILD @@ -18,6 +18,11 @@ config_setting( define_values = {"ml_platform": "tpu"}, ) +config_setting( + name = "is_mps", + define_values = {"ml_platform": "mps"}, +) + # compile pip requirements into a lock file compile_pip_requirements( name = "requirements_3_10_cpu", @@ -57,3 +62,17 @@ compile_pip_requirements( ], requirements_txt = "requirements_3_10_tpu_lock.txt", ) + +compile_pip_requirements( + name = "requirements_3_10_mps", + timeout = "moderate", + srcs = [ + "requirements.in", + "requirements_mps.in", + ], + extra_args = [ + "--allow-unsafe", + "--resolver=backtracking", + ], + requirements_txt = "requirements_3_10_mps_lock.txt", +) diff --git a/third_party/defs.bzl b/third_party/defs.bzl index 269135e..387a995 100644 --- a/third_party/defs.bzl +++ b/third_party/defs.bzl @@ -2,6 +2,7 @@ load("@ml_infra_cpu_3_10//:requirements.bzl", cpu_req = "requirement") load("@ml_infra_cuda_3_10//:requirements.bzl", cuda_req = "requirement") +load("@ml_infra_mps_3_10//:requirements.bzl", mps_req = "requirement") load("@ml_infra_tpu_3_10//:requirements.bzl", tpu_req = "requirement") load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") @@ -17,6 +18,7 @@ def _select_requirement(name): return select({ "//third_party:is_cpu": [cpu_req(name)], "//third_party:is_cuda": [cuda_req(name)], + "//third_party:is_mps": [mps_req(name)], "//third_party:is_tpu": [tpu_req(name)], }) diff --git a/third_party/requirements_mps.in b/third_party/requirements_mps.in new file mode 100644 index 0000000..872ef87 --- /dev/null +++ b/third_party/requirements_mps.in @@ -0,0 +1 @@ +jax-metal==0.1.1 From 222a88408012077e8b1c7b8c4c2b041018a73fc2 Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 03:58:00 -0600 Subject: [PATCH 31/67] feat: Added MPS dependencies to PIP hubs Signed-off-by: Juanwu --- MODULE.bazel | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/MODULE.bazel b/MODULE.bazel index 9a5854a..7345325 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -61,9 +61,15 @@ pip.parse( python_version = "3.10", requirements_lock = "//third_party:requirements_3_10_tpu_lock.txt", ) +pip.parse( + hub_name = "ml_infra_mps_3_10", + python_version = "3.10", + requirements_lock = "//third_party:requirements_3_10_mps_lock.txt", +) use_repo( pip, ml_infra_cpu_3_10 = "ml_infra_cpu_3_10", ml_infra_cuda_3_10 = "ml_infra_cuda_3_10", + ml_infra_mps_3_10 = "ml_infra_mps_3_10", ml_infra_tpu_3_10 = "ml_infra_tpu_3_10", ) From df11ccba639af3754b1e5404176665e3532cceb5 Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 03:58:42 -0600 Subject: [PATCH 32/67] feat: Added implementation for downsampling residual block in U-Net Signed-off-by: Juanwu --- src/projects/generative/model/BUILD | 19 +++ src/projects/generative/model/test_unet.py | 39 ++++++ src/projects/generative/model/unet.py | 154 +++++++++++++++++++++ 3 files changed, 212 insertions(+) create mode 100644 src/projects/generative/model/test_unet.py create mode 100644 src/projects/generative/model/unet.py diff --git a/src/projects/generative/model/BUILD b/src/projects/generative/model/BUILD index 7b747b5..7b57aef 100644 --- a/src/projects/generative/model/BUILD +++ b/src/projects/generative/model/BUILD @@ -22,3 +22,22 @@ ml_py_test( ":refinenet", ], ) + +ml_py_library( + name = "unet", + srcs = ["unet.py"], + deps = [ + "chex", + "flax", + "jax", + ], +) + +ml_py_test( + name = "test_unet", + srcs = ["test_unet.py"], + deps = [ + "jax", + ":unet", + ], +) diff --git a/src/projects/generative/model/test_unet.py b/src/projects/generative/model/test_unet.py new file mode 100644 index 0000000..5a29d0e --- /dev/null +++ b/src/projects/generative/model/test_unet.py @@ -0,0 +1,39 @@ +import sys +import typing + +import jax +from jax import numpy as jnp +import pytest + +from src.projects.generative.model import unet + + +@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) +def test_down_block(dtype: typing.Any) -> None: + r"""Tests the residual downsampling block in U-Net models.""" + rng = jax.random.PRNGKey(42) + + block = unet.DownResNetBlock(features=64, dtype=dtype, param_dtype=dtype) + test_input = jnp.ones((2, 32, 32, 32), dtype=dtype) + test_cond = jnp.ones((2, 16), dtype=dtype) + params_rng, dropout_rng = jax.random.split(rng, num=2) + variables = block.init( + rngs={"params": params_rng}, + inputs=test_input, + cond=test_cond, + deterministic=False, + ) + outputs = block.apply( + variables=variables, + inputs=test_input, + cond=test_cond, + deterministic=False, + rngs={"dropout": dropout_rng}, + ) + assert isinstance(outputs, jax.Array) + assert outputs.shape == (2, 32, 32, 64) + assert outputs.dtype == dtype + + +if __name__ == "__main__": + sys.exit(pytest.main(["-xv", __file__])) diff --git a/src/projects/generative/model/unet.py b/src/projects/generative/model/unet.py new file mode 100644 index 0000000..e8025cc --- /dev/null +++ b/src/projects/generative/model/unet.py @@ -0,0 +1,154 @@ +import typing + +import chex +from flax import linen as nn +import jax + + +class DownResNetBlock(nn.Module): + r"""A residual downsampling block with two convolutional layers. + + Args: + features (int): Dimensionality of the latent feaatures. + num_groups (int, optional): Number of groups for `GroupNorm`. + Default is :math:`32`. + epsilon (float, optional): Small float added to variance to avoid + dividing by zero in `GroupNorm`. Default is :math:`1e-5`. + deterministic (bool, optional): If true, the model is run in + deterministic mode (e.g., no dropout). Defaults to `None`. + dropout_rate (float, optional): Dropout rate. Default is :math:`0`. + dtype (Any, optional): The dtype of the computation. + param_dtype (Any, optional): The dtype of the parameters. + precision (Any, optional): Numerical precision of the computation. + """ + + features: int + num_groups: int = 32 + epsilon: float = 1e-5 + deterministic: typing.Optional[bool] = None + dropout_rate: float = 0.0 + dtype: typing.Any = None + param_dtype: typing.Any = None + precision: typing.Any = None + + def setup(self) -> None: + r"""Instantiates a `ResNetBlock` instance.""" + self.norm_1 = nn.GroupNorm( + num_groups=self.num_groups, + epsilon=self.epsilon, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="norm0", + ) + self.conv_1 = nn.Conv( + features=self.features, + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_in", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="conv0", + ) + self.cond_linear = nn.Dense( + features=self.features, + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_in", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="cond_in", + ) + self.dropout = nn.Dropout(rate=self.dropout_rate, name="dropout") + + self.norm_2 = nn.GroupNorm( + num_groups=self.num_groups, + epsilon=self.epsilon, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="norm1", + ) + self.conv_2 = nn.Conv( + features=self.features, + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_in", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="conv1", + ) + self.conv_shortcut = nn.Dense( + features=self.features, + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_in", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="conv_shortcut", + ) + + def __call__( + self, + inputs: jax.Array, + cond: typing.Optional[jax.Array] = None, + deterministic: typing.Optional[bool] = None, + ) -> jax.Array: + r"""Forward pass of the `ResNetBlock`. + + Args: + inputs (jax.Array): Input array of shape `(*, H, W, C_in)`. + cond (Optional[jax.Array], optional): Optional conditioning array + of shape `(*, C_cond)`. + deterministic (bool, optional): If true, the model is run in + deterministic mode (e.g., no dropout). Defaults to `None`. + + Returns: + Output array of shape `(*, H, W, C_out)`, where `C_out` is the + `features` specified during instantiation. + """ + m_deterministic = nn.merge_param( + "deterministic", + self.deterministic, + deterministic, + ) + batch_dims = inputs.shape[:-3] + dims = chex.Dimensions( + H=inputs.shape[-3], + W=inputs.shape[-2], + C=inputs.shape[-1], + ) + + out = self.conv_1(nn.silu(self.norm_1(inputs))) + + if cond is not None: + chex.assert_shape(cond, (*batch_dims, cond.shape[-1])) + out = out + self.cond_linear(cond)[..., None, None, :] + out = nn.silu(self.norm_2(out)) + out = self.dropout(out, deterministic=m_deterministic) + out = self.conv_2(out) + + if inputs.shape[-1] != self.features: + shortcut = self.conv_shortcut(inputs) + else: + shortcut = inputs + out = out + shortcut + chex.assert_shape(out, (*batch_dims, *dims["HW"], self.features)) + + return out From e75bc79e3340b70815172cfa78f7a05e8cd5e1bd Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 04:52:18 -0600 Subject: [PATCH 33/67] hotfix: Rename `DownResNetBlock` to `ResNetBlock` Signed-off-by: Juanwu --- src/projects/generative/model/test_unet.py | 4 ++-- src/projects/generative/model/unet.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/projects/generative/model/test_unet.py b/src/projects/generative/model/test_unet.py index 5a29d0e..56ee062 100644 --- a/src/projects/generative/model/test_unet.py +++ b/src/projects/generative/model/test_unet.py @@ -9,11 +9,11 @@ @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) -def test_down_block(dtype: typing.Any) -> None: +def test_resnet_block(dtype: typing.Any) -> None: r"""Tests the residual downsampling block in U-Net models.""" rng = jax.random.PRNGKey(42) - block = unet.DownResNetBlock(features=64, dtype=dtype, param_dtype=dtype) + block = unet.ResNetBlock(features=64, dtype=dtype, param_dtype=dtype) test_input = jnp.ones((2, 32, 32, 32), dtype=dtype) test_cond = jnp.ones((2, 16), dtype=dtype) params_rng, dropout_rng = jax.random.split(rng, num=2) diff --git a/src/projects/generative/model/unet.py b/src/projects/generative/model/unet.py index e8025cc..ea74835 100644 --- a/src/projects/generative/model/unet.py +++ b/src/projects/generative/model/unet.py @@ -5,7 +5,7 @@ import jax -class DownResNetBlock(nn.Module): +class ResNetBlock(nn.Module): r"""A residual downsampling block with two convolutional layers. Args: From 9fec7d4eec460702685e35a955123fcf8825da65 Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 05:02:59 -0600 Subject: [PATCH 34/67] feat: Added implementation for downsampling block in U-Net Signed-off-by: Juanwu --- src/projects/generative/model/test_unet.py | 34 +++++++++++++ src/projects/generative/model/unet.py | 56 ++++++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/src/projects/generative/model/test_unet.py b/src/projects/generative/model/test_unet.py index 56ee062..53366ac 100644 --- a/src/projects/generative/model/test_unet.py +++ b/src/projects/generative/model/test_unet.py @@ -35,5 +35,39 @@ def test_resnet_block(dtype: typing.Any) -> None: assert outputs.dtype == dtype +@pytest.mark.parametrize("with_conv", [True, False]) +@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) +def test_downsample_block(with_conv: bool, dtype: typing.Any) -> None: + r"""Tests the downsampling block in U-Net models.""" + rng = jax.random.PRNGKey(42) + + block = unet.DownsampleBlock( + with_conv=with_conv, + dtype=dtype, + param_dtype=dtype, + ) + test_input = jnp.ones((2, 32, 32, 32), dtype=dtype) + variables = block.init( + rngs={"params": rng}, + inputs=test_input, + ) + if with_conv: + assert "conv0" in variables["params"] + kernel = variables["params"]["conv0"]["kernel"] + assert isinstance(kernel, jax.Array) + assert kernel.shape == (3, 3, 32, 32) + bias = variables["params"]["conv0"]["bias"] + assert isinstance(bias, jax.Array) + assert bias.shape == (32,) + + outputs = block.apply( + variables=variables, + inputs=test_input, + ) + assert isinstance(outputs, jax.Array) + assert outputs.shape == (2, 16, 16, 32) + assert outputs.dtype == dtype + + if __name__ == "__main__": sys.exit(pytest.main(["-xv", __file__])) diff --git a/src/projects/generative/model/unet.py b/src/projects/generative/model/unet.py index ea74835..bb52bdd 100644 --- a/src/projects/generative/model/unet.py +++ b/src/projects/generative/model/unet.py @@ -152,3 +152,59 @@ def __call__( chex.assert_shape(out, (*batch_dims, *dims["HW"], self.features)) return out + + +class DownsampleBlock(nn.Module): + r"""A downsampling block using averaging pooling or strided convolution. + + Args: + with_conv (bool, optional): If true, uses a strided convolution for + downsampling. If `False`, uses average pooling. Default is `True`. + dtype (Any, optional): The dtype of the computation. + param_dtype (Any, optional): The dtype of the parameters. + """ + + with_conv: bool = True + dtype: typing.Any = None + param_dtype: typing.Any = None + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + r"""Forward pass of the `DownsampleBlock`. + + Args: + inputs (jax.Array): Input array of shape `(*, H, W, C)`. + + Returns: + Output array of shape `(*, H / 2, W / 2, C)`. + """ + batch_dims = inputs.shape[:-3] + dims = chex.Dimensions( + H=inputs.shape[-3], + h=inputs.shape[-3] // 2, + W=inputs.shape[-2], + w=inputs.shape[-2] // 2, + C=inputs.shape[-1], + ) + + if self.with_conv: + out = nn.Conv( + features=inputs.shape[-1], + kernel_size=(3, 3), + strides=(2, 2), + padding=((0, 1), (0, 1)), + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_in", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="conv0", + )(inputs) + else: + out = nn.avg_pool(inputs, window_shape=(2, 2), strides=(2, 2)) + chex.assert_shape(out, (*batch_dims, *dims["hwC"])) + + return out From 147427d549c35804c15f63b7d9b141092920f6b2 Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 05:11:15 -0600 Subject: [PATCH 35/67] feat: Added implementation for upsampling block in U-Net Signed-off-by: Juanwu --- src/projects/generative/model/test_unet.py | 34 +++++++++- src/projects/generative/model/unet.py | 73 ++++++++++++++++++++-- 2 files changed, 99 insertions(+), 8 deletions(-) diff --git a/src/projects/generative/model/test_unet.py b/src/projects/generative/model/test_unet.py index 53366ac..f5057c6 100644 --- a/src/projects/generative/model/test_unet.py +++ b/src/projects/generative/model/test_unet.py @@ -60,12 +60,40 @@ def test_downsample_block(with_conv: bool, dtype: typing.Any) -> None: assert isinstance(bias, jax.Array) assert bias.shape == (32,) - outputs = block.apply( - variables=variables, + outputs = block.apply(variables=variables, inputs=test_input) + assert isinstance(outputs, jax.Array) + assert outputs.shape == (2, 16, 16, 32) + assert outputs.dtype == dtype + + +@pytest.mark.parametrize("with_conv", [True, False]) +@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) +def test_upsample_block(with_conv: bool, dtype: typing.Any) -> None: + r"""Tests the upsampling block in U-Net models.""" + rng = jax.random.PRNGKey(42) + + block = unet.UpsampleBlock( + with_conv=with_conv, + dtype=dtype, + param_dtype=dtype, + ) + test_input = jnp.ones((2, 16, 16, 32), dtype=dtype) + variables = block.init( + rngs={"params": rng}, inputs=test_input, ) + if with_conv: + assert "conv0" in variables["params"] + kernel = variables["params"]["conv0"]["kernel"] + assert isinstance(kernel, jax.Array) + assert kernel.shape == (3, 3, 32, 32) + bias = variables["params"]["conv0"]["bias"] + assert isinstance(bias, jax.Array) + assert bias.shape == (32,) + + outputs = block.apply(variables=variables, inputs=test_input) assert isinstance(outputs, jax.Array) - assert outputs.shape == (2, 16, 16, 32) + assert outputs.shape == (2, 32, 32, 32) assert outputs.dtype == dtype diff --git a/src/projects/generative/model/unet.py b/src/projects/generative/model/unet.py index bb52bdd..c8362a4 100644 --- a/src/projects/generative/model/unet.py +++ b/src/projects/generative/model/unet.py @@ -47,7 +47,7 @@ def setup(self) -> None: padding=(1, 1), kernel_init=jax.nn.initializers.variance_scaling( scale=1.0, - mode="fan_in", + mode="fan_avg", distribution="uniform", ), bias_init=jax.nn.initializers.zeros, @@ -59,7 +59,7 @@ def setup(self) -> None: features=self.features, kernel_init=jax.nn.initializers.variance_scaling( scale=1.0, - mode="fan_in", + mode="fan_avg", distribution="uniform", ), bias_init=jax.nn.initializers.zeros, @@ -83,7 +83,7 @@ def setup(self) -> None: padding=(1, 1), kernel_init=jax.nn.initializers.variance_scaling( scale=1.0, - mode="fan_in", + mode="fan_avg", distribution="uniform", ), bias_init=jax.nn.initializers.zeros, @@ -95,7 +95,7 @@ def setup(self) -> None: features=self.features, kernel_init=jax.nn.initializers.variance_scaling( scale=1.0, - mode="fan_in", + mode="fan_avg", distribution="uniform", ), bias_init=jax.nn.initializers.zeros, @@ -195,7 +195,7 @@ def __call__(self, inputs: jax.Array) -> jax.Array: padding=((0, 1), (0, 1)), kernel_init=jax.nn.initializers.variance_scaling( scale=1.0, - mode="fan_in", + mode="fan_avg", distribution="uniform", ), bias_init=jax.nn.initializers.zeros, @@ -208,3 +208,66 @@ def __call__(self, inputs: jax.Array) -> jax.Array: chex.assert_shape(out, (*batch_dims, *dims["hwC"])) return out + + +class UpsampleBlock(nn.Module): + r"""An upsampling block using nearest-neighbor interpolation. + + Args: + with_conv (bool, optional): If true, applies a convolution after + upsampling. Default is `True`. + dtype (Any, optional): The dtype of the computation. + param_dtype (Any, optional): The dtype of the parameters. + precision (Any, optional): Numerical precision of the computation. + """ + + with_conv: bool = True + dtype: typing.Any = None + param_dtype: typing.Any = None + precision: typing.Any = None + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + r"""Forward pass of the `UpsampleBlock`. + + Args: + inputs (jax.Array): Input array of shape `(*, H, W, C)` + + Returns: + Output array of shape `(*, H * 2, W * 2, C)`. + """ + batch_dims = inputs.shape[:-3] + dims = chex.Dimensions( + H=inputs.shape[-3], + h=inputs.shape[-3] * 2, + W=inputs.shape[-2], + w=inputs.shape[-2] * 2, + C=inputs.shape[-1], + ) + + out = jax.image.resize( + inputs, + shape=(*batch_dims, *dims["hwC"]), + method="nearest", + antialias=True, + precision=self.precision, + ) + if self.with_conv: + out = nn.Conv( + features=inputs.shape[-1], + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="conv0", + )(out) + + chex.assert_shape(out, (*batch_dims, *dims["hwC"])) + return out From 6ce9152cf5e0fc236e717fe6a9358be88725a33e Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 05:55:53 -0600 Subject: [PATCH 36/67] feat: Added full implementation of U-Net for score-based generative models Signed-off-by: Juanwu --- src/projects/generative/model/test_unet.py | 33 ++++ src/projects/generative/model/unet.py | 204 +++++++++++++++++++++ 2 files changed, 237 insertions(+) diff --git a/src/projects/generative/model/test_unet.py b/src/projects/generative/model/test_unet.py index f5057c6..8f5e2c1 100644 --- a/src/projects/generative/model/test_unet.py +++ b/src/projects/generative/model/test_unet.py @@ -97,5 +97,38 @@ def test_upsample_block(with_conv: bool, dtype: typing.Any) -> None: assert outputs.dtype == dtype +@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) +def test_score_net(dtype: typing.Any) -> None: + r"""Tests the full U-Net model for score-based generative modeling.""" + rng = jax.random.PRNGKey(42) + + model = unet.ScoreNet( + features=128, + attn_resolutions=(), + dropout_rate=0.2, + dtype=dtype, + param_dtype=dtype, + ) + test_input = jnp.ones((2, 32, 32, 3), dtype=dtype) + test_cond = jnp.ones((2, 16), dtype=dtype) + params_rng, dropout_rng = jax.random.split(rng, num=2) + variables = model.init( + rngs={"params": params_rng}, + inputs=test_input, + cond=test_cond, + deterministic=True, + ) + outputs = model.apply( + variables=variables, + inputs=test_input, + cond=test_cond, + deterministic=True, + rngs={"dropout": dropout_rng}, + ) + assert isinstance(outputs, jax.Array) + assert outputs.shape == (2, 32, 32, 3) + assert outputs.dtype == dtype + + if __name__ == "__main__": sys.exit(pytest.main(["-xv", __file__])) diff --git a/src/projects/generative/model/unet.py b/src/projects/generative/model/unet.py index c8362a4..e65bb34 100644 --- a/src/projects/generative/model/unet.py +++ b/src/projects/generative/model/unet.py @@ -3,6 +3,7 @@ import chex from flax import linen as nn import jax +from jax import numpy as jnp class ResNetBlock(nn.Module): @@ -271,3 +272,206 @@ def __call__(self, inputs: jax.Array) -> jax.Array: chex.assert_shape(out, (*batch_dims, *dims["hwC"])) return out + + +class ScoreNet(nn.Module): + r"""U-Net architecture for score-function estimation. + + This module is adapted from the original implementation of the U-Net + architecture from "Score-Based Generative Modeling through Stochastic + Differential Equations" by Yang Song et al. and the original implementation + is available at `https://github.com/yang-song/score_sde_pytorch`. + + Args: + features (int): Base number of features for the latent representations. + ch_mults (typing.Sequence[int], optional): Sequence of multipliers + for the number of features at each level of the U-Net. + num_groups (int, optional): Number of groups for `GroupNorm`. + num_res_blocks (int, optional): Number of residual blocks per level. + attn_resolutions (typing.Sequence[int], optional): Sequence of + resolutions at which to apply attention mechanisms. + dropout_rate (float, optional): Dropout rate. Default is :math:`0.0`. + epsilon (float, optional): Small float added to variance to avoid + dividing by zero in `GroupNorm`. Default is :math:`1e-5`. + deterministic (bool, optional): If true, the model is run in + deterministic mode (e.g., no dropout). Defaults to `None`. + dtype (Any, optional): The dtype of the computation. + param_dtype (Any, optional): The dtype of the parameters. + precision (Any, optional): Numerical precision of the computation. + """ + + features: int + ch_mults: typing.Sequence[int] = (1, 2, 2, 2) + num_groups: int = 32 + num_res_blocks: int = 4 + attn_resolutions: typing.Sequence[int] = (16,) + dropout_rate: float = 0.0 + epsilon: float = 1e-5 + deterministic: typing.Optional[bool] = None + dtype: typing.Any = None + param_dtype: typing.Any = None + precision: typing.Any = None + + @nn.compact + def __call__( + self, + inputs: jax.Array, + cond: jax.Array, + deterministic: typing.Optional[bool] = None, + ) -> jax.Array: + r"""Forward pass of the `ScoreNet`. + + Args: + inputs (jax.Array): Input array of shape `(*, H, W, C_in)`. + cond (jax.Array): Conditioning array of shape `(*, C_cond)`. + deterministic (bool, optional): If true, the model is run in + deterministic mode (e.g., no dropout). Defaults to `None`. + + Returns: + Output array of shape `(*, H, W, C_out)`, where `C_out` is the + number of channels in the input. + """ + m_deterministic = nn.merge_param( + "deterministic", + self.deterministic, + deterministic, + ) + batch_dims = inputs.shape[:-3] + dims = chex.Dimensions( + H=inputs.shape[-3], + W=inputs.shape[-2], + C=inputs.shape[-1], + ) + skips = [] + + # forward pass the input convolution + conv_in = nn.Conv( + features=self.features, + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + name="conv_in", + ) + out = conv_in(inputs) + skips.append(out) + + # forward pass the downsampling path + for level, mult in enumerate(self.ch_mults): + out_ch = self.features * mult + for i in range(self.num_res_blocks): + res_block = ResNetBlock( + features=out_ch, + num_groups=self.num_groups, + dropout_rate=self.dropout_rate, + epsilon=self.epsilon, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name=f"down_resnet_{level + 1:d}_{i + 1:d}", + ) + out = res_block( + inputs=out, + cond=cond, + deterministic=m_deterministic, + ) + if out.shape[-3] in self.attn_resolutions: + # TODO (juanwulu): Attention block would go here + raise NotImplementedError + skips.append(out) + if level != len(self.ch_mults) - 1: + downsample = DownsampleBlock( + with_conv=True, + dtype=self.dtype, + param_dtype=self.param_dtype, + name=f"downsample_{level + 1:d}", + ) + out = downsample(out) + skips.append(out) + + # forward pass the middle blocks + block = ResNetBlock( + features=out.shape[-1], + num_groups=self.num_groups, + epsilon=self.epsilon, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name="mid_resnet_1", + ) + out = block(out, cond=cond, deterministic=m_deterministic) + # TODO (juanwulu): Attention block would go here + block = ResNetBlock( + features=out.shape[-1], + num_groups=self.num_groups, + epsilon=self.epsilon, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name="mid_resnet_2", + ) + out = block(out, cond=cond, deterministic=m_deterministic) + + # forward pass the upsampling path + for level, mult in reversed(list(enumerate(self.ch_mults))): + out_ch = self.features * mult + for i in range(self.num_res_blocks + 1): + skip = skips.pop() + out = jnp.concatenate([out, skip], axis=-1) + res_block = ResNetBlock( + features=out_ch, + dropout_rate=self.dropout_rate, + num_groups=self.num_groups, + epsilon=self.epsilon, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name=f"up_resnet_{level + 1:d}_{i + 1:d}", + ) + out = res_block( + inputs=out, + cond=cond, + deterministic=m_deterministic, + ) + if out.shape[-3] in self.attn_resolutions: + # TODO (juanwulu): Attention block would go here + raise NotImplementedError + if level != 0: + upsample = UpsampleBlock( + with_conv=True, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name=f"upsample_{level + 1:d}", + ) + out = upsample(out) + + # forward pass the output convolution + norm_out = nn.GroupNorm( + num_groups=self.num_groups, + epsilon=self.epsilon, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="norm_out", + ) + out = nn.silu(norm_out(out)) + conv_out = nn.Conv( + features=dims.C, # type: ignore + kernel_size=(3, 3), + strides=(1, 1), + padding=(1, 1), + kernel_init=jax.nn.initializers.zeros, + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + name="conv_out", + ) + out = conv_out(out) + chex.assert_shape(out, (*batch_dims, *dims["HWC"])) + + return out From e3e0c4f93b7c16b4c969e4a3c1a140841008dc2f Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 06:17:02 -0600 Subject: [PATCH 37/67] feat: Added implementation for scaled dot-product attention block Signed-off-by: Juanwu --- src/projects/generative/model/test_unet.py | 19 +++ src/projects/generative/model/unet.py | 153 +++++++++++++++++++++ 2 files changed, 172 insertions(+) diff --git a/src/projects/generative/model/test_unet.py b/src/projects/generative/model/test_unet.py index 8f5e2c1..45c69c9 100644 --- a/src/projects/generative/model/test_unet.py +++ b/src/projects/generative/model/test_unet.py @@ -97,6 +97,25 @@ def test_upsample_block(with_conv: bool, dtype: typing.Any) -> None: assert outputs.dtype == dtype +@pytest.mark.parametrize("num_heads", [1, 4]) +@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) +def test_attn_block(num_heads: int, dtype: typing.Any) -> None: + r"""Tests the attention block in U-Net models.""" + rng = jax.random.PRNGKey(42) + + block = unet.AttnBlock(num_heads=num_heads, dtype=dtype, param_dtype=dtype) + test_input = jnp.ones((2, 16, 16, 32), dtype=dtype) + variables = block.init( + rngs={"params": rng}, + inputs=test_input, + ) + + outputs = block.apply(variables=variables, inputs=test_input) + assert isinstance(outputs, jax.Array) + assert outputs.shape == (2, 16, 16, 32) + assert outputs.dtype == dtype + + @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) def test_score_net(dtype: typing.Any) -> None: r"""Tests the full U-Net model for score-based generative modeling.""" diff --git a/src/projects/generative/model/unet.py b/src/projects/generative/model/unet.py index e65bb34..b57e4e1 100644 --- a/src/projects/generative/model/unet.py +++ b/src/projects/generative/model/unet.py @@ -274,6 +274,159 @@ def __call__(self, inputs: jax.Array) -> jax.Array: return out +class AttnBlock(nn.Module): + r"""Self-attention block with group normalization in U-Net models. + + Args: + """ + + num_heads: int = 1 + dtype: typing.Any = None + param_dtype: typing.Any = None + precision: typing.Any = None + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + r"""Forward pass of the `AttnBlock`. + + Args: + inputs (jax.Array): Input array of shape `(*, H, W, C)`. + + Returns: + Output array of shape `(*, H, W, C)`. + """ + + norm_in = nn.GroupNorm( + num_groups=32, + epsilon=1e-5, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="norm", + ) + out = norm_in(inputs) + + if self.num_heads == 1: + # scaled dot-product attention + q_proj = nn.Dense( + features=inputs.shape[-1], + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="q_proj", + ) + query = q_proj(out) + k_proj = nn.Dense( + features=inputs.shape[-1], + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="k_proj", + ) + key = k_proj(out) + v_proj = nn.Dense( + features=inputs.shape[-1], + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="v_proj", + ) + value = v_proj(out) + out = nn.dot_product_attention( + query[..., None, :], + key[..., None, :], + value[..., None, :], + broadcast_dropout=False, + dropout_rate=0.0, + dtype=self.dtype, + precision=self.precision, + ) + print("out shape after attention:", out.shape) + out_proj = nn.Dense( + features=inputs.shape[-1], + kernel_init=jax.nn.initializers.zeros, + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="out_proj", + ) + out = out_proj(out[..., 0, :]) + else: + head_dim = inputs.shape[-1] // self.num_heads + if head_dim * self.num_heads != inputs.shape[-1]: + raise ValueError( + f"Number of heads {self.num_heads} not compatible with " + f"input channels {inputs.shape[-1]}." + ) + q_proj = nn.DenseGeneral( + features=(self.num_heads, head_dim), + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="q_proj", + ) + query = q_proj(out) + k_proj = nn.DenseGeneral( + features=(self.num_heads, head_dim), + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="k_proj", + ) + key = k_proj(out) + v_proj = nn.DenseGeneral( + features=(self.num_heads, head_dim), + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="v_proj", + ) + value = v_proj(out) + out_proj = nn.DenseGeneral( + features=inputs.shape[-1], + kernel_init=jax.nn.initializers.zeros, + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="out_proj", + ) + out = out_proj(out) + + chex.assert_equal_shape([out, inputs]) + out = out + inputs + + return out + + class ScoreNet(nn.Module): r"""U-Net architecture for score-function estimation. From 19964baf0d9c3aaea58dbaa07dc1e517e5328936 Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 06:19:42 -0600 Subject: [PATCH 38/67] feat: Integrate attention block to score U-Net architecture Signed-off-by: Juanwu --- src/projects/generative/model/unet.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/projects/generative/model/unet.py b/src/projects/generative/model/unet.py index b57e4e1..62ee7f6 100644 --- a/src/projects/generative/model/unet.py +++ b/src/projects/generative/model/unet.py @@ -535,8 +535,14 @@ def __call__( deterministic=m_deterministic, ) if out.shape[-3] in self.attn_resolutions: - # TODO (juanwulu): Attention block would go here - raise NotImplementedError + block = AttnBlock( + num_heads=1, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name=f"down_attn_{level + 1:d}_{i + 1:d}", + ) + out = block(out) skips.append(out) if level != len(self.ch_mults) - 1: downsample = DownsampleBlock( @@ -559,7 +565,14 @@ def __call__( name="mid_resnet_1", ) out = block(out, cond=cond, deterministic=m_deterministic) - # TODO (juanwulu): Attention block would go here + block = AttnBlock( + num_heads=1, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name="mid_attn", + ) + out = block(out) block = ResNetBlock( features=out.shape[-1], num_groups=self.num_groups, From 76781b3143c2bb43053af049fcd28f34856091ab Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 06:28:31 -0600 Subject: [PATCH 39/67] hotfix: Adds missing attention block in upsampling path of U-Net Signed-off-by: Juanwu --- src/projects/generative/model/test_unet.py | 1 - src/projects/generative/model/unet.py | 11 ++++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/projects/generative/model/test_unet.py b/src/projects/generative/model/test_unet.py index 45c69c9..82fe4dd 100644 --- a/src/projects/generative/model/test_unet.py +++ b/src/projects/generative/model/test_unet.py @@ -123,7 +123,6 @@ def test_score_net(dtype: typing.Any) -> None: model = unet.ScoreNet( features=128, - attn_resolutions=(), dropout_rate=0.2, dtype=dtype, param_dtype=dtype, diff --git a/src/projects/generative/model/unet.py b/src/projects/generative/model/unet.py index 62ee7f6..6c6d673 100644 --- a/src/projects/generative/model/unet.py +++ b/src/projects/generative/model/unet.py @@ -355,7 +355,6 @@ def __call__(self, inputs: jax.Array) -> jax.Array: dtype=self.dtype, precision=self.precision, ) - print("out shape after attention:", out.shape) out_proj = nn.Dense( features=inputs.shape[-1], kernel_init=jax.nn.initializers.zeros, @@ -606,8 +605,14 @@ def __call__( deterministic=m_deterministic, ) if out.shape[-3] in self.attn_resolutions: - # TODO (juanwulu): Attention block would go here - raise NotImplementedError + block = AttnBlock( + num_heads=1, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name=f"up_attn_{level + 1:d}_{i + 1:d}", + ) + out = block(out) if level != 0: upsample = UpsampleBlock( with_conv=True, From 6716140f0da461c185f4fa1a304c53957c6189b0 Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 06:29:19 -0600 Subject: [PATCH 40/67] feat: Integrates score-based U-Net for meanflow experiment on CIFAR-10 Signed-off-by: Juanwu --- src/projects/generative/BUILD | 2 +- src/projects/generative/config.py | 2 +- src/projects/generative/meanflow.py | 23 ++++++++++++++++------- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/projects/generative/BUILD b/src/projects/generative/BUILD index 9397e94..a929e1d 100644 --- a/src/projects/generative/BUILD +++ b/src/projects/generative/BUILD @@ -45,6 +45,6 @@ ml_py_library( "typing_extensions", "//src/core:model", "//src/core:train_state", - "//src/projects/generative/model:refinenet", + "//src/projects/generative/model:unet", ], ) diff --git a/src/projects/generative/config.py b/src/projects/generative/config.py index 79651f7..9b52001 100644 --- a/src/projects/generative/config.py +++ b/src/projects/generative/config.py @@ -40,7 +40,7 @@ def meanflow_unet_cifar_10() -> _config.ExperimentConfig: meanflow.MeanFlowUNetModel, in_channels=3, image_size=32, - latent_channels=16, + latent_channels=128, num_classes=10, use_cfg_embedding=False, dropout_rate=0.2, diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 1de1da9..45e3e50 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -10,7 +10,7 @@ import typing_extensions from src.core import model as _model -from src.projects.generative.model import refinenet +from src.projects.generative.model import unet # Type Aliases PyTree = jaxtyping.PyTree @@ -392,11 +392,16 @@ class MeanFlowUNetModule(nn.Module): def setup(self) -> None: r"""Instantiate a `MeanFlowUNetModel` module.""" - self.backbone = refinenet.ConditionalRefineNet( - in_channels=self.in_channels, - image_size=self.image_size, - latent_channels=self.latent_channels, - norm_module=ConditionalInstanceNorm, + # self.backbone = refinenet.ConditionalRefineNet( + # in_channels=self.in_channels, + # image_size=self.image_size, + # latent_channels=self.latent_channels, + # norm_module=ConditionalInstanceNorm, + # dtype=self.dtype, + # param_dtype=self.param_dtype, + # ) + self.backbone = unet.ScoreNet( + features=self.latent_channels, dtype=self.dtype, param_dtype=self.param_dtype, ) @@ -474,7 +479,11 @@ def __call__( else: t_emb = jnp.zeros_like(y_emb) cond = t_emb + r_emb + y_emb - output = self.backbone(inputs=image, cond=cond) + output = self.backbone( + inputs=image, + cond=cond, + deterministic=m_deterministic, + ) return output From 70bb67cb737e7157e4164bd968c9edf4b77ac2be Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 06:37:14 -0600 Subject: [PATCH 41/67] hotfix: Fixes issue of `dropout_rate` in U-Net for meanflow Signed-off-by: Juanwu --- src/projects/generative/meanflow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 45e3e50..6fc6125 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -402,6 +402,7 @@ def setup(self) -> None: # ) self.backbone = unet.ScoreNet( features=self.latent_channels, + dropout_rate=self.dropout_rate, dtype=self.dtype, param_dtype=self.param_dtype, ) From 27bc1ec101e97f1c1dc98b89b694b31c6916af13 Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 06:40:31 -0600 Subject: [PATCH 42/67] hotfix: Fixes issue of missing dropout rng in U-Net for meanflow Signed-off-by: Juanwu --- src/projects/generative/meanflow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 6fc6125..6258033 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -627,7 +627,7 @@ def compute_loss( # NOTE: following the notation in Algorithm 1 of the source paper # sample t and r batch_dims = image.shape[:-3] - tr_rng, mask_rng, e_rng = jax.random.split(rngs, num=3) + tr_rng, dropout_rng, mask_rng, e_rng = jax.random.split(rngs, num=4) t, r = sample_t_r( key=tr_rng, shape=batch_dims, @@ -682,7 +682,7 @@ def u_fn( ) out = self.network.apply( - variables={"params": params}, + variables={"params": params, "dropout": dropout_rng}, image=z_t, label=label, begin=b_arg, From d0e352a24f69203b6263364dd9b9d54926587195 Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 06:42:18 -0600 Subject: [PATCH 43/67] hotfix: Fixes issue of missing dropout rng in U-Net for meanflow Signed-off-by: Juanwu --- src/projects/generative/meanflow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 6258033..12b704a 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -682,12 +682,13 @@ def u_fn( ) out = self.network.apply( - variables={"params": params, "dropout": dropout_rng}, + variables={"params": params}, image=z_t, label=label, begin=b_arg, end=e_arg, deterministic=deterministic, + rngs={"dropout": dropout_rng}, ) assert isinstance(out, jax.Array) From 9d45b9c05226f2b1e5cdfcf774aee78fe6b2ca78 Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 18:20:06 -0600 Subject: [PATCH 44/67] hotfix: Updated configurations for training U-Net on CIFAR-10 Signed-off-by: Juanwu --- src/projects/generative/config.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/projects/generative/config.py b/src/projects/generative/config.py index 9b52001..d0ca266 100644 --- a/src/projects/generative/config.py +++ b/src/projects/generative/config.py @@ -31,8 +31,8 @@ def meanflow_unet_cifar_10() -> _config.ExperimentConfig: ), ), ), - batch_size=1024, - num_workers=2, + batch_size=128, + num_workers=4, deterministic=True, drop_remainder=True, ), @@ -61,6 +61,7 @@ def meanflow_unet_cifar_10() -> _config.ExperimentConfig: optimizer=_config.OptimizerConfig( lr_schedule=fdl.Config(optax.constant_schedule, value=6e-4), optimizer=fdl.Partial(optax.adam, b1=0.9, b2=0.999), + ema_rate=0.9999, ), seed=42, ) From 03655eda897bc0556c7ca1e1d9b944af5545dc7b Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 18:22:36 -0600 Subject: [PATCH 45/67] hotfix: Fixed implementation of logit-normal timestamp sampler Signed-off-by: Juanwu --- src/projects/generative/meanflow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 12b704a..73cb2a0 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -70,7 +70,7 @@ def _lognormal( stddev: float, ) -> jax.Array: z = jax.random.normal(key=key, shape=shape, dtype=dtype) - return jnp.exp(mean + stddev * z) + return jax.nn.sigmoid(mean + stddev * z) mean = kwargs.get("mean", -0.4) stddev = kwargs.get("stddev", 1.0) @@ -94,7 +94,7 @@ def _lognormal( 'Must be one of ["uniform", "lognormal"].' ) - return jnp.clip(t, a_min=0.0, a_max=1.0), jnp.clip(r, a_min=0.0, a_max=1.0) + return jnp.clip(t, 0.0, 1.0), jnp.clip(r, 0.0, 1.0) # ============================================================================== From f307e8cc6b8aa7d87798b7bb6b18b78f500c52a8 Mon Sep 17 00:00:00 2001 From: Juanwu Date: Fri, 28 Nov 2025 18:28:02 -0600 Subject: [PATCH 46/67] hotfix: Fixed implementation of logit-normal timestamp sampler Signed-off-by: Juanwu --- src/projects/generative/config.py | 6 +++--- src/projects/generative/meanflow.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/projects/generative/config.py b/src/projects/generative/config.py index d0ca266..3a9897a 100644 --- a/src/projects/generative/config.py +++ b/src/projects/generative/config.py @@ -31,8 +31,8 @@ def meanflow_unet_cifar_10() -> _config.ExperimentConfig: ), ), ), - batch_size=128, - num_workers=4, + batch_size=1024, + num_workers=2, deterministic=True, drop_remainder=True, ), @@ -45,7 +45,7 @@ def meanflow_unet_cifar_10() -> _config.ExperimentConfig: use_cfg_embedding=False, dropout_rate=0.2, timestamp_cond="t_and_t_minus_r", - timestamp_sampler="lognormal", + timestamp_sampler="logit-normal", timestamp_sampler_kwargs=dict(mean=-2.0, stddev=2.0), timestamp_overlap_rate=0.25, adaptive_weight_power=0.75, diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 73cb2a0..66a1f83 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -34,7 +34,7 @@ def sample_t_r( shape (jax.typing.Shape): The shape of the output arrays. dtype (dtype): The dtype of the output arrays. distribution (str): The distribution to sample from. - One of `["uniform", "lognormal"]`. + One of `["uniform", "logit-normal"]`. **kwargs: Additional keyword arguments for the distribution. Returns: @@ -60,9 +60,9 @@ def sample_t_r( minval=minval, maxval=maxval, ) - elif distribution == "lognormal": + elif distribution == "logit-normal": - def _lognormal( + def _logit_normal( key: jax.Array, shape: jax_typing.Shape, dtype: typing.Any, @@ -74,14 +74,14 @@ def _lognormal( mean = kwargs.get("mean", -0.4) stddev = kwargs.get("stddev", 1.0) - t = _lognormal( + t = _logit_normal( key=t_key, shape=shape, dtype=dtype, mean=mean, stddev=stddev, ) - r = _lognormal( + r = _logit_normal( key=r_key, shape=shape, dtype=dtype, @@ -91,7 +91,7 @@ def _lognormal( else: raise ValueError( f"Unsupported distribution: {distribution}. " - 'Must be one of ["uniform", "lognormal"].' + 'Must be one of ["uniform", "logit-normal"].' ) return jnp.clip(t, 0.0, 1.0), jnp.clip(r, 0.0, 1.0) @@ -505,7 +505,7 @@ class MeanFlowUNetModel(_model.Model): One of `["t_and_r", "t_and_t_minus_r", "t_and_r_and_t_minus_r", "t_minus_r"]`. timestamp_sampler (str): The distribution to sample timestamps from. - One of `["uniform", "lognormal"]`. + One of `["uniform", "logit-normal"]`. timestamp_sampler_kwargs (Dict[str, Any]): Additional keyword arguments for the timestamp sampler. timestamp_overlap_rate (float): The minimum overlap rate between @@ -530,7 +530,7 @@ def __init__( "t_and_r_and_t_minus_r", "t_minus_r", ] = "t_and_t_minus_r", - timestamp_sampler: str = "lognormal", + timestamp_sampler: str = "logit-normal", timestamp_sampler_kwargs: typing.Dict[str, typing.Any] = { "mean": -0.4, "stddev": 1.0, From 90b5d4e553093e662a8c6f0d048505358c8b487c Mon Sep 17 00:00:00 2001 From: Juanwu Date: Sat, 29 Nov 2025 01:31:13 -0600 Subject: [PATCH 47/67] hotfix: Fixed implementation for JAX in MacOS Signed-off-by: Juanwu --- third_party/defs.bzl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/third_party/defs.bzl b/third_party/defs.bzl index 387a995..baa820a 100644 --- a/third_party/defs.bzl +++ b/third_party/defs.bzl @@ -38,8 +38,13 @@ def _select_all_requirements(names = []): if "fiddle" in names and "etils" not in names: reqs += _select_requirement("etils") - if "jax" in names and "jaxlib" not in names: - reqs += _select_requirement("jaxlib") + if "jax" in names: + if "jaxlib" not in names: + reqs += _select_requirement("jaxlib") + reqs += select({ + "//third_party:is_mps": [mps_req("jax-metal")], + "//conditions:default": [], + }) return reqs From 12ee1023b35190aa17a47e1b4945448896b3d359 Mon Sep 17 00:00:00 2001 From: Juanwu Date: Sat, 29 Nov 2025 01:33:15 -0600 Subject: [PATCH 48/67] feat: Implements sinusoidal positional encoding for U-Net Signed-off-by: Juanwu --- src/projects/generative/config.py | 9 +++- src/projects/generative/meanflow.py | 69 ++++++++++++++++++++++------- 2 files changed, 60 insertions(+), 18 deletions(-) diff --git a/src/projects/generative/config.py b/src/projects/generative/config.py index 3a9897a..523bf2a 100644 --- a/src/projects/generative/config.py +++ b/src/projects/generative/config.py @@ -31,7 +31,7 @@ def meanflow_unet_cifar_10() -> _config.ExperimentConfig: ), ), ), - batch_size=1024, + batch_size=128, num_workers=2, deterministic=True, drop_remainder=True, @@ -59,7 +59,12 @@ def meanflow_unet_cifar_10() -> _config.ExperimentConfig: profile=False, ), optimizer=_config.OptimizerConfig( - lr_schedule=fdl.Config(optax.constant_schedule, value=6e-4), + lr_schedule=fdl.Config( + optax.warmup_constant_schedule, + init_value=1e-8, + peak_value=6e-4, + warmup_steps=10_000, + ), optimizer=fdl.Partial(optax.adam, b1=0.9, b2=0.999), ema_rate=0.9999, ), diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 66a1f83..421cfb9 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -100,6 +100,40 @@ def _logit_normal( # ============================================================================== # Helper modules # ============================================================================== +class SinusoidalEmbed(nn.Module): + r"""Sinusoidal positional embeddings. + + Args: + features (int): Dimensionality of the output embeddings. + max_indx (int): Maximum index value. + endpoint (bool): Whether to include the endpoint frequency. + """ + + features: int + max_indx: int = 10_000 + endpoint: bool = False + + def setup(self) -> None: + """Instantiate a `SinusoidalEmbed` module.""" + half_dim = self.features >> 1 + freqs = jnp.arange(0, half_dim, dtype=jnp.float32) + freqs = freqs / (half_dim - (1 if self.endpoint else 0)) + self.freqs = jnp.power(1.0 / self.max_indx, freqs) + + def __call__(self, inputs: jax.Array) -> jax.Array: + r"""Forward pass and returns the sinusoidal embeddings. + + Args: + inputs (jax.Array): Input indexes of shape `(*, )`. + + Returns: + Sinusoidal embedding array of shape `(..., features)`. + """ + out = jnp.outer(inputs[..., None], self.freqs) + out = jnp.concatenate([jnp.cos(out), jnp.sin(out)], axis=-1) + return out + + class TimestampEmbed(nn.Module): """Encode scalar timestamps to vectors. @@ -400,28 +434,31 @@ def setup(self) -> None: # dtype=self.dtype, # param_dtype=self.param_dtype, # ) + # self.r_embed = TimestampEmbed( + # features=self.latent_channels, + # frequency=256, + # max_stamp=10_000, + # name="r_embedder", + # dtype=self.dtype, + # param_dtype=self.param_dtype, + # ) + # self.t_embed = TimestampEmbed( + # features=self.latent_channels, + # frequency=256, + # max_stamp=10_000, + # name="t_embedder", + # dtype=self.dtype, + # param_dtype=self.param_dtype, + # ) + self.backbone = unet.ScoreNet( features=self.latent_channels, dropout_rate=self.dropout_rate, dtype=self.dtype, param_dtype=self.param_dtype, ) - self.r_embed = TimestampEmbed( - features=self.latent_channels, - frequency=256, - max_stamp=10_000, - name="r_embedder", - dtype=self.dtype, - param_dtype=self.param_dtype, - ) - self.t_embed = TimestampEmbed( - features=self.latent_channels, - frequency=256, - max_stamp=10_000, - name="t_embedder", - dtype=self.dtype, - param_dtype=self.param_dtype, - ) + self.r_embed = SinusoidalEmbed(self.latent_channels, endpoint=True) + self.t_embed = SinusoidalEmbed(self.latent_channels, endpoint=True) self.label_embed = ConditionEmbed( features=self.latent_channels, num_classes=self.num_classes, From 6f80f07732330e2a7c258026ea3d6999653fb828 Mon Sep 17 00:00:00 2001 From: Juanwu Date: Sat, 29 Nov 2025 21:31:31 -0600 Subject: [PATCH 49/67] hotfix: Updated step output to contain model output array Signed-off-by: Juanwu --- src/core/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/core/model.py b/src/core/model.py index 53169ad..7f4eb04 100644 --- a/src/core/model.py +++ b/src/core/model.py @@ -12,10 +12,12 @@ class StepOutputs: """A base container for outputs from a single step. Attributes: + output (Optional[jax.Array]): The main output of the model. scalars (Optional[Dict[str, Any]]): A dictionary of scalar metrics. images (Optional[Dict[str, Any]]): A dictionary of image outputs. """ + output: typing.Optional[jax.Array] = None scalars: typing.Optional[typing.Dict[str, typing.Any]] = None images: typing.Optional[typing.Dict[str, typing.Any]] = None From a42758a028560d69c7b4416b702ddf7a854843c0 Mon Sep 17 00:00:00 2001 From: Juanwu Date: Sat, 29 Nov 2025 21:32:03 -0600 Subject: [PATCH 50/67] feat: Updated grid visualization function to use jax array Signed-off-by: Juanwu --- src/utilities/BUILD | 3 ++- src/utilities/visualization.py | 38 ++++++++++++++++++++-------------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/utilities/BUILD b/src/utilities/BUILD index fb1fc18..40126c8 100644 --- a/src/utilities/BUILD +++ b/src/utilities/BUILD @@ -23,6 +23,7 @@ ml_py_library( name = "visualization", srcs = ["visualization.py"], deps = [ - "numpy", + "jax", + ":logging", ], ) diff --git a/src/utilities/visualization.py b/src/utilities/visualization.py index e810556..ffa8e14 100644 --- a/src/utilities/visualization.py +++ b/src/utilities/visualization.py @@ -1,45 +1,51 @@ import typing -import numpy as np -import numpy.typing as npt +import jax +from jax import numpy as jnp + +from src.utilities import logging def make_grid( - images: typing.Union[typing.Any, npt.NDArray], + images: jax.Array, n_rows: int = 8, + n_cols: int = 8, padding: int = 2, -) -> npt.NDArray: +) -> jax.Array: r"""Convert a batch of images into a grid for visualization. Args: - images (Any | NDArray): Batch of images with shape `(B, H, W, C)`. - n_rows (int, optional): Number of rows in grid. Default is :math:`8`. + images (jax.Array): Batch of images with shape `(B, H, W, C)`. + n_rows (int): Number of rows in grid. Default is :math:`8`. + n_cols (int): Number of columns in grid. Default is :math:`8`. padding (int, optional): Number of pixels between pair of images. Default is :math:`2`. Returns: The array containing a grid of input images. """ - images = np.asarray(images) - if images.ndim != 4: - raise ValueError( - "Input images must be a 4-D numpy array with shape (B, H, W, C). " - f"But got {images}." - ) - bz, h, w, c = images.shape - n_cols = int(np.ceil(bz / n_rows)) + images = jnp.reshape(images, (-1,) + images.shape[-3:]) + _, h, w, c = images.shape shape = ( h * n_rows + padding * (n_rows - 1), w * n_cols + padding * (n_cols - 1), c, ) - out = np.zeros(shape, dtype=images.dtype) + out = jnp.zeros(shape, dtype=images.dtype) for idx, img in enumerate(images): row = idx // n_cols col = idx % n_cols top = row * (h + padding) left = col * (w + padding) - out[top : top + h, left : left + w] = img + out = out.at[top : top + h, left : left + w].set(img) + + if idx + 1 >= n_rows * n_cols: + logging.rank_zero_warning( + "Number of images exceed grid capacity; " + + "only the first %d images are used.", + n_rows * n_cols, + ) + break return out From cef9f5a2c041a6ddf4efbddc5fb9146f59b3a67f Mon Sep 17 00:00:00 2001 From: Juanwu Date: Sat, 29 Nov 2025 21:32:51 -0600 Subject: [PATCH 51/67] feat: Implements the evaluation step for meanflow with visualization Signed-off-by: Juanwu --- src/projects/generative/BUILD | 1 + src/projects/generative/config.py | 4 +-- src/projects/generative/main.py | 30 +++++++++++++++--- src/projects/generative/meanflow.py | 47 ++++++++++++++++++++--------- 4 files changed, 61 insertions(+), 21 deletions(-) diff --git a/src/projects/generative/BUILD b/src/projects/generative/BUILD index a929e1d..b6d261a 100644 --- a/src/projects/generative/BUILD +++ b/src/projects/generative/BUILD @@ -31,6 +31,7 @@ ml_py_binary( "//src/core:train", "//src/core:train_state", "//src/utilities:logging", + "//src/utilities:visualization", ], ) diff --git a/src/projects/generative/config.py b/src/projects/generative/config.py index 523bf2a..c3c25b8 100644 --- a/src/projects/generative/config.py +++ b/src/projects/generative/config.py @@ -52,9 +52,9 @@ def meanflow_unet_cifar_10() -> _config.ExperimentConfig: ), trainer=_config.TrainerConfig( num_train_steps=800_000, - log_every_n_steps=5, + log_every_n_steps=50, checkpoint_every_n_steps=10_000, # save every 10k steps - eval_every_n_steps=1_000_000, # NOTE: never evaluate now + eval_every_n_steps=1_000, max_checkpoints_to_keep=3, profile=False, ), diff --git a/src/projects/generative/main.py b/src/projects/generative/main.py index 123194f..232dc32 100644 --- a/src/projects/generative/main.py +++ b/src/projects/generative/main.py @@ -11,6 +11,7 @@ from fiddle import absl_flags import fiddle as fdl import jax +from jax import numpy as jnp import jaxtyping import optax import tensorflow as tf @@ -21,6 +22,7 @@ from src.core import train as _train from src.core import train_state as _train_state from src.utilities import logging +from src.utilities import visualization CONFIG = absl_flags.DEFINE_fiddle_config( name="experiment", @@ -44,9 +46,28 @@ assert not tf.config.experimental.get_visible_devices("GPU") -def evaluation_step(rngs: jax.Array) -> _model.StepOutputs: +def evaluation_step( + rngs: jax.Array, + model: _model.Model, + params: PyTree, + batch: PyTree, + **kwargs, +) -> _model.StepOutputs: r"""Conduct a single evaluation step and compute metrics.""" - raise NotImplementedError + local_rng = jax.random.fold_in(rngs, jax.lax.axis_index("batch")) + outputs = model.forward( + rngs=local_rng, + params=params, + deterministic=True, + batch=batch, + **kwargs, + ) + out = outputs.output + assert isinstance(out, jax.Array) + out = jnp.clip(out * 0.5 + 0.5, 0.0, 1.0) + img_grid = visualization.make_grid(out, n_rows=4, n_cols=8, padding=2) + outputs.images = {"sampled images": img_grid} + return outputs def training_step( @@ -65,7 +86,7 @@ def loss_fn(params: PyTree) -> typing.Tuple[jax.Array, _model.StepOutputs]: rngs=local_rng, params=params, deterministic=False, - **batch, + batch=batch, **kwargs, ) return loss, outputs @@ -187,11 +208,12 @@ def main(_: typing.List[str]) -> int: if exp_config.mode == "train": p_training_step = functools.partial(training_step, model=model) + p_evaluation_step = functools.partial(evaluation_step, model=model) _train.run( state=state, datamodule=datamodule, training_step=p_training_step, - evaluation_step=evaluation_step, + evaluation_step=p_evaluation_step, num_train_steps=exp_config.trainer.num_train_steps, writer=writer, work_dir=log_dir, diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 421cfb9..9567033 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -639,8 +639,7 @@ def compute_loss( self, *, rngs: typing.Any, - image: jax.Array, - label: jax.Array, + batch: typing.Dict[str, typing.Any], params: frozen_dict.FrozenDict, deterministic: bool = False, **kwargs, @@ -650,8 +649,9 @@ def compute_loss( Args: rngs (Union[jax.random.KeyArray, Dict[str, jax.random.KeyArray]]): JAX random key or a dictionary of JAX random keys. - image (jax.Array): The input images of shape `(*, H, W, C)`. - label (jax.Array): The class labels of shape `(*,)`. + batch (Dict[str, Any]): A batch of data containing: + - image (jax.Array): Input images of shape `(*, H, W, C)`. + - label (jax.Array): Conditioning labels of shape `(*, )`. params (frozen_dict.FrozenDict): The model parameters. deterministic (bool): Whether to run the model deterministically. **kwargs: additional keyword arguments. @@ -663,6 +663,7 @@ def compute_loss( # NOTE: following the notation in Algorithm 1 of the source paper # sample t and r + image, label = batch["image"], batch["label"] batch_dims = image.shape[:-3] tr_rng, dropout_rng, mask_rng, e_rng = jax.random.split(rngs, num=4) t, r = sample_t_r( @@ -780,23 +781,22 @@ def u_fn( def forward( self, *, - rngs: typing.Any, + rngs: jax.Array, params: frozen_dict.FrozenDict, - image: jax.Array, - label: jax.Array, - begin: typing.Optional[jax.Array] = None, - end: typing.Optional[jax.Array] = None, - deterministic: bool = False, + batch: typing.Dict[str, typing.Any], + deterministic: bool = True, **kwargs, ) -> _model.StepOutputs: r"""Forward sampling with average velocity prediction. Args: + rngs (jax.Array): Random key for sampling. params (frozen_dict.FrozenDict): The model parameters. - image (jax.Array): Input latent image `z_t` of shape `(*, H, W, C)`. - label (jax.Array): Conditioning labels of shape `(*,)`. - begin (jax.Array): Begin timestamp `r` of shape `(*, )`. - end (jax.Array): End timestamp `t` of shape `(*, )`. + batch (Dict[str, Any]): A batch of data containing: + - image (jax.Array): Input images of shape `(*, H, W, C)`. + - label (jax.Array): Conditioning labels of shape `(*, )`. + shape (jax.typing.Shape): The shape of the output samples. + dtype (Any): The dtype of the output samples. deterministic (bool): Whether to run the model deterministically. **kwargs: Additional keyword arguments. @@ -805,4 +805,21 @@ def forward( """ del kwargs # unused - raise NotImplementedError + # TODO (juanwulu): unconditional generation + image = batch["image"] + label = batch.get("label", None) + shape, dtype = image.shape, image.dtype + + e = jax.random.normal(key=rngs, shape=shape, dtype=dtype) + r = jnp.zeros(e.shape[:-3], dtype=dtype) + t = jnp.ones(e.shape[:-3], dtype=dtype) + out = e - self.network.apply( + variables={"params": params}, + image=e, + label=label, + begin=t - r, + end=t, + deterministic=deterministic, + ) + + return _model.StepOutputs(output=out) From de1f95c6414d0b38ab41c3ff7e5e088329505caf Mon Sep 17 00:00:00 2001 From: Juanwu Date: Sun, 30 Nov 2025 04:58:33 -0600 Subject: [PATCH 52/67] feat: Moved evaluation to before the training inner loop Signed-off-by: Juanwu --- src/core/train.py | 80 +++++++++++++++++++++++++---------------------- 1 file changed, 43 insertions(+), 37 deletions(-) diff --git a/src/core/train.py b/src/core/train.py index 969b9fc..02315ff 100644 --- a/src/core/train.py +++ b/src/core/train.py @@ -113,6 +113,49 @@ def run( train_metrics = collections.defaultdict(list) while True: for batch in datamodule.train_dataloader(): + # evaluation and sanity check running + if ( + step % eval_every_n_steps == 0 + or step == num_train_steps + ): + logging.rank_zero_info("Running evaluation...") + eval_metrics = collections.defaultdict(list) + outputs = None + for batch in datamodule.eval_dataloader(): + batch = _shard(batch) + outputs = p_evaluation_step( + params=state.params, + batch=batch, + ) + if not isinstance(outputs, _model.StepOutputs): + raise RuntimeError( + "FATAL: Output from `evaluation_step` is " + "not a `StepOutputs` object." + ) + if outputs.scalars is not None: + for k, v in outputs.scalars.items(): + eval_metrics[k].append( + jax.device_get(v).mean() + ) + logging.rank_zero_info("Evaluation done.") + + if isinstance(outputs, _model.StepOutputs): + writer.write_scalars( + step=step, + scalars={ + f"eval/{k}": sum(v) / len(v) + for k, v in eval_metrics.items() + }, + ) + if outputs.images is not None: + writer.write_images( + step=step, + images={ + f"eval/{k}": v + for k, v in outputs.images.items() + }, + ) + batch = _shard(batch) with jax.profiler.StepTraceAnnotation( name="train", @@ -148,43 +191,6 @@ def run( ) step += 1 - # evaluation - if ( - step % eval_every_n_steps == 0 - or step == num_train_steps - ): - logging.rank_zero_info("Running evaluation...") - eval_metrics = collections.defaultdict(list) - for batch in datamodule.eval_dataloader(): - batch = _shard(batch) - outputs = p_evaluation_step( - params=state.params, - batch=batch, - ) - if not isinstance(outputs, _model.StepOutputs): - raise RuntimeError( - "FATAL: Output from `evaluation_step` is " - "not a `StepOutputs` object." - ) - if outputs.scalars is not None: - for k, v in outputs.scalars.items(): - eval_metrics[k].append( - jax.device_get(v).mean() - ) - logging.rank_zero_info("Evaluation done.") - writer.write_scalars( - step=step, - scalars={ - f"eval/{k}": sum(v) / len(v) - for k, v in eval_metrics.items() - }, - ) - if outputs.images is not None: - writer.write_images( - step=step, - images=outputs.images, - ) - # checkpointing if step % checkpoint_every_n_steps == 0: logging.rank_zero_info("Checkpointing...") From 74e69ce469bd78e8fd0b8d2d0663997fdb951263 Mon Sep 17 00:00:00 2001 From: Juanwu Date: Sun, 30 Nov 2025 05:38:17 -0600 Subject: [PATCH 53/67] feat: Added random left-right flip in training loop Signed-off-by: Juanwu --- src/projects/generative/meanflow.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 9567033..5df29c6 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -663,9 +663,21 @@ def compute_loss( # NOTE: following the notation in Algorithm 1 of the source paper # sample t and r - image, label = batch["image"], batch["label"] + image = batch["image"] + assert isinstance(image, jax.Array) + label = batch.get("label", None) batch_dims = image.shape[:-3] - tr_rng, dropout_rng, mask_rng, e_rng = jax.random.split(rngs, num=4) + tr_rng, dropout_rng, f_rng, m_rng, e_rng = jax.random.split(rngs, 5) + + # randomly flip image horizontally for data augmentation + flip_mask = jax.random.bernoulli(key=f_rng, p=0.5, shape=batch_dims) + image = jnp.where( + flip_mask[..., None, None, None], + jnp.flip(image, axis=-2), + image, + ) + + # sample begin and end timestamps t, r = sample_t_r( key=tr_rng, shape=batch_dims, @@ -676,17 +688,11 @@ def compute_loss( t, r = jnp.maximum(t, r), jnp.minimum(t, r) # ensure a portion of overlap between t and r # NOTE: the following code randomly mask by uniform samples - r_neq_t_mask = jnp.greater_equal( - jax.random.uniform( - key=mask_rng, - shape=batch_dims, - dtype=image.dtype, - minval=0.0, - maxval=1.0, - ), + r_eq_t_mask = jnp.less( + jax.random.uniform(key=m_rng, shape=batch_dims, dtype=image.dtype), self.timestamp_overlap_rate, ) - r = jnp.where(r_neq_t_mask, r, t) + r = jnp.where(r_eq_t_mask, t, r) # sample e ~ N(0, I) e = jax.random.normal(key=e_rng, shape=image.shape, dtype=image.dtype) From 000f1a6b2a492afbe5ff171a7b5b36f26fe241eb Mon Sep 17 00:00:00 2001 From: Juanwu Date: Sun, 30 Nov 2025 08:16:12 -0600 Subject: [PATCH 54/67] feat: Updated implementation for MeanFlow network and remove label conditions Signed-off-by: Juanwu --- src/projects/generative/meanflow.py | 30 +++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 5df29c6..3e5325a 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -473,18 +473,19 @@ def setup(self) -> None: def __call__( self, image: jax.Array, - label: jax.Array, begin: typing.Optional[jax.Array] = None, end: typing.Optional[jax.Array] = None, + label: typing.Optional[jax.Array] = None, deterministic: typing.Optional[bool] = None, ) -> jax.Array: r"""Forward pass the `MeanFlowUNetModel`. Args: inputs (jax.Array): Input images of shape `(*, H, W, C)`. - cond (jax.Array): Conditioning labels of shape `(*,)`. begin (jax.Array): Begin timestamp `r` of shape `(*, )`. end (jax.Array): End timestamp `t` of shape `(*, )`. + label (jax.Array): Integer class labels of shape `(*, )`. + deterministic (bool, optional): Whether to run deterministically. Returns: The predicted average velocity of shape `(*, H, W, C)`. @@ -507,7 +508,13 @@ def __call__( deterministic, ) - y_emb = self.label_embed(label, deterministic=m_deterministic) + if label is not None: + y_emb = self.label_embed(label, deterministic=m_deterministic) + else: + y_emb = jnp.zeros( + shape=(*batch_dims, self.latent_channels), + dtype=image.dtype, + ) if begin is not None: r_emb = self.r_embed(begin) else: @@ -665,7 +672,6 @@ def compute_loss( # sample t and r image = batch["image"] assert isinstance(image, jax.Array) - label = batch.get("label", None) batch_dims = image.shape[:-3] tr_rng, dropout_rng, f_rng, m_rng, e_rng = jax.random.split(rngs, 5) @@ -728,7 +734,6 @@ def u_fn( out = self.network.apply( variables={"params": params}, image=z_t, - label=label, begin=b_arg, end=e_arg, deterministic=deterministic, @@ -742,11 +747,7 @@ def u_fn( drdt = jnp.zeros_like(r) dtdt = jnp.ones_like(t) u, dudt = jax.jvp(u_fn, (z, r, t), (v, drdt, dtdt)) - u_target = jax.lax.stop_gradient( - v - - jnp.clip(t - r, a_min=0.0, a_max=1.0)[..., None, None, None] - * dudt - ) + u_target = v - (t - r)[..., None, None, None] * dudt # NOTE: following the symmetric meanflow # drdt = jnp.ones_like(r) @@ -761,11 +762,14 @@ def u_fn( # computes the target # NOTE: sum over all the pixels, following official implementation - loss = jnp.sum(jnp.square(u - u_target), axis=(-1, -2, -3)) + loss = jnp.sum( + jnp.square(u - jax.lax.stop_gradient(u_target)), + axis=(-1, -2, -3), + ) # applies adaptive weight power if self.adaptive_weight_power > 0.0: - ada_wt = jnp.power(loss + 1e-2, self.adaptive_weight_power) + ada_wt = jnp.power(loss + 1e-3, self.adaptive_weight_power) loss = loss / jax.lax.stop_gradient(ada_wt) loss = jnp.mean(loss) @@ -813,7 +817,6 @@ def forward( # TODO (juanwulu): unconditional generation image = batch["image"] - label = batch.get("label", None) shape, dtype = image.shape, image.dtype e = jax.random.normal(key=rngs, shape=shape, dtype=dtype) @@ -822,7 +825,6 @@ def forward( out = e - self.network.apply( variables={"params": params}, image=e, - label=label, begin=t - r, end=t, deterministic=deterministic, From d8e4b763c6ffcf4238c6b3292198715e2b4d4140 Mon Sep 17 00:00:00 2001 From: Juanwu Date: Sun, 30 Nov 2025 08:21:00 -0600 Subject: [PATCH 55/67] hotfix: Fixed error raised by wrong shape checking Signed-off-by: Juanwu --- src/projects/generative/meanflow.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 3e5325a..f6fa62c 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -498,9 +498,6 @@ def __call__( C=self.in_channels, ) chex.assert_shape(image, (*batch_dims, *dims["HWC"])) - chex.assert_shape(label, batch_dims) - chex.assert_shape(begin, batch_dims) - chex.assert_shape(end, batch_dims) m_deterministic = nn.merge_param( "deterministic", @@ -509,6 +506,7 @@ def __call__( ) if label is not None: + chex.assert_shape(label, batch_dims) y_emb = self.label_embed(label, deterministic=m_deterministic) else: y_emb = jnp.zeros( @@ -516,10 +514,12 @@ def __call__( dtype=image.dtype, ) if begin is not None: + chex.assert_shape(begin, batch_dims) r_emb = self.r_embed(begin) else: r_emb = jnp.zeros_like(y_emb) if end is not None: + chex.assert_shape(end, batch_dims) t_emb = self.t_embed(end) else: t_emb = jnp.zeros_like(y_emb) From a50c12e4db63889f11931d8752c56b16242b33f2 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Mon, 1 Dec 2025 04:12:50 -0500 Subject: [PATCH 56/67] feat: Fixed training collapse by adding fc layers for timestamp conditions Signed-off-by: Juanwu Lu --- src/projects/generative/meanflow.py | 94 ++++++++++++--------------- src/projects/generative/model/unet.py | 1 - 2 files changed, 42 insertions(+), 53 deletions(-) diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index f6fa62c..2668940 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -130,7 +130,7 @@ def __call__(self, inputs: jax.Array) -> jax.Array: Sinusoidal embedding array of shape `(..., features)`. """ out = jnp.outer(inputs[..., None], self.freqs) - out = jnp.concatenate([jnp.cos(out), jnp.sin(out)], axis=-1) + out = jnp.concatenate([jnp.sin(out), jnp.cos(out)], axis=-1) return out @@ -426,48 +426,42 @@ class MeanFlowUNetModule(nn.Module): def setup(self) -> None: r"""Instantiate a `MeanFlowUNetModel` module.""" - # self.backbone = refinenet.ConditionalRefineNet( - # in_channels=self.in_channels, - # image_size=self.image_size, - # latent_channels=self.latent_channels, - # norm_module=ConditionalInstanceNorm, - # dtype=self.dtype, - # param_dtype=self.param_dtype, - # ) - # self.r_embed = TimestampEmbed( - # features=self.latent_channels, - # frequency=256, - # max_stamp=10_000, - # name="r_embedder", - # dtype=self.dtype, - # param_dtype=self.param_dtype, - # ) - # self.t_embed = TimestampEmbed( - # features=self.latent_channels, - # frequency=256, - # max_stamp=10_000, - # name="t_embedder", - # dtype=self.dtype, - # param_dtype=self.param_dtype, - # ) - + # backbone U-Net model self.backbone = unet.ScoreNet( features=self.latent_channels, dropout_rate=self.dropout_rate, dtype=self.dtype, param_dtype=self.param_dtype, ) - self.r_embed = SinusoidalEmbed(self.latent_channels, endpoint=True) - self.t_embed = SinusoidalEmbed(self.latent_channels, endpoint=True) - self.label_embed = ConditionEmbed( + + # conditional embeddings + self.time_embed = SinusoidalEmbed( + self.latent_channels * 2, + endpoint=True, + ) + self.cond_in = nn.Dense( features=self.latent_channels, - num_classes=self.num_classes, - use_cfg_embedding=self.use_cfg_embedding, - deterministic=self.deterministic, - dropout_rate=self.dropout_rate, - name="y_embedder", + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, dtype=self.dtype, param_dtype=self.param_dtype, + name="cond_fc_1", + ) + self.cond_out = nn.Dense( + features=self.latent_channels, + kernel_init=jax.nn.initializers.variance_scaling( + scale=1.0, + mode="fan_avg", + distribution="uniform", + ), + bias_init=jax.nn.initializers.zeros, + dtype=self.dtype, + param_dtype=self.param_dtype, + name="cond_fc_2", ) def __call__( @@ -475,7 +469,6 @@ def __call__( image: jax.Array, begin: typing.Optional[jax.Array] = None, end: typing.Optional[jax.Array] = None, - label: typing.Optional[jax.Array] = None, deterministic: typing.Optional[bool] = None, ) -> jax.Array: r"""Forward pass the `MeanFlowUNetModel`. @@ -484,12 +477,12 @@ def __call__( inputs (jax.Array): Input images of shape `(*, H, W, C)`. begin (jax.Array): Begin timestamp `r` of shape `(*, )`. end (jax.Array): End timestamp `t` of shape `(*, )`. - label (jax.Array): Integer class labels of shape `(*, )`. deterministic (bool, optional): Whether to run deterministically. Returns: The predicted average velocity of shape `(*, H, W, C)`. """ + # sanity check for the input arrays batch_dims = image.shape[:-3] dims = chex.Dimensions( @@ -505,25 +498,26 @@ def __call__( deterministic, ) - if label is not None: - chex.assert_shape(label, batch_dims) - y_emb = self.label_embed(label, deterministic=m_deterministic) + if begin is not None: + chex.assert_shape(begin, batch_dims) + r_emb = self.time_embed(begin) else: - y_emb = jnp.zeros( + r_emb = jnp.zeros( shape=(*batch_dims, self.latent_channels), dtype=image.dtype, ) - if begin is not None: - chex.assert_shape(begin, batch_dims) - r_emb = self.r_embed(begin) - else: - r_emb = jnp.zeros_like(y_emb) if end is not None: chex.assert_shape(end, batch_dims) - t_emb = self.t_embed(end) + t_emb = self.time_embed(end) else: - t_emb = jnp.zeros_like(y_emb) - cond = t_emb + r_emb + y_emb + t_emb = jnp.zeros( + shape=(*batch_dims, self.latent_channels), + dtype=image.dtype, + ) + cond = jnp.concatenate([t_emb, r_emb], axis=-1) + cond = jax.nn.silu(self.cond_in(cond)) + cond = jax.nn.silu(self.cond_out(cond)) + output = self.backbone( inputs=image, cond=cond, @@ -624,14 +618,12 @@ def init( (1, self.image_size, self.image_size, self.in_channels), dtype=jnp.float32, ), - "label": jnp.zeros((1,), dtype=jnp.int32), "begin": jnp.zeros((1,), dtype=jnp.float32), "end": jnp.zeros((1,), dtype=jnp.float32), } variables = self.network.init( rngs=rngs, image=dummy_inputs["image"], - label=dummy_inputs["label"], begin=dummy_inputs["begin"], end=dummy_inputs["end"], deterministic=True, @@ -658,7 +650,6 @@ def compute_loss( JAX random key or a dictionary of JAX random keys. batch (Dict[str, Any]): A batch of data containing: - image (jax.Array): Input images of shape `(*, H, W, C)`. - - label (jax.Array): Conditioning labels of shape `(*, )`. params (frozen_dict.FrozenDict): The model parameters. deterministic (bool): Whether to run the model deterministically. **kwargs: additional keyword arguments. @@ -804,7 +795,6 @@ def forward( params (frozen_dict.FrozenDict): The model parameters. batch (Dict[str, Any]): A batch of data containing: - image (jax.Array): Input images of shape `(*, H, W, C)`. - - label (jax.Array): Conditioning labels of shape `(*, )`. shape (jax.typing.Shape): The shape of the output samples. dtype (Any): The dtype of the output samples. deterministic (bool): Whether to run the model deterministically. diff --git a/src/projects/generative/model/unet.py b/src/projects/generative/model/unet.py index 6c6d673..6f24bb0 100644 --- a/src/projects/generative/model/unet.py +++ b/src/projects/generative/model/unet.py @@ -139,7 +139,6 @@ def __call__( out = self.conv_1(nn.silu(self.norm_1(inputs))) if cond is not None: - chex.assert_shape(cond, (*batch_dims, cond.shape[-1])) out = out + self.cond_linear(cond)[..., None, None, :] out = nn.silu(self.norm_2(out)) out = self.dropout(out, deterministic=m_deterministic) From 1c4f404dd4b8b69b40f00dbc9d8fdf86d559c352 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Mon, 1 Dec 2025 04:14:29 -0500 Subject: [PATCH 57/67] hotfix: Fixed wrong implementation of timestamp conditioning in forward pass of meanflow model Signed-off-by: Juanwu Lu --- src/projects/generative/meanflow.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 2668940..4efd010 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -812,11 +812,26 @@ def forward( e = jax.random.normal(key=rngs, shape=shape, dtype=dtype) r = jnp.zeros(e.shape[:-3], dtype=dtype) t = jnp.ones(e.shape[:-3], dtype=dtype) + if self.timestamp_cond == "t_and_r": + b_arg, e_arg = r, t + elif self.timestamp_cond == "t_and_t_minus_r": + b_arg, e_arg = t - r, t + elif self.timestamp_cond == "t_and_r_and_t_minus_r": + raise NotImplementedError( + "`t_and_r_and_t_minus_r` conditioning is not implemented." + ) + elif self.timestamp_cond == "t_minus_r": + b_arg, e_arg = t - r, None + else: + raise ValueError( + f"Unsupported timestamp conditioning: {self.timestamp_cond}." + ) + out = e - self.network.apply( variables={"params": params}, image=e, - begin=t - r, - end=t, + begin=b_arg, + end=e_arg, deterministic=deterministic, ) From 551c366f3e1ef2d5ee1e7a4d3faaf53402005c7b Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Mon, 1 Dec 2025 15:01:46 -0500 Subject: [PATCH 58/67] feat: Added histogram attribute to the model step output Signed-off-by: Juanwu Lu --- src/core/model.py | 3 +++ src/projects/generative/meanflow.py | 1 + 2 files changed, 4 insertions(+) diff --git a/src/core/model.py b/src/core/model.py index 7f4eb04..0f1cd3b 100644 --- a/src/core/model.py +++ b/src/core/model.py @@ -15,11 +15,14 @@ class StepOutputs: output (Optional[jax.Array]): The main output of the model. scalars (Optional[Dict[str, Any]]): A dictionary of scalar metrics. images (Optional[Dict[str, Any]]): A dictionary of image outputs. + histograms (Optional[Dict[str, Array]]): A dictionary of array to + plot as histograms. """ output: typing.Optional[jax.Array] = None scalars: typing.Optional[typing.Dict[str, typing.Any]] = None images: typing.Optional[typing.Dict[str, typing.Any]] = None + histograms: typing.Optional[typing.Dict[str, jax.Array]] = None class Model(abc.ABC): diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index 4efd010..a1d2aaf 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -774,6 +774,7 @@ def u_fn( out = _model.StepOutputs( scalars={"loss": loss, "velocity_loss": velocity_loss}, + histograms={"t": t, "r": r, "t - r": t - r}, ) return loss, out From a2060a7cc76d0c102eef9e345dc647c1b1e6ded9 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Mon, 1 Dec 2025 15:33:31 -0500 Subject: [PATCH 59/67] feat: Added histogram logging for training and evaluation Signed-off-by: Juanwu Lu --- src/core/train.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/core/train.py b/src/core/train.py index 02315ff..8ecdd3a 100644 --- a/src/core/train.py +++ b/src/core/train.py @@ -155,6 +155,14 @@ def run( for k, v in outputs.images.items() }, ) + if outputs.histograms is not None: + writer.write_histograms( + step=step, + histograms={ + f"eval/{k}": v + for k, v in outputs.histograms.items() + }, + ) batch = _shard(batch) with jax.profiler.StepTraceAnnotation( @@ -187,7 +195,18 @@ def run( if outputs.images is not None: writer.write_images( step=step, - images=outputs.images, + images={ + f"train/{k}": v + for k, v in outputs.images.items() + }, + ) + if outputs.histograms is not None: + writer.write_histograms( + step=step, + histograms={ + f"train/{k}": v + for k, v in outputs.histograms.items() + }, ) step += 1 From 06ee4c5082ad62a1735383df66a07d74c9596d64 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Mon, 1 Dec 2025 15:34:00 -0500 Subject: [PATCH 60/67] feat: Updated implementation for meanflow model to take arbitrary tuple of timestamps Signed-off-by: Juanwu Lu --- src/projects/generative/meanflow.py | 72 +++++++++++++---------------- 1 file changed, 33 insertions(+), 39 deletions(-) diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index a1d2aaf..eec4a9a 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -467,8 +467,7 @@ def setup(self) -> None: def __call__( self, image: jax.Array, - begin: typing.Optional[jax.Array] = None, - end: typing.Optional[jax.Array] = None, + timestamps: typing.Tuple[jax.Array], deterministic: typing.Optional[bool] = None, ) -> jax.Array: r"""Forward pass the `MeanFlowUNetModel`. @@ -498,23 +497,8 @@ def __call__( deterministic, ) - if begin is not None: - chex.assert_shape(begin, batch_dims) - r_emb = self.time_embed(begin) - else: - r_emb = jnp.zeros( - shape=(*batch_dims, self.latent_channels), - dtype=image.dtype, - ) - if end is not None: - chex.assert_shape(end, batch_dims) - t_emb = self.time_embed(end) - else: - t_emb = jnp.zeros( - shape=(*batch_dims, self.latent_channels), - dtype=image.dtype, - ) - cond = jnp.concatenate([t_emb, r_emb], axis=-1) + emb = [self.time_embed(time) for time in timestamps] + cond = jnp.concatenate(emb, axis=-1) cond = jax.nn.silu(self.cond_in(cond)) cond = jax.nn.silu(self.cond_out(cond)) @@ -613,19 +597,35 @@ def init( del batch # unused # create dummy inputs + if self.timestamp_cond in ["t_and_r", "t_and_t_minus_r"]: + timestamps = ( + jnp.zeros((1,), dtype=jnp.float32), + jnp.zeros((1,), dtype=jnp.float32), + ) + elif self.timestamp_cond == "t_and_r_and_t_minus_r": + timestamps = ( + jnp.zeros((1,), dtype=jnp.float32), + jnp.zeros((1,), dtype=jnp.float32), + jnp.zeros((1,), dtype=jnp.float32), + ) + elif self.timestamp_cond == "t_minus_r": + timestamps = (jnp.zeros((1,), dtype=jnp.float32),) + else: + raise ValueError( + f"Unsupported timestamp conditioning: {self.timestamp_cond}." + ) + dummy_inputs = { "image": jnp.zeros( (1, self.image_size, self.image_size, self.in_channels), dtype=jnp.float32, ), - "begin": jnp.zeros((1,), dtype=jnp.float32), - "end": jnp.zeros((1,), dtype=jnp.float32), + "timestamps": timestamps, } variables = self.network.init( rngs=rngs, image=dummy_inputs["image"], - begin=dummy_inputs["begin"], - end=dummy_inputs["end"], + timestamps=dummy_inputs["timestamps"], deterministic=True, ) _tabulate_fn = nn.summary.tabulate(self.network, rngs=rngs) @@ -708,15 +708,13 @@ def u_fn( t_in: jax.Array, ) -> jax.Array: if self.timestamp_cond == "t_and_r": - b_arg, e_arg = r_in, t_in + timestamps = (r_in, t_in) elif self.timestamp_cond == "t_and_t_minus_r": - b_arg, e_arg = t_in - r_in, t_in + timestamps = (t_in - r_in, t_in) elif self.timestamp_cond == "t_and_r_and_t_minus_r": - raise NotImplementedError( - "`t_and_r_and_t_minus_r` conditioning is not implemented." - ) + timestamps = (t_in, r_in, t_in - r_in) elif self.timestamp_cond == "t_minus_r": - b_arg, e_arg = t_in - r_in, None + timestamps = (t_in - r_in,) else: raise ValueError( f"Unsupported timestamp conditioning: {self.timestamp_cond}." @@ -725,8 +723,7 @@ def u_fn( out = self.network.apply( variables={"params": params}, image=z_t, - begin=b_arg, - end=e_arg, + timestamps=timestamps, deterministic=deterministic, rngs={"dropout": dropout_rng}, ) @@ -814,15 +811,13 @@ def forward( r = jnp.zeros(e.shape[:-3], dtype=dtype) t = jnp.ones(e.shape[:-3], dtype=dtype) if self.timestamp_cond == "t_and_r": - b_arg, e_arg = r, t + timestamps = (t, r) elif self.timestamp_cond == "t_and_t_minus_r": - b_arg, e_arg = t - r, t + timestamps = (t, t - r) elif self.timestamp_cond == "t_and_r_and_t_minus_r": - raise NotImplementedError( - "`t_and_r_and_t_minus_r` conditioning is not implemented." - ) + timestamps = (t, r, t - r) elif self.timestamp_cond == "t_minus_r": - b_arg, e_arg = t - r, None + timestamps = (t - r,) else: raise ValueError( f"Unsupported timestamp conditioning: {self.timestamp_cond}." @@ -831,8 +826,7 @@ def forward( out = e - self.network.apply( variables={"params": params}, image=e, - begin=b_arg, - end=e_arg, + timestamps=timestamps, deterministic=deterministic, ) From 50df9aa95c63877619f1f1b996bf552a2bf9d1f5 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Mon, 1 Dec 2025 16:16:32 -0500 Subject: [PATCH 61/67] hotfix: Fixed error in logging histograms Signed-off-by: Juanwu Lu --- src/core/train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/core/train.py b/src/core/train.py index 8ecdd3a..e8cd287 100644 --- a/src/core/train.py +++ b/src/core/train.py @@ -158,11 +158,12 @@ def run( if outputs.histograms is not None: writer.write_histograms( step=step, - histograms={ + arrays={ f"eval/{k}": v for k, v in outputs.histograms.items() }, ) + writer.flush() batch = _shard(batch) with jax.profiler.StepTraceAnnotation( @@ -203,11 +204,12 @@ def run( if outputs.histograms is not None: writer.write_histograms( step=step, - histograms={ + arrays={ f"train/{k}": v for k, v in outputs.histograms.items() }, ) + writer.flush() step += 1 # checkpointing @@ -240,6 +242,7 @@ def run( step=step, scalars=scalar_output, ) + writer.flush() except Exception as e: logging.rank_zero_error( @@ -250,6 +253,7 @@ def run( _status = 1 finally: state = jax_utils.unreplicate(state) + writer.close() logging.rank_zero_info( "Training finished. Final step: %d. Exit with code %d.", state.step, From ef2bcacd24a3f15c047a673920bc94dbbb9e6345 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Mon, 1 Dec 2025 16:17:12 -0500 Subject: [PATCH 62/67] feat: Updated implementation for U-Net model in MeanFlow Signed-off-by: Juanwu Lu --- src/projects/generative/config.py | 7 ++- src/projects/generative/meanflow.py | 83 ++++++++++----------------- src/projects/generative/model/unet.py | 56 +++++++++++++++--- 3 files changed, 84 insertions(+), 62 deletions(-) diff --git a/src/projects/generative/config.py b/src/projects/generative/config.py index c3c25b8..75a1733 100644 --- a/src/projects/generative/config.py +++ b/src/projects/generative/config.py @@ -1,4 +1,5 @@ import functools +import math import fiddle as fdl import optax @@ -40,10 +41,10 @@ def meanflow_unet_cifar_10() -> _config.ExperimentConfig: meanflow.MeanFlowUNetModel, in_channels=3, image_size=32, - latent_channels=128, - num_classes=10, - use_cfg_embedding=False, + features=128, dropout_rate=0.2, + epsilon=1e-6, + skip_scale=math.sqrt(0.5), timestamp_cond="t_and_t_minus_r", timestamp_sampler="logit-normal", timestamp_sampler_kwargs=dict(mean=-2.0, stddev=2.0), diff --git a/src/projects/generative/meanflow.py b/src/projects/generative/meanflow.py index eec4a9a..8c17158 100644 --- a/src/projects/generative/meanflow.py +++ b/src/projects/generative/meanflow.py @@ -395,52 +395,43 @@ class MeanFlowUNetModule(nn.Module): """Generative model with a RefineNet backbone trained with `MeanFlow`. Attributes: - in_channels (int): Number of channels in the input images. - image_size (int): Height and width of the input images. - latent_channels (int): Number of channels in the latent feature maps. - num_classes (int): Number of conditioning classes. + features (int): Number of channels in the latent feature maps. + num_groups (int): Number of groups for `GroupNorm` layers. + dropout_rate (float): Dropout rate for the attention blocks. + epsilon (float): Small constant for numerical stability in `GroupNorm`. + skip_scale (float): Scaling factor for skip connections. + deterministic (Optional[bool]): Whether to run deterministically. dtype (dtype): The dtype of the computation (default: float32). param_dtype (dtype): The dtype of the parameters (default: float32). """ - in_channels: int - """int: Number of channels in the input images.""" - image_size: int - """int: Height and width of the input images.""" - latent_channels: int - """int: Number of channels in the latent feature maps.""" - num_classes: int - """int: Number of conditioning classes.""" - use_cfg_embedding: bool = False - """bool: Whether to use classifier-free guidance (CFG) embedding.""" - deterministic: typing.Optional[bool] = None - """Optional[bool]: Whether to run deterministically.""" + features: int + num_groups: int = 32 dropout_rate: float = 0.0 - """float: Dropout rate for the classifier-free guidance.""" + epsilon: float = 1e-5 + skip_scale: float = 1.0 + deterministic: typing.Optional[bool] = None dtype: typing.Any = None - """typing.Any: The dtype of the computation.""" param_dtype: typing.Any = None - """typing.Any: The dtype of the parameters.""" precision: typing.Any = None - """typing.Any: The precision of the computation.""" def setup(self) -> None: r"""Instantiate a `MeanFlowUNetModel` module.""" # backbone U-Net model self.backbone = unet.ScoreNet( - features=self.latent_channels, + features=self.features, + num_groups=self.num_groups, + epsilon=self.epsilon, dropout_rate=self.dropout_rate, + skip_scale=self.skip_scale, dtype=self.dtype, param_dtype=self.param_dtype, ) # conditional embeddings - self.time_embed = SinusoidalEmbed( - self.latent_channels * 2, - endpoint=True, - ) + self.time_embed = SinusoidalEmbed(self.features * 2, endpoint=True) self.cond_in = nn.Dense( - features=self.latent_channels, + features=self.features * 4, kernel_init=jax.nn.initializers.variance_scaling( scale=1.0, mode="fan_avg", @@ -452,7 +443,7 @@ def setup(self) -> None: name="cond_fc_1", ) self.cond_out = nn.Dense( - features=self.latent_channels, + features=self.features * 4, kernel_init=jax.nn.initializers.variance_scaling( scale=1.0, mode="fan_avg", @@ -481,16 +472,6 @@ def __call__( Returns: The predicted average velocity of shape `(*, H, W, C)`. """ - - # sanity check for the input arrays - batch_dims = image.shape[:-3] - dims = chex.Dimensions( - H=self.image_size, - W=self.image_size, - C=self.in_channels, - ) - chex.assert_shape(image, (*batch_dims, *dims["HWC"])) - m_deterministic = nn.merge_param( "deterministic", self.deterministic, @@ -499,8 +480,7 @@ def __call__( emb = [self.time_embed(time) for time in timestamps] cond = jnp.concatenate(emb, axis=-1) - cond = jax.nn.silu(self.cond_in(cond)) - cond = jax.nn.silu(self.cond_out(cond)) + cond = self.cond_out(jax.nn.silu(self.cond_in(cond))) output = self.backbone( inputs=image, @@ -515,12 +495,12 @@ class MeanFlowUNetModel(_model.Model): r"""`MeanFlow` generative model with a U-Net backbone. Args: - in_channels (int): Number of channels in the input images. - image_size (int): Height and width of the (square) input images. - latent_channels (int): Number of channels in the latent feature maps. - num_classes (int): Number of conditioning classes. - use_cfg_embedding (bool): Whether to use classifier-free guidance (CFG). + in_channels (int): Number of input image channels. + image_size (int): Height and width of the input images. + features (int): Dimensionality of the latent feature map. dropout_rate (float): Dropout rate for the classifier-free guidance. + epsilon (float): Small constant for numerical stability in `GroupNorm`. + skip_scale (float): Scaling factor for skip connections. dtype (dtype): The dtype of the computation (default: float32). param_dtype (dtype): The dtype of the parameters (default: float32). timestamp_cond (Literal): The type of timestamp conditioning. @@ -539,10 +519,10 @@ def __init__( self, in_channels: int, image_size: int, - latent_channels: int, - num_classes: int, - use_cfg_embedding: bool, + features: int, dropout_rate: float, + epsilon: float, + skip_scale: float, dtype: typing.Any = None, param_dtype: typing.Any = None, precision: typing.Any = None, @@ -563,18 +543,17 @@ def __init__( """Initializes the `MeanFlow` model.""" self.in_channels = in_channels self.image_size = image_size + self.features = features self.timestamp_cond = timestamp_cond self.timestamp_sampler = timestamp_sampler self.timestamp_sampler_kwargs = timestamp_sampler_kwargs self.timestamp_overlap_rate = timestamp_overlap_rate self.adaptive_weight_power = adaptive_weight_power self._network = MeanFlowUNetModule( - in_channels=in_channels, - image_size=image_size, - latent_channels=latent_channels, - num_classes=num_classes, - use_cfg_embedding=use_cfg_embedding, + features=features, + skip_scale=skip_scale, dropout_rate=dropout_rate, + epsilon=epsilon, name="unet", dtype=dtype, param_dtype=param_dtype, diff --git a/src/projects/generative/model/unet.py b/src/projects/generative/model/unet.py index 6f24bb0..7bd0a27 100644 --- a/src/projects/generative/model/unet.py +++ b/src/projects/generative/model/unet.py @@ -18,6 +18,8 @@ class ResNetBlock(nn.Module): deterministic (bool, optional): If true, the model is run in deterministic mode (e.g., no dropout). Defaults to `None`. dropout_rate (float, optional): Dropout rate. Default is :math:`0`. + skip_scale (float, optional): Scaling factor for the residual + connection output. Default is :math:`1.0`. dtype (Any, optional): The dtype of the computation. param_dtype (Any, optional): The dtype of the parameters. precision (Any, optional): Numerical precision of the computation. @@ -28,6 +30,7 @@ class ResNetBlock(nn.Module): epsilon: float = 1e-5 deterministic: typing.Optional[bool] = None dropout_rate: float = 0.0 + skip_scale: float = 1.0 dtype: typing.Any = None param_dtype: typing.Any = None precision: typing.Any = None @@ -136,11 +139,11 @@ def __call__( C=inputs.shape[-1], ) - out = self.conv_1(nn.silu(self.norm_1(inputs))) + out = self.conv_1(jax.nn.silu(self.norm_1(inputs))) if cond is not None: - out = out + self.cond_linear(cond)[..., None, None, :] - out = nn.silu(self.norm_2(out)) + out = out + self.cond_linear(jax.nn.silu(cond))[..., None, None, :] + out = jax.nn.silu(self.norm_2(out)) out = self.dropout(out, deterministic=m_deterministic) out = self.conv_2(out) @@ -149,6 +152,7 @@ def __call__( else: shortcut = inputs out = out + shortcut + out = out * self.skip_scale chex.assert_shape(out, (*batch_dims, *dims["HW"], self.features)) return out @@ -277,9 +281,21 @@ class AttnBlock(nn.Module): r"""Self-attention block with group normalization in U-Net models. Args: + num_heads (int): Number of attention heads. + num_groups (int): Number of groups for `GroupNorm`. + epsilon (float, optional): Small float added to variance to avoid + dividing by zero in `GroupNorm`. Default is :math:`1e-5`. + skip_scale (float, optional): Scaling factor for the residual + connection output. Default is :math:`1.0`. + dtype (Any, optional): The dtype of the computation. + param_dtype (Any, optional): The dtype of the parameters. + precision (Any, optional): Numerical precision of the computation. """ - num_heads: int = 1 + num_heads: int + num_groups: int + epsilon: float = 1e-5 + skip_scale: float = 1.0 dtype: typing.Any = None param_dtype: typing.Any = None precision: typing.Any = None @@ -296,8 +312,8 @@ def __call__(self, inputs: jax.Array) -> jax.Array: """ norm_in = nn.GroupNorm( - num_groups=32, - epsilon=1e-5, + num_groups=self.num_groups, + epsilon=self.epsilon, dtype=self.dtype, param_dtype=self.param_dtype, name="norm", @@ -409,6 +425,15 @@ def __call__(self, inputs: jax.Array) -> jax.Array: name="v_proj", ) value = v_proj(out) + out = nn.dot_product_attention( + query, + key, + value, + broadcast_dropout=False, + dropout_rate=0.0, + dtype=self.dtype, + precision=self.precision, + ) out_proj = nn.DenseGeneral( features=inputs.shape[-1], kernel_init=jax.nn.initializers.zeros, @@ -421,6 +446,7 @@ def __call__(self, inputs: jax.Array) -> jax.Array: chex.assert_equal_shape([out, inputs]) out = out + inputs + out = out * self.skip_scale return out @@ -444,6 +470,8 @@ class ScoreNet(nn.Module): dropout_rate (float, optional): Dropout rate. Default is :math:`0.0`. epsilon (float, optional): Small float added to variance to avoid dividing by zero in `GroupNorm`. Default is :math:`1e-5`. + skip_scale (float, optional): Scaling factor for the residual + connection outputs. Default is :math:`1.0`. deterministic (bool, optional): If true, the model is run in deterministic mode (e.g., no dropout). Defaults to `None`. dtype (Any, optional): The dtype of the computation. @@ -458,6 +486,7 @@ class ScoreNet(nn.Module): attn_resolutions: typing.Sequence[int] = (16,) dropout_rate: float = 0.0 epsilon: float = 1e-5 + skip_scale: float = 1.0 deterministic: typing.Optional[bool] = None dtype: typing.Any = None param_dtype: typing.Any = None @@ -522,6 +551,7 @@ def __call__( num_groups=self.num_groups, dropout_rate=self.dropout_rate, epsilon=self.epsilon, + skip_scale=self.skip_scale, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, @@ -535,6 +565,9 @@ def __call__( if out.shape[-3] in self.attn_resolutions: block = AttnBlock( num_heads=1, + num_groups=self.num_groups, + epsilon=self.epsilon, + skip_scale=self.skip_scale, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, @@ -557,6 +590,7 @@ def __call__( features=out.shape[-1], num_groups=self.num_groups, epsilon=self.epsilon, + skip_scale=self.skip_scale, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, @@ -565,6 +599,9 @@ def __call__( out = block(out, cond=cond, deterministic=m_deterministic) block = AttnBlock( num_heads=1, + num_groups=self.num_groups, + epsilon=self.epsilon, + skip_scale=self.skip_scale, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, @@ -575,6 +612,7 @@ def __call__( features=out.shape[-1], num_groups=self.num_groups, epsilon=self.epsilon, + skip_scale=self.skip_scale, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, @@ -593,6 +631,7 @@ def __call__( dropout_rate=self.dropout_rate, num_groups=self.num_groups, epsilon=self.epsilon, + skip_scale=self.skip_scale, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, @@ -606,6 +645,9 @@ def __call__( if out.shape[-3] in self.attn_resolutions: block = AttnBlock( num_heads=1, + num_groups=self.num_groups, + epsilon=self.epsilon, + skip_scale=self.skip_scale, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, @@ -630,7 +672,7 @@ def __call__( param_dtype=self.param_dtype, name="norm_out", ) - out = nn.silu(norm_out(out)) + out = jax.nn.silu(norm_out(out)) conv_out = nn.Conv( features=dims.C, # type: ignore kernel_size=(3, 3), From 2b9761426c3955c390bcd87e5933b153e53596d7 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Mon, 1 Dec 2025 16:56:06 -0500 Subject: [PATCH 63/67] hotfix: Increased data loading batch size to 1024 for CIFAR-10 Signed-off-by: Juanwu Lu --- src/projects/generative/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/projects/generative/config.py b/src/projects/generative/config.py index 75a1733..3da65e4 100644 --- a/src/projects/generative/config.py +++ b/src/projects/generative/config.py @@ -32,7 +32,7 @@ def meanflow_unet_cifar_10() -> _config.ExperimentConfig: ), ), ), - batch_size=128, + batch_size=1024, num_workers=2, deterministic=True, drop_remainder=True, From 063551773a0ad1bb6980c3d30bf2351d72ecd729 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Mon, 1 Dec 2025 17:04:53 -0500 Subject: [PATCH 64/67] feat: Updated implementation for evaluation step Signed-off-by: Juanwu Lu --- src/core/evaluate.py | 50 ++++++++++++++++++++++----------- src/projects/generative/main.py | 6 ++-- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/src/core/evaluate.py b/src/core/evaluate.py index 44ebf22..a247dd2 100644 --- a/src/core/evaluate.py +++ b/src/core/evaluate.py @@ -1,10 +1,12 @@ import collections import functools +import traceback import typing from clu import metric_writers from clu import periodic_actions import jax +from jax import numpy as jnp import jaxtyping from src.core import datamodule as _datamodule @@ -13,8 +15,8 @@ def run( - model: _model.Model, datamodule: _datamodule.DataModule, + evaluation_step: typing.Callable[..., _model.StepOutputs], params: jaxtyping.PyTree, writer: metric_writers.MetricWriter, work_dir: str, @@ -24,8 +26,8 @@ def run( """Runs evaluation loop with the given model and datamodule. Args: - model (Model): The model to evaluate. datamodule (DataModule): The datamodule providing the evaluation data. + evaluation_step (Callable): The pmapped evaluation step function. params (PyTree): The model parameters to use for evaluation. writer (MetricWriter): The metric writer for logging evaluation metrics. work_dir (str): The working directory for saving outputs. @@ -36,11 +38,11 @@ def run( Integer status code (0 for success). """ _status = 0 - logging.rank_zero_debug(f"running {model.__class__.__name__} eval...") - eval_rng = jax.random.fold_in(rng, jax.process_index()) - p_evaluation_step = functools.partial(model.evaluation_step, rng=eval_rng) + logging.rank_zero_info("Compiling evaluation step...") + p_evaluation_step = functools.partial(evaluation_step, rng=rng) p_evaluation_step = jax.pmap(p_evaluation_step, axis_name="batch") + logging.rank_zero_info("Compiling evaluation step...DONE!") hooks = [] if jax.process_index() == 0: @@ -69,7 +71,7 @@ def run( batch, ) with jax.profiler.StepTraceAnnotation( - name="train", + name="evaluation", step_num=step, ): outputs = p_evaluation_step( @@ -85,38 +87,52 @@ def run( # logging at the end of batch if outputs.scalars is not None: - _scalars = {} - for k, v in outputs.scalars.items(): - eval_metrics[k].append(jax.device_get(v).mean()) - _scalars[ - f"eval/{k.replace('_', ' ')}" - ] = jax.device_get(v).mean() writer.write_scalars( - step=step + 1, - scalars=_scalars, + step=step, + scalars={ + f"eval/{k}_step": sum(v) / len(v) + for k, v in outputs.scalars.items() + }, ) if outputs.images is not None: writer.write_images( - step=step + 1, - images=outputs.images, + step=step, + images={ + f"eval/{k}_step": v + for k, v in outputs.images.items() + }, ) + if outputs.histograms is not None: + writer.write_histograms( + step=step, + arrays={ + f"eval/{k}_step": v + for k, v in outputs.histograms.items() + }, + ) + writer.flush() # logging at the end of evaluation logging.rank_zero_info("Evaluation done.") scalar_output = { - f"eval/{k.replace('_', ' ')}": sum(v) / len(v) + f"eval/{k.replace('_', ' ')}_epoch": sum(v) / len(v) for k, v in eval_metrics.items() } writer.write_scalars( step=step, scalars=scalar_output, ) + writer.flush() + except Exception as e: logging.rank_zero_error( "Exception occurred during evaluation: %s", e ) + error_trace = traceback.format_exc() + logging.rank_zero_error("Stack trace:\n%s", error_trace) _status = 1 finally: + writer.close() logging.rank_zero_info( "Evaluation done. Exit with code %d.", _status, diff --git a/src/projects/generative/main.py b/src/projects/generative/main.py index 232dc32..7a3f740 100644 --- a/src/projects/generative/main.py +++ b/src/projects/generative/main.py @@ -206,9 +206,9 @@ def main(_: typing.List[str]) -> int: logging.rank_zero_error("Resuming from checkpoint not implemented.") return 1 + p_training_step = functools.partial(training_step, model=model) + p_evaluation_step = functools.partial(evaluation_step, model=model) if exp_config.mode == "train": - p_training_step = functools.partial(training_step, model=model) - p_evaluation_step = functools.partial(evaluation_step, model=model) _train.run( state=state, datamodule=datamodule, @@ -225,8 +225,8 @@ def main(_: typing.List[str]) -> int: ) elif exp_config.mode == "evaluate": _evaluate.run( - model=model, datamodule=datamodule, + evaluation_step=p_evaluation_step, params=params, writer=writer, work_dir=log_dir, From c4f14b95c0107550d15f611b8e6f5bdb9bd2cfa6 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Mon, 1 Dec 2025 19:14:15 -0500 Subject: [PATCH 65/67] hotfix: Fixed typo Signed-off-by: Juanwu Lu --- src/core/train.py | 1 - src/projects/generative/model/unet.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/core/train.py b/src/core/train.py index e8cd287..4a3ca6b 100644 --- a/src/core/train.py +++ b/src/core/train.py @@ -61,7 +61,6 @@ def run( training_step (Callable): The training step function. evaluation_step (Callable): The evaluation step function. num_train_steps (int): Number of training steps. - checkpoint_manager (Checkpoint): The checkpoint manager. writer (MetricWriter): The metric writer for logging. work_dir (str): The working directory for saving checkpoints and logs. rng (Any): The random number generator. diff --git a/src/projects/generative/model/unet.py b/src/projects/generative/model/unet.py index 7bd0a27..fbce45f 100644 --- a/src/projects/generative/model/unet.py +++ b/src/projects/generative/model/unet.py @@ -10,7 +10,7 @@ class ResNetBlock(nn.Module): r"""A residual downsampling block with two convolutional layers. Args: - features (int): Dimensionality of the latent feaatures. + features (int): Dimensionality of the latent features. num_groups (int, optional): Number of groups for `GroupNorm`. Default is :math:`32`. epsilon (float, optional): Small float added to variance to avoid From abcf902988382a07e7de9038b3a5bd295b01a3d6 Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Mon, 1 Dec 2025 19:20:38 -0500 Subject: [PATCH 66/67] hotfix: Fixed infinite outer loop in training Signed-off-by: Juanwu Lu --- src/core/evaluate.py | 1 - src/core/train.py | 4 ++++ src/utilities/visualization.py | 2 -- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/core/evaluate.py b/src/core/evaluate.py index a247dd2..9e7b4ff 100644 --- a/src/core/evaluate.py +++ b/src/core/evaluate.py @@ -6,7 +6,6 @@ from clu import metric_writers from clu import periodic_actions import jax -from jax import numpy as jnp import jaxtyping from src.core import datamodule as _datamodule diff --git a/src/core/train.py b/src/core/train.py index 4a3ca6b..041cf12 100644 --- a/src/core/train.py +++ b/src/core/train.py @@ -243,6 +243,10 @@ def run( ) writer.flush() + # break outer loop if reach max steps + if step >= num_train_steps: + break + except Exception as e: logging.rank_zero_error( "Exception occurred during training: %s", e diff --git a/src/utilities/visualization.py b/src/utilities/visualization.py index ffa8e14..15a540c 100644 --- a/src/utilities/visualization.py +++ b/src/utilities/visualization.py @@ -1,5 +1,3 @@ -import typing - import jax from jax import numpy as jnp From 29296d1162bc4ea64359b54d4d595b1ddd598ffa Mon Sep 17 00:00:00 2001 From: Juanwu Lu Date: Mon, 1 Dec 2025 19:23:38 -0500 Subject: [PATCH 67/67] hotfix: Fixed conflict in naming of `batch` Signed-off-by: Juanwu Lu --- src/core/train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/core/train.py b/src/core/train.py index 041cf12..68fe193 100644 --- a/src/core/train.py +++ b/src/core/train.py @@ -111,7 +111,7 @@ def run( try: train_metrics = collections.defaultdict(list) while True: - for batch in datamodule.train_dataloader(): + for train_batch in datamodule.train_dataloader(): # evaluation and sanity check running if ( step % eval_every_n_steps == 0 @@ -120,11 +120,11 @@ def run( logging.rank_zero_info("Running evaluation...") eval_metrics = collections.defaultdict(list) outputs = None - for batch in datamodule.eval_dataloader(): - batch = _shard(batch) + for eval_batch in datamodule.eval_dataloader(): + eval_batch = _shard(eval_batch) outputs = p_evaluation_step( params=state.params, - batch=batch, + batch=eval_batch, ) if not isinstance(outputs, _model.StepOutputs): raise RuntimeError( @@ -164,14 +164,14 @@ def run( ) writer.flush() - batch = _shard(batch) + train_batch = _shard(train_batch) with jax.profiler.StepTraceAnnotation( name="train", step_num=step, ): state, outputs = p_training_step( state=state, - batch=batch, + batch=train_batch, ) if not isinstance(outputs, _model.StepOutputs): raise RuntimeError(