Conversation
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
… models Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu <juanwu@purdue.edu>
Signed-off-by: Juanwu <juanwu@purdue.edu>
Signed-off-by: Juanwu <juanwu@purdue.edu>
Signed-off-by: Juanwu <juanwu@purdue.edu>
Signed-off-by: Juanwu <juanwu@purdue.edu>
…nditions Signed-off-by: Juanwu <juanwu@purdue.edu>
Signed-off-by: Juanwu <juanwu@purdue.edu>
…tions Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
…rd pass of meanflow model Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
…le of timestamps Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
There was a problem hiding this comment.
Pull request overview
This pull request introduces a comprehensive refactoring of the ML training infrastructure to support MeanFlow generative models with U-Net architecture on CIFAR-10. The changes improve modularity by decoupling training/evaluation loops from model implementations and add support for Apple Silicon (MPS) hardware.
Key Changes
- Module refactoring: Renamed
datamodule todatamodulethroughout the codebase for clarity - Training interface improvements: Replaced model-bound
training_step/evaluation_stepmethods with explicit callable functions, enabling more flexible training pipelines - New U-Net implementation: Added complete U-Net architecture for score-based generative modeling with attention mechanisms and skip connections
- MPS support: Added infrastructure for Apple Silicon GPU acceleration through
jax-metaldependency
Reviewed changes
Copilot reviewed 25 out of 26 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| MODULE.bazel | Added MPS pip dependency target for Apple Silicon support |
| third_party/requirements.in | Downgraded chex version; added jax-metal requirements |
| third_party/defs.bzl | Added MPS platform selection logic and conditional jax-metal dependency |
| third_party/BUILD | Added MPS config setting and requirements compilation target |
| src/core/model.py | Refactored interface: replaced step methods with compute_loss and forward |
| src/core/train.py | Accepts explicit training/evaluation step functions; improved error handling and logging |
| src/core/evaluate.py | Updated to use explicit evaluation step function; enhanced metric logging |
| src/core/config.py | Updated imports from data to datamodule; added checkpoint frequency config |
| src/data/huggingface.py | Major refactor: improved dataset caching, removed seed in favor of rng, simplified data loading |
| src/projects/generative/meanflow.py | Complete rewrite using U-Net backbone; improved timestamp conditioning and loss computation |
| src/projects/generative/model/unet.py | New U-Net implementation with ResNet blocks, attention, and up/downsampling |
| src/projects/generative/model/refinenet.py | Updated type hints and documentation; fixed deprecated JAX APIs |
| src/projects/generative/main.py | New training entry point with configuration support and proper device setup |
| src/utilities/visualization.py | New utility for creating image grids for visualization |
Comments suppressed due to low confidence (1)
src/core/evaluate.py:140
- The condition in line 93 checks for
outputs.scalars is not None, but the code inside at line 136 accessesoutputs.scalars.items()without checking ifoutputsitself is None. Ifoutputsis None (which could happen if no batches were processed), this will raise anAttributeError. Consider adding a checkif outputs is not None and outputs.scalars is not None:before accessing the scalars.
if outputs.scalars is not None:
writer.write_scalars(
step=step,
scalars={
f"eval/{k}_step": sum(v) / len(v)
for k, v in outputs.scalars.items()
},
)
if outputs.images is not None:
writer.write_images(
step=step,
images={
f"eval/{k}_step": v
for k, v in outputs.images.items()
},
)
if outputs.histograms is not None:
writer.write_histograms(
step=step,
arrays={
f"eval/{k}_step": v
for k, v in outputs.histograms.items()
},
)
writer.flush()
# logging at the end of evaluation
logging.rank_zero_info("Evaluation done.")
scalar_output = {
f"eval/{k.replace('_', ' ')}_epoch": sum(v) / len(v)
for k, v in eval_metrics.items()
}
writer.write_scalars(
step=step,
scalars=scalar_output,
)
writer.flush()
except Exception as e:
logging.rank_zero_error(
"Exception occurred during evaluation: %s", e
)
error_trace = traceback.format_exc()
logging.rank_zero_error("Stack trace:\n%s", error_trace)
_status = 1
finally:
writer.close()
logging.rank_zero_info(
"Evaluation done. Exit with code %d.",
_status,
)
return _status
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
This was referenced Dec 2, 2025
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
Signed-off-by: Juanwu Lu <juanwu@purdue.edu>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This pull request introduces several important refactorings and improvements to the core training and evaluation pipeline, with a focus on modularity, extensibility, and improved logging. The main changes include replacing references to the old
datamodule withdatamodule, updating the evaluation and training step interfaces, enhancing logging and error handling, and adding support for new dependency targets. Key changes are as follows:Refactor to use
datamoduleinstead ofdatadatamodule have been updated to usedatamodulethroughout the codebase, including imports, type annotations, and Bazel dependencies. This improves clarity and modularity in the data handling pipeline. [1] [2] [3] [4] [5] [6] [7]Training and evaluation loop improvements
train.pyandevaluate.py) now accept explicittraining_stepandevaluation_stepcallables instead of relying on model methods, allowing for more flexible and decoupled step function definitions. [1] [2] [3] [4] [5] [6]Model interface refactor
Modelclass interface has been refactored: thetraining_step,evaluation_step, andpredict_stepmethods have been replaced with more genericcompute_lossandforwardmethods. TheStepOutputscontainer now also supports histogram outputs. This makes the model API more extensible and explicit. [1] [2]Evaluation and logging enhancements
_stepand_epochsuffixes).Bazel and dependency updates
ml_infra_mps_3_10pip dependency target toMODULE.bazelfor MPS (Apple Silicon) support, and updated Bazel build dependencies for clarity and correctness. [1] [2]These changes collectively improve the maintainability, extensibility, and robustness of the core ML infrastructure.