Skip to content

Flow Matching for Medical Image Synthesis: Bridging the Gap Between Speed and Quality

License

Notifications You must be signed in to change notification settings

milad1378yz/MOTFM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

32 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MOTFM (Medical Optimal Transport Flow Matching)

arXiv License Stars Repo Size MICCAI 2025

MOTFM (Medical Optimal Transport Flow Matching) accelerates medical image generation while preserving, and often improving, quality, across 2D/3D and class/mask-conditional setups.



Requirements

  • Python: 3.9 - 3.12
  • Core pinned stack (from pyproject.toml):
    • torch==2.5.1
    • flow_matching==1.0.10
    • pytorch-lightning==2.5.6
    • numpy==1.26.4
    • monai_generative==0.2.3

To install from pyproject.toml, run:

pip install -e .

Data Preparation

Important Note:

  • Your training data must be stored in a single .pkl file, which itself must follow the structure below.

Within that .pkl file, your data dictionary should look like:

{
  "train": [  # List of training samples
    {
      "image": "Tensor[Channels, Height, Width, ...] (float32, normalized)",
      "mask":  "Tensor[1, Height, Width, ...] (int32)",
      "class": "Scalar integer (int32)",
      "metadata": "Structured data (dict or other format)"
    },
    ...
  ],

  "valid": [  # List of validation samples
    {
      "image": "Tensor[Channels, Height, Width, ...] (float32, normalized)",
      "mask":  "Tensor[1, Height, Width, ...] (int32)",
      "class": "Scalar integer (int32)",
      "metadata": "Structured data (dict or other format)"
    },
    ...
  ],

  "test": [  # List of test samples
    {
      "image": "Tensor[Channels, Height, Width, ...] (float32, normalized)",
      "mask":  "Tensor[1, Height, Width, ...] (int32)",
      "class": "Scalar integer (int32)",
      "metadata": "Structured data (dict or other format)"
    },
    ...
  ]
}

Make sure your dataset adheres to the described data structure, saved in a single .pkl file, before running the training or inference pipelines.


Configuration Files

You must either create or modify a YAML configuration file to suit your dataset paths, model parameters, and hyperparameters. Some sample configuration files are provided in the configs/ folder. By default, configs/default.yaml is used if no custom path is provided.


Training

To train the model, run:

python trainer.py --config_path configs/default.yaml

or (after installation):

motfm-train --config_path configs/default.yaml
  • --config_path: Path to your YAML configuration file. Defaults to configs/default.yaml if not provided.

Note: Make sure you have prepared your dataset (as a single .pkl file) and configuration file properly before starting training.


Inference

Use inferer.py to generate synthetic samples from a trained checkpoint and save them as a .pkl.

Quick start

Run with your config and checkpoint directory:

python inferer.py \
    --config_path configs/default.yaml \
    --model_path mask_class_conditioning_checkpoints/default \
    --num_samples 200 \
    --num_inference_steps 5 \
    --output_norm clip_0_1

or (after installation):

motfm-infer \
    --config_path configs/default.yaml \
    --model_path mask_class_conditioning_checkpoints/default \
    --num_samples 200 \
    --num_inference_steps 5 \
    --output_norm clip_0_1

Arguments

  • --config_path (str, default: configs/default.yaml): Config file used for model/data setup.
  • --model_path (str, optional): Checkpoint .ckpt file or directory.
  • --num_samples (int, optional): Number of samples to save. If omitted, saves all validation samples.
  • --num_inference_steps (int, default: 5): Number of solver time points used during sampling.
  • --output_path (str, optional): Explicit output .pkl path.
  • --overwrite (flag): Overwrite an existing file at --output_path.
  • --output_norm (str, default: clip_0_1): One of clip_0_1, per_sample_minmax, global_minmax, none.
  • --allow_config_mismatch (flag): Allow loading a checkpoint whose saved critical model fields differ from current config.
  • --seed (int, optional): Override RNG seed for reproducible inference. Defaults to train_args.seed if provided.

Checkpoint resolution behavior

If --model_path is omitted, inferer searches:

  • train_args.checkpoint_dir/<config_basename>

If --model_path is provided, inferer checks (in order):

  • <model_path>
  • <model_path>/<config_basename>
  • <model_path>/latest

If a directory is selected, checkpoint preference is:

  • last.ckpt (if present)
  • otherwise, the most recently modified *.ckpt

Output behavior

  • If --output_path is omitted, output is saved in the resolved checkpoint directory as:
    • samples_<config_basename>_<checkpoint_name>_steps<time_points>.pkl
  • If output file exists and --overwrite is not set, a timestamp suffix is appended automatically.
  • Generated samples are produced from the validation split and saved under:
    • data_args.split_train
    • and also data_args.split_val if that key is different.

CPU-only note

If you run inference on CPU, set model_args.use_flash_attention: false in your config.
Flash attention requires CUDA and will raise an error otherwise.


3D Evaluation

A dedicated script is available in evaluation_3d/ to compute 3D metrics between two datasets:

  • MMD
  • MS-SSIM
  • 3D-FID (R3D-18 features + MONAI FIDMetric)
python evaluation_3d/evaluate_3d.py \
    --generated_path /path/to/generated.pkl \
    --reference_path /path/to/reference.pkl \
    --generated_split train \
    --reference_split valid \
    --num_samples 200

Use --skip_fid to skip 3D-FID when torchvision video weights are unavailable.


News

  • 2025-04-09 | Code released.
  • 2025-03-29 | The paper became available on arXiv.
  • 2025-05-27 | The paper was accepted to MICCAI 2025.

Citation

If you find this code or our work useful in your research, please cite:

@inproceedings{yazdani2025flow,
  title={Flow matching for medical image synthesis: Bridging the gap between speed and quality},
  author={Yazdani, Milad and Medghalchi, Yasamin and Ashrafian, Pooria and Hacihaliloglu, Ilker and Shahriari, Dena},
  booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
  pages={216--226},
  year={2025},
  organization={Springer}
}

Enjoy working with MOTFM! Feel free to open an issue or pull request if you have any questions or suggestions.

Packages

 
 
 

Languages