Skip to content

MeanFlow with U-Net on CIFAR-10#8

Merged
juanwulu merged 67 commits intomasterfrom
meanflow
Dec 2, 2025
Merged

MeanFlow with U-Net on CIFAR-10#8
juanwulu merged 67 commits intomasterfrom
meanflow

Conversation

@juanwulu
Copy link
Member

@juanwulu juanwulu commented Dec 1, 2025

Description

This pull request introduces several important refactorings and improvements to the core training and evaluation pipeline, with a focus on modularity, extensibility, and improved logging. The main changes include replacing references to the old data module with datamodule, updating the evaluation and training step interfaces, enhancing logging and error handling, and adding support for new dependency targets. Key changes are as follows:

Refactor to use datamodule instead of data

  • All references to the old data module have been updated to use datamodule throughout the codebase, including imports, type annotations, and Bazel dependencies. This improves clarity and modularity in the data handling pipeline. [1] [2] [3] [4] [5] [6] [7]

Training and evaluation loop improvements

  • The training and evaluation loops (train.py and evaluate.py) now accept explicit training_step and evaluation_step callables instead of relying on model methods, allowing for more flexible and decoupled step function definitions. [1] [2] [3] [4] [5] [6]
  • Enhanced logging and error handling have been added, including stack traces on exceptions and more informative status messages during compilation and evaluation. [1] [2] [3]

Model interface refactor

  • The Model class interface has been refactored: the training_step, evaluation_step, and predict_step methods have been replaced with more generic compute_loss and forward methods. The StepOutputs container now also supports histogram outputs. This makes the model API more extensible and explicit. [1] [2]

Evaluation and logging enhancements

  • The evaluation loop now logs histograms, flushes the writer more frequently, and improves metric naming consistency (e.g., using _step and _epoch suffixes).
  • Improved error reporting with stack traces and ensured proper resource cleanup (writer closure) in evaluation.

Bazel and dependency updates

  • Added a new ml_infra_mps_3_10 pip dependency target to MODULE.bazel for MPS (Apple Silicon) support, and updated Bazel build dependencies for clarity and correctness. [1] [2]

These changes collectively improve the maintainability, extensibility, and robustness of the core ML infrastructure.

juanwulu and others added 30 commits November 19, 2025 04:40
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
… models

Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu <juanwu@purdue.edu>
Signed-off-by: Juanwu <juanwu@purdue.edu>
Signed-off-by: Juanwu <juanwu@purdue.edu>
Signed-off-by: Juanwu <juanwu@purdue.edu>
…nditions

Signed-off-by: Juanwu <juanwu@purdue.edu>
Signed-off-by: Juanwu <juanwu@purdue.edu>
…tions

Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
…rd pass of meanflow model

Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
…le of timestamps

Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
@juanwulu juanwulu added this to the 2025.10 milestone Dec 1, 2025
@juanwulu juanwulu self-assigned this Dec 1, 2025
@juanwulu juanwulu added the enhacements New features or enhancements to existing ones. label Dec 1, 2025
@juanwulu juanwulu requested a review from Copilot December 1, 2025 22:06
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request introduces a comprehensive refactoring of the ML training infrastructure to support MeanFlow generative models with U-Net architecture on CIFAR-10. The changes improve modularity by decoupling training/evaluation loops from model implementations and add support for Apple Silicon (MPS) hardware.

Key Changes

  • Module refactoring: Renamed data module to datamodule throughout the codebase for clarity
  • Training interface improvements: Replaced model-bound training_step/evaluation_step methods with explicit callable functions, enabling more flexible training pipelines
  • New U-Net implementation: Added complete U-Net architecture for score-based generative modeling with attention mechanisms and skip connections
  • MPS support: Added infrastructure for Apple Silicon GPU acceleration through jax-metal dependency

Reviewed changes

Copilot reviewed 25 out of 26 changed files in this pull request and generated 11 comments.

Show a summary per file
File Description
MODULE.bazel Added MPS pip dependency target for Apple Silicon support
third_party/requirements.in Downgraded chex version; added jax-metal requirements
third_party/defs.bzl Added MPS platform selection logic and conditional jax-metal dependency
third_party/BUILD Added MPS config setting and requirements compilation target
src/core/model.py Refactored interface: replaced step methods with compute_loss and forward
src/core/train.py Accepts explicit training/evaluation step functions; improved error handling and logging
src/core/evaluate.py Updated to use explicit evaluation step function; enhanced metric logging
src/core/config.py Updated imports from data to datamodule; added checkpoint frequency config
src/data/huggingface.py Major refactor: improved dataset caching, removed seed in favor of rng, simplified data loading
src/projects/generative/meanflow.py Complete rewrite using U-Net backbone; improved timestamp conditioning and loss computation
src/projects/generative/model/unet.py New U-Net implementation with ResNet blocks, attention, and up/downsampling
src/projects/generative/model/refinenet.py Updated type hints and documentation; fixed deprecated JAX APIs
src/projects/generative/main.py New training entry point with configuration support and proper device setup
src/utilities/visualization.py New utility for creating image grids for visualization
Comments suppressed due to low confidence (1)

src/core/evaluate.py:140

  • The condition in line 93 checks for outputs.scalars is not None, but the code inside at line 136 accesses outputs.scalars.items() without checking if outputs itself is None. If outputs is None (which could happen if no batches were processed), this will raise an AttributeError. Consider adding a check if outputs is not None and outputs.scalars is not None: before accessing the scalars.
                if outputs.scalars is not None:
                    writer.write_scalars(
                        step=step,
                        scalars={
                            f"eval/{k}_step": sum(v) / len(v)
                            for k, v in outputs.scalars.items()
                        },
                    )
                if outputs.images is not None:
                    writer.write_images(
                        step=step,
                        images={
                            f"eval/{k}_step": v
                            for k, v in outputs.images.items()
                        },
                    )
                if outputs.histograms is not None:
                    writer.write_histograms(
                        step=step,
                        arrays={
                            f"eval/{k}_step": v
                            for k, v in outputs.histograms.items()
                        },
                    )
                writer.flush()

            # logging at the end of evaluation
            logging.rank_zero_info("Evaluation done.")
            scalar_output = {
                f"eval/{k.replace('_', ' ')}_epoch": sum(v) / len(v)
                for k, v in eval_metrics.items()
            }
            writer.write_scalars(
                step=step,
                scalars=scalar_output,
            )
            writer.flush()

        except Exception as e:
            logging.rank_zero_error(
                "Exception occurred during evaluation: %s", e
            )
            error_trace = traceback.format_exc()
            logging.rank_zero_error("Stack trace:\n%s", error_trace)
            _status = 1
        finally:
            writer.close()
            logging.rank_zero_info(
                "Evaluation done. Exit with code %d.",
                _status,
            )
    return _status

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
@juanwulu juanwulu merged commit 2b93386 into master Dec 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhacements New features or enhancements to existing ones.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants