MOTFM (Medical Optimal Transport Flow Matching) accelerates medical image generation while preserving, and often improving, quality, across 2D/3D and class/mask-conditional setups.
- Python: 3.9 - 3.12
- Core pinned stack (from
pyproject.toml):torch==2.5.1flow_matching==1.0.10pytorch-lightning==2.5.6numpy==1.26.4monai_generative==0.2.3
To install from pyproject.toml, run:
pip install -e .Important Note:
- Your training data must be stored in a single
.pklfile, which itself must follow the structure below.
Within that .pkl file, your data dictionary should look like:
{
"train": [ # List of training samples
{
"image": "Tensor[Channels, Height, Width, ...] (float32, normalized)",
"mask": "Tensor[1, Height, Width, ...] (int32)",
"class": "Scalar integer (int32)",
"metadata": "Structured data (dict or other format)"
},
...
],
"valid": [ # List of validation samples
{
"image": "Tensor[Channels, Height, Width, ...] (float32, normalized)",
"mask": "Tensor[1, Height, Width, ...] (int32)",
"class": "Scalar integer (int32)",
"metadata": "Structured data (dict or other format)"
},
...
],
"test": [ # List of test samples
{
"image": "Tensor[Channels, Height, Width, ...] (float32, normalized)",
"mask": "Tensor[1, Height, Width, ...] (int32)",
"class": "Scalar integer (int32)",
"metadata": "Structured data (dict or other format)"
},
...
]
}Make sure your dataset adheres to the described data structure, saved in a single .pkl file, before running the training or inference pipelines.
You must either create or modify a YAML configuration file to suit your dataset paths, model parameters, and hyperparameters. Some sample configuration files are provided in the configs/ folder. By default, configs/default.yaml is used if no custom path is provided.
To train the model, run:
python trainer.py --config_path configs/default.yamlor (after installation):
motfm-train --config_path configs/default.yaml--config_path: Path to your YAML configuration file. Defaults toconfigs/default.yamlif not provided.
Note: Make sure you have prepared your dataset (as a single .pkl file) and configuration file properly before starting training.
Use inferer.py to generate synthetic samples from a trained checkpoint and save them as a .pkl.
Run with your config and checkpoint directory:
python inferer.py \
--config_path configs/default.yaml \
--model_path mask_class_conditioning_checkpoints/default \
--num_samples 200 \
--num_inference_steps 5 \
--output_norm clip_0_1or (after installation):
motfm-infer \
--config_path configs/default.yaml \
--model_path mask_class_conditioning_checkpoints/default \
--num_samples 200 \
--num_inference_steps 5 \
--output_norm clip_0_1--config_path(str, default:configs/default.yaml): Config file used for model/data setup.--model_path(str, optional): Checkpoint.ckptfile or directory.--num_samples(int, optional): Number of samples to save. If omitted, saves all validation samples.--num_inference_steps(int, default:5): Number of solver time points used during sampling.--output_path(str, optional): Explicit output.pklpath.--overwrite(flag): Overwrite an existing file at--output_path.--output_norm(str, default:clip_0_1): One ofclip_0_1,per_sample_minmax,global_minmax,none.--allow_config_mismatch(flag): Allow loading a checkpoint whose saved critical model fields differ from current config.--seed(int, optional): Override RNG seed for reproducible inference. Defaults totrain_args.seedif provided.
If --model_path is omitted, inferer searches:
train_args.checkpoint_dir/<config_basename>
If --model_path is provided, inferer checks (in order):
<model_path><model_path>/<config_basename><model_path>/latest
If a directory is selected, checkpoint preference is:
last.ckpt(if present)- otherwise, the most recently modified
*.ckpt
- If
--output_pathis omitted, output is saved in the resolved checkpoint directory as:samples_<config_basename>_<checkpoint_name>_steps<time_points>.pkl
- If output file exists and
--overwriteis not set, a timestamp suffix is appended automatically. - Generated samples are produced from the validation split and saved under:
data_args.split_train- and also
data_args.split_valif that key is different.
If you run inference on CPU, set model_args.use_flash_attention: false in your config.
Flash attention requires CUDA and will raise an error otherwise.
A dedicated script is available in evaluation_3d/ to compute 3D metrics between two datasets:
- MMD
- MS-SSIM
- 3D-FID (R3D-18 features + MONAI FIDMetric)
python evaluation_3d/evaluate_3d.py \
--generated_path /path/to/generated.pkl \
--reference_path /path/to/reference.pkl \
--generated_split train \
--reference_split valid \
--num_samples 200Use --skip_fid to skip 3D-FID when torchvision video weights are unavailable.
2025-04-09| Code released.2025-03-29| The paper became available on arXiv.2025-05-27| The paper was accepted to MICCAI 2025.
If you find this code or our work useful in your research, please cite:
@inproceedings{yazdani2025flow,
title={Flow matching for medical image synthesis: Bridging the gap between speed and quality},
author={Yazdani, Milad and Medghalchi, Yasamin and Ashrafian, Pooria and Hacihaliloglu, Ilker and Shahriari, Dena},
booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
pages={216--226},
year={2025},
organization={Springer}
}Enjoy working with MOTFM! Feel free to open an issue or pull request if you have any questions or suggestions.
