|
23 | 23 | CONFIG_FILENAME, |
24 | 24 | ) |
25 | 25 | from wormpose.config.experiment_config import save_config, ExperimentConfig |
| 26 | +from wormpose.dataset.image_processing.options import ( |
| 27 | + add_image_processing_arguments, |
| 28 | + WORM_IS_LIGHTER, |
| 29 | +) |
26 | 30 | from wormpose.dataset.loader import get_dataset_name |
27 | 31 | from wormpose.dataset.loader import load_dataset |
28 | 32 | from wormpose.dataset.loaders.resizer import add_resizing_arguments, ResizeOptions |
@@ -53,6 +57,8 @@ def _parse_arguments(kwargs: dict): |
53 | 57 | kwargs["video_names"] = None |
54 | 58 | if kwargs.get("random_seed") is None: |
55 | 59 | kwargs["random_seed"] = None |
| 60 | + if kwargs.get(WORM_IS_LIGHTER) is None: |
| 61 | + kwargs[WORM_IS_LIGHTER] = False |
56 | 62 | kwargs["temp_dir"] = tempfile.mkdtemp(dir=kwargs["temp_dir"]) |
57 | 63 | kwargs["resize_options"] = ResizeOptions(**kwargs) |
58 | 64 |
|
@@ -89,8 +95,8 @@ def generate(dataset_loader: str, dataset_path: str, **kwargs): |
89 | 95 | dataset = load_dataset( |
90 | 96 | dataset_loader=dataset_loader, |
91 | 97 | dataset_path=dataset_path, |
92 | | - resize_options=args.resize_options, |
93 | 98 | selected_video_names=args.video_names, |
| 99 | + **vars(args), |
94 | 100 | ) |
95 | 101 |
|
96 | 102 | start = time.time() |
@@ -129,6 +135,7 @@ def generate(dataset_loader: str, dataset_path: str, **kwargs): |
129 | 135 | num_eval_samples=num_eval_samples, |
130 | 136 | resize_factor=args.resize_options.resize_factor, |
131 | 137 | video_names=dataset.video_names, |
| 138 | + worm_is_lighter=getattr(args, WORM_IS_LIGHTER), |
132 | 139 | ), |
133 | 140 | os.path.join(experiment_dir, CONFIG_FILENAME), |
134 | 141 | ) |
@@ -157,6 +164,7 @@ def main(): |
157 | 164 | parser.add_argument("--num_process", type=int, help="How many worker processes") |
158 | 165 | parser.add_argument("--random_seed", type=int, help="Optional random seed for deterministic results") |
159 | 166 | add_resizing_arguments(parser) |
| 167 | + add_image_processing_arguments(parser) |
160 | 168 | args = parser.parse_args() |
161 | 169 |
|
162 | 170 | last_progress = None |
|
0 commit comments