Skip to content
Merged
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
ccfa898
hotfix: Fixed typo in core module
juanwulu Nov 19, 2025
0810627
feat: Updated implementation fo RefineNet
juanwulu Nov 19, 2025
b60e746
hotfix: Fixed typo in core module
juanwulu Nov 19, 2025
95be2cb
feat: Updated implementation of HuggingFace datamodule
juanwulu Nov 19, 2025
2e923ce
hotfix: Fixed build target for utility module
juanwulu Nov 19, 2025
7ca6e1e
hotfix: Fixed typo in build targets
juanwulu Nov 19, 2025
35ad184
feat: Updated implementation of MeanFlow
juanwulu Nov 19, 2025
8ca3ed3
feat: Added main entrypoint for training and evaluation of generative…
juanwulu Nov 19, 2025
0ab6cfc
feat: Updated configuration for training U-Net meanflow on CIFAR-10
juanwulu Nov 19, 2025
f36fbeb
hotfix: Fixed issue with version of `chex`
juanwulu Nov 19, 2025
e9816c8
hotfix: Updated main train logic
juanwulu Nov 19, 2025
c2fb1ba
hotfix: Improve the log frequency for meanflow on CIFAR-10
juanwulu Nov 19, 2025
a2f8fbe
feat: Updated the main logic for training step in MeanFlow
juanwulu Nov 19, 2025
91e96b3
feat: Updated implementation of MeanFlow
juanwulu Nov 19, 2025
dd4df9b
feat: Updated implementation of MeanFlow
juanwulu Nov 19, 2025
d585b66
feat: Added checkpoint frequency attribute to trainer config
juanwulu Nov 19, 2025
41223e0
feat: Updated implementation of train logic
juanwulu Nov 19, 2025
4357a02
feat: Updated the main logic for training step in MeanFlow
juanwulu Nov 19, 2025
b7d1967
feat: Updated the model protocol
juanwulu Nov 19, 2025
60db887
feat: Updated the training logic
juanwulu Nov 19, 2025
c435f02
feat: Implemented the new model protocol for MeanFlow
juanwulu Nov 19, 2025
3b12acb
feat: Implemented the new main entrypoint with train logic
juanwulu Nov 19, 2025
cbfaa3c
feat: Updated implementation for huggingface dataset
juanwulu Nov 19, 2025
d368e52
hotfix: Fixed error in huggingface datamodule
juanwulu Nov 19, 2025
d91518f
hotfix: Updated checkpoint frequency
juanwulu Nov 20, 2025
07ab0aa
feat: Added visualization utility to create a grid of images
juanwulu Nov 21, 2025
f462f55
hotfix: Fixed wrong implementation of t\neq{r} in meanflow
juanwulu Nov 21, 2025
3649da1
feat: symmetric mean flow
lan-qing Nov 21, 2025
4bad082
hotfix: Switch back to original meanflow loss
juanwulu Nov 21, 2025
63b4cdf
feat: Added dependencies for running on MPS framework
juanwulu Nov 28, 2025
222a884
feat: Added MPS dependencies to PIP hubs
juanwulu Nov 28, 2025
df11ccb
feat: Added implementation for downsampling residual block in U-Net
juanwulu Nov 28, 2025
e75bc79
hotfix: Rename `DownResNetBlock` to `ResNetBlock`
juanwulu Nov 28, 2025
9fec7d4
feat: Added implementation for downsampling block in U-Net
juanwulu Nov 28, 2025
147427d
feat: Added implementation for upsampling block in U-Net
juanwulu Nov 28, 2025
6ce9152
feat: Added full implementation of U-Net for score-based generative m…
juanwulu Nov 28, 2025
e3e0c4f
feat: Added implementation for scaled dot-product attention block
juanwulu Nov 28, 2025
19964ba
feat: Integrate attention block to score U-Net architecture
juanwulu Nov 28, 2025
76781b3
hotfix: Adds missing attention block in upsampling path of U-Net
juanwulu Nov 28, 2025
6716140
feat: Integrates score-based U-Net for meanflow experiment on CIFAR-10
juanwulu Nov 28, 2025
70bb67c
hotfix: Fixes issue of `dropout_rate` in U-Net for meanflow
juanwulu Nov 28, 2025
27bc1ec
hotfix: Fixes issue of missing dropout rng in U-Net for meanflow
juanwulu Nov 28, 2025
d0e352a
hotfix: Fixes issue of missing dropout rng in U-Net for meanflow
juanwulu Nov 28, 2025
9d45b9c
hotfix: Updated configurations for training U-Net on CIFAR-10
juanwulu Nov 29, 2025
03655ed
hotfix: Fixed implementation of logit-normal timestamp sampler
juanwulu Nov 29, 2025
f307e8c
hotfix: Fixed implementation of logit-normal timestamp sampler
juanwulu Nov 29, 2025
90b5d4e
hotfix: Fixed implementation for JAX in MacOS
juanwulu Nov 29, 2025
12ee102
feat: Implements sinusoidal positional encoding for U-Net
juanwulu Nov 29, 2025
6f80f07
hotfix: Updated step output to contain model output array
juanwulu Nov 30, 2025
a42758a
feat: Updated grid visualization function to use jax array
juanwulu Nov 30, 2025
cef9f5a
feat: Implements the evaluation step for meanflow with visualization
juanwulu Nov 30, 2025
de1f95c
feat: Moved evaluation to before the training inner loop
juanwulu Nov 30, 2025
74e69ce
feat: Added random left-right flip in training loop
juanwulu Nov 30, 2025
000f1a6
feat: Updated implementation for MeanFlow network and remove label co…
juanwulu Nov 30, 2025
d8e4b76
hotfix: Fixed error raised by wrong shape checking
juanwulu Nov 30, 2025
a50c12e
feat: Fixed training collapse by adding fc layers for timestamp condi…
juanwulu Dec 1, 2025
1c4f404
hotfix: Fixed wrong implementation of timestamp conditioning in forwa…
juanwulu Dec 1, 2025
551c366
feat: Added histogram attribute to the model step output
juanwulu Dec 1, 2025
a2060a7
feat: Added histogram logging for training and evaluation
juanwulu Dec 1, 2025
06ee4c5
feat: Updated implementation for meanflow model to take arbitrary tup…
juanwulu Dec 1, 2025
50df9aa
hotfix: Fixed error in logging histograms
juanwulu Dec 1, 2025
ef2bcac
feat: Updated implementation for U-Net model in MeanFlow
juanwulu Dec 1, 2025
2b97614
hotfix: Increased data loading batch size to 1024 for CIFAR-10
juanwulu Dec 1, 2025
0635517
feat: Updated implementation for evaluation step
juanwulu Dec 1, 2025
c4f14b9
hotfix: Fixed typo
juanwulu Dec 2, 2025
abcf902
hotfix: Fixed infinite outer loop in training
juanwulu Dec 2, 2025
29296d1
hotfix: Fixed conflict in naming of `batch`
juanwulu Dec 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,5 @@ cython_debug/
/data/
/logs/
requirements_*.txt

