Skip to content

Commit 9d52a41

Browse files
authored
Add projects package (#3)
1 parent 7f287cb commit 9d52a41

File tree

15 files changed

+172
-145
lines changed

15 files changed

+172
-145
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,5 @@ repos:
5858
entry: python3 -m pytest -m "not integration_test"
5959
pass_filenames: false
6060
always_run: true
61+
62+
exclude: "projects"

mmlearn/datasets/__init__.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,8 @@
55
from mmlearn.datasets.imagenet import ImageNet
66
from mmlearn.datasets.librispeech import LibriSpeech
77
from mmlearn.datasets.llvip import LLVIPDataset
8-
from mmlearn.datasets.medvqa import MedVQA
9-
from mmlearn.datasets.mimiciv_cxr import MIMICIVCXR
108
from mmlearn.datasets.nihcxr import NIHCXR
119
from mmlearn.datasets.nyuv2 import NYUv2Dataset
12-
from mmlearn.datasets.pmcoa import PMCOA
13-
from mmlearn.datasets.quilt import Quilt
14-
from mmlearn.datasets.roco import ROCO
1510
from mmlearn.datasets.sunrgbd import SUNRGBDDataset
1611

1712

@@ -21,12 +16,7 @@
2116
"ImageNet",
2217
"LibriSpeech",
2318
"LLVIPDataset",
24-
"MedVQA",
25-
"MIMICIVCXR",
2619
"NIHCXR",
2720
"NYUv2Dataset",
28-
"PMCOA",
29-
"Quilt",
30-
"ROCO",
3121
"SUNRGBDDataset",
3222
]

mmlearn/datasets/processors/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,12 @@
55
RandomMaskGenerator,
66
)
77
from mmlearn.datasets.processors.tokenizers import HFTokenizer
8-
from mmlearn.datasets.processors.transforms import (
9-
MedVQAProcessor,
10-
TrimText,
11-
med_clip_vision_transform,
12-
)
8+
from mmlearn.datasets.processors.transforms import TrimText
139

1410

1511
__all__ = [
1612
"BlockwiseImagePatchMaskGenerator",
1713
"HFTokenizer",
18-
"MedVQAProcessor",
1914
"RandomMaskGenerator",
2015
"TrimText",
21-
"med_clip_vision_transform",
2216
]
Lines changed: 1 addition & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
"""Custom transforms for datasets."""
22

3-
from typing import List, Literal, Union
3+
from typing import List, Union
44

55
from hydra_zen import store
6-
from timm.data.transforms import ResizeKeepRatio
7-
from torchvision import transforms
86

97

108
@store(group="datasets/transforms", provider="mmlearn")
@@ -30,72 +28,3 @@ def __call__(self, sentence: Union[str, List[str]]) -> Union[str, List[str]]:
3028
sentence[i] = s[: self.trim_size]
3129

3230
return sentence
33-
34-
35-
class MedVQAProcessor:
36-
"""Preprocessor for textual reports of MedVQA datasets."""
37-
38-
def __call__(self, sentence: Union[str, List[str]]) -> Union[str, List[str]]:
39-
"""Process the textual captions."""
40-
if not isinstance(sentence, (list, str)):
41-
raise TypeError(
42-
f"Expected sentence to be a string or list of strings, got {type(sentence)}"
43-
)
44-
45-
def _preprocess_sentence(sentence: str) -> str:
46-
sentence = sentence.lower()
47-
if "? -yes/no" in sentence:
48-
sentence = sentence.replace("? -yes/no", "")
49-
if "? -open" in sentence:
50-
sentence = sentence.replace("? -open", "")
51-
if "? - open" in sentence:
52-
sentence = sentence.replace("? - open", "")
53-
return (
54-
sentence.replace(",", "")
55-
.replace("?", "")
56-
.replace("'s", " 's")
57-
.replace("...", "")
58-
.replace("x ray", "x-ray")
59-
.replace(".", "")
60-
)
61-
62-
if isinstance(sentence, str):
63-
return _preprocess_sentence(sentence)
64-
65-
for i, s in enumerate(sentence):
66-
sentence[i] = _preprocess_sentence(s)
67-
68-
return sentence
69-
70-
71-
@store(group="datasets/transforms", provider="mmlearn") # type: ignore[misc]
72-
def med_clip_vision_transform(
73-
image_crop_size: int = 224, job_type: Literal["train", "eval"] = "train"
74-
) -> transforms.Compose:
75-
"""Return transforms for training/evaluating CLIP with medical images.
76-
77-
Parameters
78-
----------
79-
image_crop_size : int, default=224
80-
Size of the image crop.
81-
job_type : {"train", "eval"}, default="train"
82-
Type of the job (training or evaluation) for which the transforms are needed.
83-
84-
Returns
85-
-------
86-
transforms.Compose
87-
Composed transforms for training CLIP with medical images.
88-
"""
89-
return transforms.Compose(
90-
[
91-
ResizeKeepRatio(512, interpolation="bicubic"),
92-
transforms.RandomCrop(image_crop_size)
93-
if job_type == "train"
94-
else transforms.CenterCrop(image_crop_size),
95-
transforms.ToTensor(),
96-
transforms.Normalize(
97-
mean=[0.48145466, 0.4578275, 0.40821073],
98-
std=[0.26862954, 0.26130258, 0.27577711],
99-
),
100-
]
101-
)

