Skip to content

Commit e20855a

Browse files
committed
more changes
1 parent 067c422 commit e20855a

File tree

4 files changed

+41
-60
lines changed

4 files changed

+41
-60
lines changed

README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,23 @@ python src/emma_perception/commands/run_server.py
5454

5555
### Extracting features
5656

57+
For training things, we need to extract the features for each image.
5758

58-
#### For the pretrained datasets
59+
Here's the command you can use to extract features from images. Obviously, you can change the paths to the folder of images, and the output dir, and whatever else you want.
5960

6061
```bash
62+
python src/emma_perception/commands/extract_visual_features.py --images_dir <path_to_images> --output_dir <path to output dir>
6163
```
6264

65+
<details>
66+
<summary>`argparse` arguments for the command</summary>
6367

64-
#### For the Alexa Arena
6568

69+
</details>
6670

67-
```bash
68-
```
71+
#### Extracting features for the Alexa Arena
6972

73+
If you want to use the fine-tuned model to extract features with the model we trained on the Alexa Arena, just add `--is_arena` onto the above command.
7074

7175
### Developer tooling
7276

src/emma_perception/commands/download_checkpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def download_arena_checkpoint(
1919

2020

2121
def download_vinvl_checkpoint(
22-
*, hf_repo_id: str = HF_REPO_ID, file_name: str = ARENA_CHECKPOINT_NAME
22+
*, hf_repo_id: str = HF_REPO_ID, file_name: str = VINVL_CHECKPOINT_NAME
2323
) -> Path:
2424
"""Download the pre-trained VinVL checkpoint."""
2525
file_path = download_file(repo_id=hf_repo_id, repo_type="model", filename=file_name)

src/emma_perception/commands/extract_visual_features.py

Lines changed: 30 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,45 @@
11
import argparse
2-
from typing import Union
2+
from ast import arg
33

44
from maskrcnn_benchmark.config import cfg
55
from pytorch_lightning import Trainer
66
from scene_graph_benchmark.config import sg_cfg
77

88
from emma_perception.callbacks.callbacks import VisualExtractionCacheCallback
9-
from emma_perception.datamodules.visual_extraction_dataset import (
10-
ImageDataset,
11-
PredictDataModule,
12-
VideoFrameDataset,
9+
from emma_perception.commands.download_checkpoints import (
10+
download_arena_checkpoint,
11+
download_vinvl_checkpoint,
1312
)
13+
from emma_perception.datamodules.visual_extraction_dataset import ImageDataset, PredictDataModule
1414
from emma_perception.models.vinvl_extractor import VinVLExtractor, VinVLTransform
1515

1616

1717
def parse_args() -> argparse.Namespace:
1818
"""Defines arguments."""
19-
parser = argparse.ArgumentParser(prog="PROG")
20-
19+
parser = argparse.ArgumentParser()
2120
parser = Trainer.add_argparse_args(parser) # type: ignore[assignment]
22-
parser.add_argument("-i", "--input_path", required=True, help="Path to input dataset")
23-
parser.add_argument("-b", "--batch_size", type=int, default=2)
24-
parser.add_argument("-w", "--num_workers", type=int, default=0)
25-
parser.add_argument("-cs", "--cache_suffix", default=".pt", help="Extension of cached files")
26-
parser.add_argument("--config_file", metavar="FILE", help="path to VinVL config file")
27-
parser.add_argument("--return_predictions", action="store_true")
28-
29-
parser.add_argument(
30-
"--downsample",
31-
type=int,
32-
default=0,
33-
help="Downsampling factor for videos. If 0 then no downsampling is performed",
34-
)
35-
3621
parser.add_argument(
37-
"-c", "--cache", default="storage/data/cache", help="Path to store visual features"
22+
"-i",
23+
"--images_dir",
24+
required=True,
25+
help="Path to a folder of images to extract features from",
3826
)
39-
4027
parser.add_argument(
41-
"-d", "--dataset", required=True, choices=["images", "frames"], help="Dataset type"
28+
"--is_arena",
29+
action="store_true",
30+
help="If we are extracting features from the Arena images, use the Arena checkpoint",
4231
)
32+
parser.add_argument("-b", "--batch_size", type=int, default=2)
33+
parser.add_argument("-w", "--num_workers", type=int, default=0)
4334
parser.add_argument(
44-
"-a",
45-
"--ann_csv",
46-
help="Path to annotation csv file. Used for video datasets to select only the frames that have annotations",
35+
"-c", "--output_dir", default="storage/data/cache", help="Path to store visual features"
4736
)
48-
4937
parser.add_argument(
50-
"-at",
51-
"--ann_type",
52-
choices=["epic_kitchens"],
53-
default="epic_kitchens",
54-
help="Annotation parser for video datasets",
38+
"--num_gpus",
39+
type=int,
40+
default=None,
41+
help="Number of GPUs to use for visual feature extraction",
5542
)
56-
5743
parser.add_argument(
5844
"opts",
5945
default=None,
@@ -75,34 +61,23 @@ def main() -> None:
7561
cfg.merge_from_list(args.opts)
7662
cfg.freeze()
7763

78-
extractor = VinVLExtractor(cfg=cfg)
79-
transform = VinVLTransform(cfg=cfg)
80-
81-
dataset: Union[ImageDataset, VideoFrameDataset]
82-
if args.dataset == "images":
83-
dataset = ImageDataset(input_path=args.input_path, preprocess_transform=transform)
84-
elif args.dataset == "frames":
85-
dataset = VideoFrameDataset(
86-
input_path=args.input_path,
87-
ann_csv=args.ann_csv,
88-
ann_type=args.ann_type,
89-
preprocess_transform=transform,
90-
downsample=args.downsample,
91-
)
64+
if args.is_arena:
65+
cfg.MODEL.WEIGHT = download_arena_checkpoint().as_posix()
9266
else:
93-
raise OSError(f"Unsupported dataset type {args.dataset}")
67+
cfg.MODEL.WEIGHT = download_vinvl_checkpoint().as_posix()
9468

69+
dataset = ImageDataset(
70+
input_path=args.images_dir, preprocess_transform=VinVLTransform(cfg=cfg)
71+
)
9572
dm = PredictDataModule(
9673
dataset=dataset, batch_size=args.batch_size, num_workers=args.num_workers
9774
)
75+
extractor = VinVLExtractor(cfg=cfg)
9876
trainer = Trainer(
99-
gpus=args.gpus,
100-
callbacks=[
101-
VisualExtractionCacheCallback(cache_dir=args.cache, cache_suffix=args.cache_suffix)
102-
],
103-
profiler="advanced",
77+
gpus=args.num_gpus,
78+
callbacks=[VisualExtractionCacheCallback(cache_dir=args.output_dir, cache_suffix=".pt")],
10479
)
105-
trainer.predict(extractor, dm, return_predictions=args.return_predictions)
80+
trainer.predict(extractor, dm)
10681

10782

10883
if __name__ == "__main__":

src/emma_perception/constants/vinvl_x152c4.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ MODEL:
1717
SCORE_THRESH: 0.2 # 0.0001
1818
DETECTIONS_PER_IMG: 36 # 600
1919
MIN_DETECTIONS_PER_IMG: 10
20+
NMS_FILTER: 1
2021
ROI_BOX_HEAD:
2122
NUM_CLASSES: 1595
2223
ROI_ATTRIBUTE_HEAD:
@@ -52,6 +53,7 @@ TEST:
5253
TSV_SAVE_SUBSET: ["rect", "class", "conf", "feature"]
5354
OUTPUT_FEATURE: True
5455
GATHER_ON_CPU: True
56+
IGNORE_BOX_REGRESSION: False
5557
OUTPUT_DIR: "./output/X152C5_test"
5658
DATA_DIR: "./datasets"
5759
DISTRIBUTED_BACKEND: "gloo"

0 commit comments

Comments
 (0)