.specstory/
6 changes: 6 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,15 @@ pip.parse(
python_version = "3.10",
requirements_lock = "//third_party:requirements_3_10_tpu_lock.txt",
)
pip.parse(
hub_name = "ml_infra_mps_3_10",
python_version = "3.10",
requirements_lock = "//third_party:requirements_3_10_mps_lock.txt",
)
use_repo(
pip,
ml_infra_cpu_3_10 = "ml_infra_cpu_3_10",
ml_infra_cuda_3_10 = "ml_infra_cuda_3_10",
ml_infra_mps_3_10 = "ml_infra_mps_3_10",
ml_infra_tpu_3_10 = "ml_infra_tpu_3_10",
)
8 changes: 4 additions & 4 deletions src/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ml_py_library(
deps = [
"fiddle",
"optax",
":data",
":datamodule",
":model",
],
)
Expand All @@ -24,7 +24,7 @@ ml_py_library(
deps = [
"clu",
"jax",
":data",
":datamodule",
":model",
"//src/utilities:logging",
],
Expand All @@ -36,8 +36,8 @@ ml_py_library(
deps = [
"chex",
"flax",
"jax",
"jaxtyping",
":train_state",
],
)

Expand All @@ -60,7 +60,7 @@ ml_py_library(
"flax",
"jax",
"jaxtyping",
":data",
":datamodule",
":model",
":train_state",
"//src/utilities:logging",
Expand Down
7 changes: 5 additions & 2 deletions src/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import fiddle as fdl
import optax

from src.core import data as _data
from src.core import datamodule as _datamodule
from src.core import model as _model


Expand All @@ -20,7 +20,7 @@ class DataConfig:
drop_remainder (bool): Whether to drop the last incomplete batch.
"""

module: fdl.Partial[_data.DataModule]
module: fdl.Partial[_datamodule.DataModule]
batch_size: int = 32
num_workers: int = 4
deterministic: bool = True
Expand All @@ -45,6 +45,8 @@ class TrainerConfig:

Attributes:
num_train_steps (int): Total number of training steps.
checkpoint_every_n_steps (Optional[int]): Frequency of checkpointing.
If `None`, defaults to `eval_every_n_steps`.
log_every_n_steps (int): Frequency of logging training metrics.
eval_every_n_steps (int): Frequency of evaluation during training.
checkpoint_dir (Optional[str]): Directory of checkpoint to resume from.
Expand All @@ -53,6 +55,7 @@ class TrainerConfig:
"""

num_train_steps: int = 10_000
checkpoint_every_n_steps: typing.Optional[int] = None
log_every_n_steps: int = 50
eval_every_n_steps: int = 1_000
checkpoint_dir: typing.Optional[str] = None
Expand Down
54 changes: 35 additions & 19 deletions src/core/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import collections
import functools
import traceback
import typing

from clu import metric_writers
from clu import periodic_actions
import jax
from jax import numpy as jnp
import jaxtyping

from src.core import data as _data
from src.core import datamodule as _datamodule
from src.core import model as _model
from src.utilities import logging


def run(
model: _model.Model,
datamodule: _data.DataModule,
datamodule: _datamodule.DataModule,
evaluation_step: typing.Callable[..., _model.StepOutputs],
params: jaxtyping.PyTree,
writer: metric_writers.MetricWriter,
work_dir: str,
Expand All @@ -24,8 +26,8 @@ def run(
"""Runs evaluation loop with the given model and datamodule.

Args:
model (Model): The model to evaluate.
datamodule (DataModule): The datamodule providing the evaluation data.
evaluation_step (Callable): The pmapped evaluation step function.
params (PyTree): The model parameters to use for evaluation.
writer (MetricWriter): The metric writer for logging evaluation metrics.
work_dir (str): The working directory for saving outputs.
Expand All @@ -36,11 +38,11 @@ def run(
Integer status code (0 for success).
"""
_status = 0
logging.rank_zero_debug(f"running {model.__class__.__name__} eval...")

eval_rng = jax.random.fold_in(rng, jax.process_index())
p_evaluation_step = functools.partial(model.evaluation_step, rng=eval_rng)
logging.rank_zero_info("Compiling evaluation step...")
p_evaluation_step = functools.partial(evaluation_step, rng=rng)
p_evaluation_step = jax.pmap(p_evaluation_step, axis_name="batch")
logging.rank_zero_info("Compiling evaluation step...DONE!")

hooks = []
if jax.process_index() == 0:
Expand Down Expand Up @@ -69,7 +71,7 @@ def run(
batch,
)
with jax.profiler.StepTraceAnnotation(
name="train",
name="evaluation",
step_num=step,
):
outputs = p_evaluation_step(
Expand All @@ -85,38 +87,52 @@ def run(

# logging at the end of batch
if outputs.scalars is not None:
_scalars = {}
for k, v in outputs.scalars.items():
eval_metrics[k].append(jax.device_get(v).mean())
_scalars[
f"eval/{k.replace('_', ' ')}"
] = jax.device_get(v).mean()
writer.write_scalars(
step=step + 1,
scalars=_scalars,
step=step,
scalars={
f"eval/{k}_step": sum(v) / len(v)
for k, v in outputs.scalars.items()
},
)
if outputs.images is not None:
writer.write_images(
step=step + 1,
images=outputs.images,
step=step,
images={
f"eval/{k}_step": v
for k, v in outputs.images.items()
},
)
if outputs.histograms is not None:
writer.write_histograms(
step=step,
arrays={
f"eval/{k}_step": v
for k, v in outputs.histograms.items()
},
)
writer.flush()

# logging at the end of evaluation
logging.rank_zero_info("Evaluation done.")
scalar_output = {
f"eval/{k.replace('_', ' ')}": sum(v) / len(v)
f"eval/{k.replace('_', ' ')}_epoch": sum(v) / len(v)
for k, v in eval_metrics.items()
}
writer.write_scalars(
step=step,
scalars=scalar_output,
)
writer.flush()

except Exception as e:
logging.rank_zero_error(
"Exception occurred during evaluation: %s", e
)
error_trace = traceback.format_exc()
logging.rank_zero_error("Stack trace:\n%s", error_trace)
_status = 1
finally:
writer.close()
logging.rank_zero_info(
"Evaluation done. Exit with code %d.",
_status,
Expand Down
78 changes: 30 additions & 48 deletions src/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,27 @@
import typing

import chex
from flax import struct
from flax.core import frozen_dict
import jax
import jaxtyping

from src.core import train_state as _train_state


@chex.dataclass
class StepOutputs:
"""A base container for outputs from a single step.

Attributes:
output (Optional[jax.Array]): The main output of the model.
scalars (Optional[Dict[str, Any]]): A dictionary of scalar metrics.
images (Optional[Dict[str, Any]]): A dictionary of image outputs.
histograms (Optional[Dict[str, Array]]): A dictionary of array to
plot as histograms.
"""

output: typing.Optional[jax.Array] = None
scalars: typing.Optional[typing.Dict[str, typing.Any]] = None
images: typing.Optional[typing.Dict[str, typing.Any]] = None
histograms: typing.Optional[typing.Dict[str, jax.Array]] = None


class Model(abc.ABC):
Expand Down Expand Up @@ -51,67 +55,45 @@ def init(
pass

@abc.abstractmethod
def training_step(
def compute_loss(
self,
*,
state: _train_state.TrainState,
batch: typing.Any,
rngs: typing.Union[typing.Any, typing.Dict[str, typing.Any]],
rngs: typing.Any,
deterministic: bool = False,
params: frozen_dict.FrozenDict,
**kwargs,
) -> typing.Tuple[struct.PyTreeNode, StepOutputs]:
r"""Performs a single training step.
) -> typing.Tuple[jax.Array, StepOutputs]:
"""Computes the loss given parameters and model inputs.

Args:
state (TrainState): The current training state.
batch (Any): A batch of data.
rngs (Union[Any, Dict[str, Any]]): Random generators.
**kwargs: Additional keyword arguments.
deterministic (bool): Whether to run the model in deterministic
mode (e.g., disable dropout). Default is `False`.
params (FrozenDict): The model parameters.
**kwargs: Keyword arguments consumed by the model.

Returns:
A tuple containing the updated state and step outputs.
A dictionary containing the loss and other outputs.
"""
pass
raise NotImplementedError

@abc.abstractmethod
def evaluation_step(
def forward(
self,
*,
params: jaxtyping.PyTree,
batch: typing.Any,
rngs: typing.Union[typing.Any, typing.Dict[str, typing.Any]],
rngs: typing.Any,
deterministic: bool = True,
params: frozen_dict.FrozenDict,
**kwargs,
) -> StepOutputs:
r"""Performs a single evaluation step.

Args:
params (PyTree): The model parameters.
batch (Any): A batch of data.
rngs (Union[Any, Dict[str, Any]]): Random generators.
**kwargs: Additional keyword arguments.

Returns:
The step outputs containing evaluation metrics.
"""
pass

@abc.abstractmethod
def predict_step(
self,
*,
params: jaxtyping.PyTree,
batch: typing.Any,
rngs: typing.Union[typing.Any, typing.Dict[str, typing.Any]],
**kwargs,
) -> typing.Any:
r"""Performs a single prediction step during inference.
"""Forward pass the model and returns the output tree structure.

Args:
params (PyTree): The model parameters.
batch (Any): A batch of data.
rngs (Union[Any, Dict[str, Any]]): Random generators.
**kwargs: Additional keyword arguments.
deterministic (bool): Whether to run the model in deterministic
mode (e.g., disable dropout). Default is `True`.
params (FrozenDict): The model parameters.
**kwargs: Keyword arguments consumed by the model.

Returns:
The model's predictions.
The model outputs.
"""
pass
raise NotImplementedError
Loading