11import argparse
2- from typing import Union
2+ from ast import arg
33
44from maskrcnn_benchmark .config import cfg
55from pytorch_lightning import Trainer
66from scene_graph_benchmark .config import sg_cfg
77
88from emma_perception .callbacks .callbacks import VisualExtractionCacheCallback
9- from emma_perception .datamodules .visual_extraction_dataset import (
10- ImageDataset ,
11- PredictDataModule ,
12- VideoFrameDataset ,
9+ from emma_perception .commands .download_checkpoints import (
10+ download_arena_checkpoint ,
11+ download_vinvl_checkpoint ,
1312)
13+ from emma_perception .datamodules .visual_extraction_dataset import ImageDataset , PredictDataModule
1414from emma_perception .models .vinvl_extractor import VinVLExtractor , VinVLTransform
1515
1616
1717def parse_args () -> argparse .Namespace :
1818 """Defines arguments."""
19- parser = argparse .ArgumentParser (prog = "PROG" )
20-
19+ parser = argparse .ArgumentParser ()
2120 parser = Trainer .add_argparse_args (parser ) # type: ignore[assignment]
22- parser .add_argument ("-i" , "--input_path" , required = True , help = "Path to input dataset" )
23- parser .add_argument ("-b" , "--batch_size" , type = int , default = 2 )
24- parser .add_argument ("-w" , "--num_workers" , type = int , default = 0 )
25- parser .add_argument ("-cs" , "--cache_suffix" , default = ".pt" , help = "Extension of cached files" )
26- parser .add_argument ("--config_file" , metavar = "FILE" , help = "path to VinVL config file" )
27- parser .add_argument ("--return_predictions" , action = "store_true" )
28-
29- parser .add_argument (
30- "--downsample" ,
31- type = int ,
32- default = 0 ,
33- help = "Downsampling factor for videos. If 0 then no downsampling is performed" ,
34- )
35-
3621 parser .add_argument (
37- "-c" , "--cache" , default = "storage/data/cache" , help = "Path to store visual features"
22+ "-i" ,
23+ "--images_dir" ,
24+ required = True ,
25+ help = "Path to a folder of images to extract features from" ,
3826 )
39-
4027 parser .add_argument (
41- "-d" , "--dataset" , required = True , choices = ["images" , "frames" ], help = "Dataset type"
28+ "--is_arena" ,
29+ action = "store_true" ,
30+ help = "If we are extracting features from the Arena images, use the Arena checkpoint" ,
4231 )
32+ parser .add_argument ("-b" , "--batch_size" , type = int , default = 2 )
33+ parser .add_argument ("-w" , "--num_workers" , type = int , default = 0 )
4334 parser .add_argument (
44- "-a" ,
45- "--ann_csv" ,
46- help = "Path to annotation csv file. Used for video datasets to select only the frames that have annotations" ,
35+ "-c" , "--output_dir" , default = "storage/data/cache" , help = "Path to store visual features"
4736 )
48-
4937 parser .add_argument (
50- "-at" ,
51- "--ann_type" ,
52- choices = ["epic_kitchens" ],
53- default = "epic_kitchens" ,
54- help = "Annotation parser for video datasets" ,
38+ "--num_gpus" ,
39+ type = int ,
40+ default = None ,
41+ help = "Number of GPUs to use for visual feature extraction" ,
5542 )
56-
5743 parser .add_argument (
5844 "opts" ,
5945 default = None ,
@@ -75,34 +61,23 @@ def main() -> None:
7561 cfg .merge_from_list (args .opts )
7662 cfg .freeze ()
7763
78- extractor = VinVLExtractor (cfg = cfg )
79- transform = VinVLTransform (cfg = cfg )
80-
81- dataset : Union [ImageDataset , VideoFrameDataset ]
82- if args .dataset == "images" :
83- dataset = ImageDataset (input_path = args .input_path , preprocess_transform = transform )
84- elif args .dataset == "frames" :
85- dataset = VideoFrameDataset (
86- input_path = args .input_path ,
87- ann_csv = args .ann_csv ,
88- ann_type = args .ann_type ,
89- preprocess_transform = transform ,
90- downsample = args .downsample ,
91- )
64+ if args .is_arena :
65+ cfg .MODEL .WEIGHT = download_arena_checkpoint ().as_posix ()
9266 else :
93- raise OSError ( f"Unsupported dataset type { args . dataset } " )
67+ cfg . MODEL . WEIGHT = download_vinvl_checkpoint (). as_posix ( )
9468
69+ dataset = ImageDataset (
70+ input_path = args .images_dir , preprocess_transform = VinVLTransform (cfg = cfg )
71+ )
9572 dm = PredictDataModule (
9673 dataset = dataset , batch_size = args .batch_size , num_workers = args .num_workers
9774 )
75+ extractor = VinVLExtractor (cfg = cfg )
9876 trainer = Trainer (
99- gpus = args .gpus ,
100- callbacks = [
101- VisualExtractionCacheCallback (cache_dir = args .cache , cache_suffix = args .cache_suffix )
102- ],
103- profiler = "advanced" ,
77+ gpus = args .num_gpus ,
78+ callbacks = [VisualExtractionCacheCallback (cache_dir = args .output_dir , cache_suffix = ".pt" )],
10479 )
105- trainer .predict (extractor , dm , return_predictions = args . return_predictions )
80+ trainer .predict (extractor , dm )
10681
10782
10883if __name__ == "__main__" :
0 commit comments