projects/__init__.py

Whitespace-only changes.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
## Benchmarking CLIP-style Methods on Medical Data
2+
Prior to running any experiments under this project, please install the required dependencies by running the following command:
3+
```bash
4+
pip install -r requirements.txt
5+
```
6+
**NOTE**: It is assumed that the requirements for the `mmlearn` package have already been installed in a virtual environment.
7+
If not, please refer to the README file in the `mmlearn` package for installation instructions.
8+
9+
Also, please make sure to set the following environment variables:
10+
```bash
11+
export MIMICIVCXR_ROOT_DIR=/path/to/mimic-cxr/data
12+
export PMCOA_ROOT_DIR=/path/to/pmc_oa/data
13+
export QUILT_ROOT_DIR=/path/to/quilt/data
14+
export ROCO_ROOT_DIR=/path/to/roco/data
15+
```
16+
17+
If you are running an experiment with the MedVQA dataset, please also set the following environment variables:
18+
```bash
19+
export PATHVQA_ROOT_DIR=/path/to/pathvqa/data
20+
export VQARAD_ROOT_DIR=/path/to/vqarad/data
21+
```
22+
23+
To run an experiment, use the following command:
24+
25+
**To Run Locally**:
26+
```bash
27+
mmlearn_run 'hydra.searchpath=[pkg://projects.med_benchmarking.configs]' +experiment=baseline experiment_name=test
28+
```
29+
30+
**To Run on a SLURM Cluster**:
31+
```bash
32+
mmlearn_run --multirun hydra.launcher.mem_gb=32 hydra.launcher.qos=your_qos hydra.launcher.partition=your_partition hydra.launcher.gres=gpu:4 hydra.launcher.cpus_per_task=8 hydra.launcher.tasks_per_node=4 hydra.launcher.nodes=1 hydra.launcher.stderr_to_stdout=true hydra.launcher.timeout_min=60 '+hydra.launcher.additional_parameters={export: ALL}' 'hydra.searchpath=[pkg://projects.med_benchmarking.configs]' +experiment=baseline experiment_name=test
33+
```
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import os
2+
from typing import Literal
3+
4+
from hydra_zen import builds, store
5+
from omegaconf import MISSING
6+
from timm.data.transforms import ResizeKeepRatio
7+
from torchvision import transforms
8+
9+
from mmlearn.conf import external_store
10+
from projects.med_benchmarking.datasets.medvqa import MedVQA, MedVQAProcessor
11+
from projects.med_benchmarking.datasets.mimiciv_cxr import MIMICIVCXR
12+
from projects.med_benchmarking.datasets.pmcoa import PMCOA
13+
from projects.med_benchmarking.datasets.quilt import Quilt
14+
from projects.med_benchmarking.datasets.roco import ROCO
15+
16+
17+
_MedVQAConf = builds(
18+
MedVQA,
19+
split="train",
20+
encoder={"image_size": 224, "feat_dim": 512, "images_filename": "images_clip.pkl"},
21+
autoencoder={
22+
"available": True,
23+
"image_size": 128,
24+
"feat_dim": 64,
25+
"images_filename": "images128x128.pkl",
26+
},
27+
num_ans_candidates=MISSING,
28+
)
29+
_PathVQAConf = builds(
30+
MedVQA,
31+
root_dir=os.getenv("PATHVQA_ROOT_DIR", MISSING),
32+
num_ans_candidates=3974,
33+
autoencoder={"available": False},
34+
builds_bases=(_MedVQAConf,),
35+
)
36+
_VQARADConf = builds(
37+
MedVQA,
38+
root_dir=os.getenv("VQARAD_ROOT_DIR", MISSING),
39+
num_ans_candidates=458,
40+
autoencoder={"available": False},
41+
builds_bases=(_MedVQAConf,),
42+
)
43+
external_store(_MedVQAConf, name="MedVQA", group="datasets")
44+
external_store(_PathVQAConf, name="PathVQA", group="datasets")
45+
external_store(_VQARADConf, name="VQARAD", group="datasets")
46+
47+
external_store(MedVQAProcessor, name="MedVQAProcessor", group="datasets/transforms")
48+
49+
50+
@external_store(group="datasets/transforms")
51+
def med_clip_vision_transform(
52+
image_crop_size: int = 224, job_type: Literal["train", "eval"] = "train"
53+
) -> transforms.Compose:
54+
"""Return transforms for training/evaluating CLIP with medical images.
55+
56+
Parameters
57+
----------
58+
image_crop_size : int, default=224
59+
Size of the image crop.
60+
job_type : {"train", "eval"}, default="train"
61+
Type of the job (training or evaluation) for which the transforms are needed.
62+
63+
Returns
64+
-------
65+
transforms.Compose
66+
Composed transforms for training CLIP with medical images.
67+
"""
68+
return transforms.Compose(
69+
[
70+
ResizeKeepRatio(
71+
512 if job_type == "train" else image_crop_size, interpolation="bicubic"
72+
),
73+
transforms.RandomCrop(image_crop_size)
74+
if job_type == "train"
75+
else transforms.CenterCrop(image_crop_size),
76+
transforms.ToTensor(),
77+
transforms.Normalize(
78+
mean=[0.48145466, 0.4578275, 0.40821073],
79+
std=[0.26862954, 0.26130258, 0.27577711],
80+
),
81+
]
82+
)

