55import os
66import sys
77from pathlib import Path
8+ from textwrap import dedent
89from typing import TYPE_CHECKING
910
11+ from omegaconf import OmegaConf
12+
13+ # Don't import anything from torch or lightning_pose until needed.
14+ # These imports are slow and delay CLI help text outputs.
15+ # if TYPE_CHECKING allows use of imports for type annotations, without
16+ # actually invoking the import at runtime.
1017if TYPE_CHECKING :
1118 from lightning_pose .model import Model
1219
@@ -45,6 +52,11 @@ def _build_parser():
4552 "If not specified, defaults to "
4653 "./outputs/{YYYY-MM-DD}/{HH:MM:SS}/" ,
4754 )
55+ train_parser .add_argument (
56+ "--detector_model" ,
57+ type = types .existing_model_dir ,
58+ help = "If specified, uses cropped training data in the detector model's directory." ,
59+ )
4860 train_parser .add_argument (
4961 "--overrides" ,
5062 nargs = "*" ,
@@ -94,13 +106,86 @@ def _build_parser():
94106 " uses the labels to compute pixel error.\n "
95107 " saves outputs to `image_preds/<csv_file_name>`\n " ,
96108 )
109+ predict_parser .add_argument (
110+ "--overrides" ,
111+ nargs = "*" ,
112+ metavar = "KEY=VALUE" ,
113+ help = "overrides attributes of the config file. Uses hydra syntax:\n "
114+ "https://hydra.cc/docs/advanced/override_grammar/basic/" ,
115+ )
97116
98117 post_prediction_args = predict_parser .add_argument_group ("post-prediction" )
99118 post_prediction_args .add_argument (
100119 "--skip_viz" ,
101120 action = "store_true" ,
102121 help = "skip generating prediction-annotated images/videos" ,
103122 )
123+
124+ # Crop command
125+ crop_parser = subparsers .add_parser (
126+ "crop" ,
127+ description = dedent (
128+ """\
129+ Crops a video or labeled frames based on model predictions.
130+ Requires model predictions to already have been generated using `litpose predict`.
131+
132+ Cropped videos are saved to:
133+ <model_dir>/
134+ └── video_preds/
135+ ├── <video_filename>.csv (predictions)
136+ ├── <video_filename>_bbox.csv (bbox)
137+ └── remapped_<video_filename>.csv (TODO move to remap command)
138+ └── cropped_videos/
139+ └── cropped_<video_filename>.mp4 (cropped video)
140+
141+ Cropped images are saved to:
142+ <model_dir>/
143+ └── image_preds/
144+ └── <csv_file_name>/
145+ ├── predictions.csv
146+ ├── bbox.csv (bbox)
147+ └── cropped_<csv_file_name>.csv (cropped labels)
148+ └── cropped_images/
149+ └── a/b/c/<image_name>.png (cropped images)\
150+ """
151+ ),
152+ usage = "litpose crop <model_dir> <input_path:video|csv>... --crop_ratio=CROP_RATIO --anchor_keypoints=x,y,z" ,
153+ )
154+ crop_parser .add_argument (
155+ "model_dir" , type = types .existing_model_dir , help = "path to a model directory"
156+ )
157+
158+ crop_parser .add_argument (
159+ "input_path" , type = Path , nargs = "+" , help = "one or more files"
160+ )
161+ crop_parser .add_argument (
162+ "--crop_ratio" ,
163+ type = float ,
164+ default = 2.0 ,
165+ help = "Crop a bounding box this much larger than the animal. Default is 2." ,
166+ )
167+ crop_parser .add_argument (
168+ "--anchor_keypoints" ,
169+ type = str ,
170+ default = "" , # Or a reasonable default like "0,0,0" if appropriate
171+ help = "Comma-separated list of anchor keypoint names, defaults to all keypoints" ,
172+ )
173+
174+ remap_parser = subparsers .add_parser (
175+ "remap" ,
176+ description = dedent (
177+ """\
178+ Remaps predictions from cropped to original coordinate space.
179+ Requires model predictions to already have been generated using `litpose predict`.
180+
181+ Remapped predictions are saved as "remapped_{preds_file}" in the same folder as preds_file.
182+ """
183+ ),
184+ usage = "litpose remap <preds_file> <bbox_file>" ,
185+ )
186+ remap_parser .add_argument ("preds_file" , type = Path , help = "path to a prediction file" )
187+ remap_parser .add_argument ("bbox_file" , type = Path , help = "path to a bbox file" )
188+
104189 return parser
105190
106191
@@ -120,6 +205,84 @@ def main():
120205 elif args .command == "predict" :
121206 _predict (args )
122207
208+ elif args .command == "crop" :
209+ _crop (args )
210+
211+ elif args .command == "remap" :
212+ _remap_preds (args )
213+
214+
215+ def _crop (args : argparse .Namespace ):
216+ import lightning_pose .utils .cropzoom as cz
217+ from lightning_pose .model import Model
218+
219+ model_dir = args .model_dir
220+ model = Model .from_dir (model_dir )
221+
222+ # Make both cropped_images and cropped_videos dirs. Reason: After this, the user
223+ # will train a pose model, and current code in io utils checks that both
224+ # data_dir and videos_dir are present. if we just create one or the other,
225+ # the check will fail.
226+ model .cropped_data_dir ().mkdir (parents = True , exist_ok = True )
227+ model .cropped_videos_dir ().mkdir (parents = True , exist_ok = True )
228+
229+ input_paths = [Path (p ) for p in args .input_path ]
230+
231+ detector_cfg = OmegaConf .create (
232+ {
233+ "crop_ratio" : args .crop_ratio ,
234+ "anchor_keypoints" : args .anchor_keypoints .split ("," ) if args .anchor_keypoints else [],
235+ }
236+ )
237+ assert detector_cfg .crop_ratio > 1
238+
239+ for input_path in input_paths :
240+ if input_path .suffix == ".mp4" :
241+ input_preds_file = model .video_preds_dir () / (input_path .stem + ".csv" )
242+ output_bbox_file = model .video_preds_dir () / (
243+ input_path .stem + "_bbox.csv"
244+ )
245+ output_file = model .cropped_videos_dir () / ("cropped_" + input_path .name )
246+
247+ cz .generate_cropped_video (
248+ input_video_file = input_path ,
249+ input_preds_file = input_preds_file ,
250+ detector_cfg = detector_cfg ,
251+ output_bbox_file = output_bbox_file ,
252+ output_file = output_file ,
253+ )
254+ elif input_path .suffix == ".csv" :
255+ preds_dir = model .image_preds_dir () / input_path .name
256+ input_data_dir = Path (model .config .cfg .data .data_dir )
257+ cropped_data_dir = model .cropped_data_dir ()
258+
259+ output_bbox_file = preds_dir / "bbox.csv"
260+ output_csv_file_path = preds_dir / ("cropped_" + input_path .name )
261+ input_preds_file = preds_dir / "predictions.csv"
262+ cz .generate_cropped_labeled_frames (
263+ input_data_dir = input_data_dir ,
264+ input_csv_file = input_path ,
265+ input_preds_file = input_preds_file ,
266+ detector_cfg = detector_cfg ,
267+ output_data_dir = cropped_data_dir ,
268+ output_bbox_file = output_bbox_file ,
269+ output_csv_file = output_csv_file_path ,
270+ )
271+ else :
272+ raise NotImplementedError ("Only mp4 and csv files are supported." )
273+
274+
275+ def _remap_preds (args : argparse .Namespace ):
276+ import lightning_pose .utils .cropzoom as cz
277+
278+ output_file = args .preds_file .with_name ("remapped_" + args .preds_file .name )
279+
280+ cz .generate_cropped_csv_file (
281+ input_csv_file = args .preds_file ,
282+ input_bbox_file = args .bbox_file ,
283+ output_csv_file = output_file ,
284+ )
285+
123286
124287def _train (args : argparse .Namespace ):
125288 import hydra
@@ -142,11 +305,32 @@ def _train(args: argparse.Namespace):
142305 cfg = hydra .compose (config_name = args .config_file .stem , overrides = args .overrides )
143306
144307 # Delay this import because it's slow.
308+ from lightning_pose .model import Model
145309 from lightning_pose .train import train
146310
147311 # TODO: Move some aspects of directory mgmt to the train function.
148312 output_dir .mkdir (parents = True , exist_ok = True )
149313 # Maintain legacy hydra chdir until downstream no longer depends on it.
314+
315+ if args .detector_model :
316+ # create detector model object before chdir so that relative path is resolved correctly
317+ detector_model = Model .from_dir (args .detector_model )
318+ import copy
319+
320+ cfg = copy .deepcopy (cfg )
321+ cfg .data .data_dir = str (detector_model .cropped_data_dir ())
322+ cfg .data .video_dir = str (detector_model .cropped_videos_dir ())
323+ if isinstance (cfg .data .csv_file , str ):
324+ cfg .data .csv_file = str (
325+ detector_model .cropped_csv_file_path (cfg .data .csv_file )
326+ )
327+ else :
328+ cfg .data .csv_file = [
329+ str (detector_model .cropped_csv_file_path (f ))
330+ for f in cfg .data .csv_file
331+ ]
332+ cfg .eval .test_videos_directory = cfg .data .video_dir
333+
150334 os .chdir (output_dir )
151335 train (cfg )
152336
@@ -155,7 +339,7 @@ def _predict(args: argparse.Namespace):
155339 # Delay this import because it's slow.
156340 from lightning_pose .model import Model
157341
158- model = Model .from_dir (args .model_dir )
342+ model = Model .from_dir2 (args .model_dir , hydra_overrides = args . overrides )
159343 input_paths = [Path (p ) for p in args .input_path ]
160344
161345 for p in input_paths :
0 commit comments