configs/experiment/med_contrastive_pretraining.yaml renamed to projects/med_benchmarking/configs/experiment/baseline.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ task:
6969
eps: 1.0e-6
7070
lr_scheduler:
7171
scheduler:
72-
T_max: 537_775
72+
T_max: 107_555 # make sure to change this if max_epochs or accumulate_grad_batches is changed
7373
extras:
7474
interval: step
7575
loss:
@@ -80,15 +80,15 @@ task:
8080
task_specs:
8181
- query_modality: text
8282
target_modality: rgb
83-
top_k: [200]
83+
top_k: [10, 200]
8484
- query_modality: rgb
8585
target_modality: text
86-
top_k: [200]
86+
top_k: [10, 200]
8787
run_on_validation: false
8888
run_on_test: true
8989

9090
trainer:
91-
max_epochs: 100
91+
max_epochs: 20
9292
precision: 16-mixed
9393
deterministic: False
9494
benchmark: True

projects/med_benchmarking/datasets/__init__.py

Whitespace-only changes.

mmlearn/datasets/medvqa.py renamed to projects/med_benchmarking/datasets/medvqa.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
import json
44
import os
55
import warnings
6-
from typing import Any, Callable, Dict, Literal, Optional
6+
from typing import Any, Callable, Dict, List, Literal, Optional, Union
77

88
import numpy as np
99
import torch
10-
from hydra_zen import MISSING, builds, store
1110
from PIL import Image
1211
from torch.utils.data import Dataset
1312
from torchvision.transforms import CenterCrop, Compose, Grayscale, Resize, ToTensor
@@ -205,32 +204,37 @@ def __len__(self) -> int:
205204
return len(self.entries)
206205

207206

208-
_MedVQAConf = builds(
209-
MedVQA,
210-
split="train",
211-
encoder={"image_size": 224, "feat_dim": 512, "images_filename": "images_clip.pkl"},
212-
autoencoder={
213-
"available": True,
214-
"image_size": 128,
215-
"feat_dim": 64,
216-
"images_filename": "images128x128.pkl",
217-
},
218-
num_ans_candidates=MISSING,
219-
)
220-
_PathVQAConf = builds(
221-
MedVQA,
222-
root_dir=os.getenv("PATHVQA_ROOT_DIR", MISSING),
223-
num_ans_candidates=3974,
224-
autoencoder={"available": False},
225-
builds_bases=(_MedVQAConf,),
226-
)
227-
_VQARADConf = builds(
228-
MedVQA,
229-
root_dir=os.getenv("VQARAD_ROOT_DIR", MISSING),
230-
num_ans_candidates=458,
231-
autoencoder={"available": False},
232-
builds_bases=(_MedVQAConf,),
233-
)
234-
store(_MedVQAConf, name="MedVQA", group="datasets", provider="mmlearn")
235-
store(_PathVQAConf, name="PathVQA", group="datasets", provider="mmlearn")
236-
store(_VQARADConf, name="VQARAD", group="datasets", provider="mmlearn")
207+
class MedVQAProcessor:
208+
"""Preprocessor for textual reports of MedVQA datasets."""
209+
210+
def __call__(self, sentence: Union[str, List[str]]) -> Union[str, List[str]]:
211+
"""Process the textual captions."""
212+
if not isinstance(sentence, (list, str)):
213+
raise TypeError(
214+
f"Expected sentence to be a string or list of strings, got {type(sentence)}"
215+
)
216+
217+
def _preprocess_sentence(sentence: str) -> str:
218+
sentence = sentence.lower()
219+
if "? -yes/no" in sentence:
220+
sentence = sentence.replace("? -yes/no", "")
221+
if "? -open" in sentence:
222+
sentence = sentence.replace("? -open", "")
223+
if "? - open" in sentence:
224+
sentence = sentence.replace("? - open", "")
225+
return (
226+
sentence.replace(",", "")
227+
.replace("?", "")
228+
.replace("'s", " 's")
229+
.replace("...", "")
230+
.replace("x ray", "x-ray")
231+
.replace(".", "")
232+
)
233+
234+
if isinstance(sentence, str):
235+
return _preprocess_sentence(sentence)
236+
237+
for i, s in enumerate(sentence):
238+
sentence[i] = _preprocess_sentence(s)
239+
240+
return sentence

0 commit comments

Comments
 (0)