Skip to content

Commit d8babea

Browse files
committed
Lightning Pose task
1 parent a14029f commit d8babea

File tree

1 file changed

+116
-0
lines changed

1 file changed

+116
-0
lines changed

ibllib/pipes/video_tasks.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,3 +610,119 @@ def _run(self, overwrite=True, run_qc=True, plot_qc=True):
610610
self.status = -1
611611

612612
return output_files
613+
614+
615+
class LightningPose(base_tasks.VideoTask):
616+
# TODO: make one task per cam?
617+
gpu = 1
618+
io_charge = 100
619+
level = 2
620+
force = True
621+
job_size = 'large'
622+
623+
env = Path.home().joinpath('Documents', 'PYTHON', 'envs', 'litpose', 'bin', 'activate')
624+
scripts = Path.home().joinpath('Documents', 'PYTHON', 'iblscripts', 'deploy', 'serverpc', 'litpose')
625+
626+
@property
627+
def signature(self):
628+
signature = {
629+
'input_files': [(f'_iblrig_{cam}Camera.raw.mp4', self.device_collection, True) for cam in self.cameras],
630+
'output_files': [(f'_ibl_{cam}Camera.lightningPose.pqt', 'alf', True) for cam in self.cameras]
631+
}
632+
633+
return signature
634+
635+
@staticmethod
636+
def _video_intact(file_mp4):
637+
"""Checks that the downloaded video can be opened and is not empty"""
638+
cap = cv2.VideoCapture(str(file_mp4))
639+
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
640+
intact = True if frame_count > 0 else False
641+
cap.release()
642+
return intact
643+
644+
def _check_env(self):
645+
"""Check that scripts are present, env can be activated and get iblvideo version"""
646+
assert len(list(self.scripts.rglob('run_litpose.*'))) == 2, \
647+
f'Scripts run_litpose.sh and run_litpose.py do not exist in {self.scripts}'
648+
assert self.env.exists(), f"environment does not exist in assumed location {self.env}"
649+
command2run = f"source {self.env}; python -c 'import iblvideo; print(iblvideo.__version__)'"
650+
process = subprocess.Popen(
651+
command2run,
652+
shell=True,
653+
stdout=subprocess.PIPE,
654+
stderr=subprocess.PIPE,
655+
executable="/bin/bash"
656+
)
657+
info, error = process.communicate()
658+
if process.returncode != 0:
659+
raise AssertionError(f"environment check failed\n{error.decode('utf-8')}")
660+
version = info.decode("utf-8").strip().split('\n')[-1]
661+
return version
662+
663+
def _run(self, overwrite=True, **kwargs):
664+
665+
# Gather video files
666+
self.session_path = Path(self.session_path)
667+
mp4_files = [
668+
self.session_path.joinpath(self.device_collection, f'_iblrig_{cam}Camera.raw.mp4') for cam in self.cameras
669+
if self.session_path.joinpath(self.device_collection, f'_iblrig_{cam}Camera.raw.mp4').exists()
670+
]
671+
672+
labels = [label_from_path(x) for x in mp4_files]
673+
_logger.info(f'Running on {labels} videos')
674+
675+
# Check the environment
676+
self.version = self._check_env()
677+
_logger.info(f'iblvideo version {self.version}')
678+
679+
# If all results exist and overwrite is False, skip computation
680+
expected_outputs_present, expected_outputs = self.assert_expected(self.output_files, silent=True)
681+
if overwrite is False and expected_outputs_present is True:
682+
actual_outputs = expected_outputs
683+
return actual_outputs
684+
685+
# Else, loop over videos
686+
actual_outputs = []
687+
for label, mp4_file in zip(labels, mp4_files):
688+
# Catch exceptions so that the other cams can still run but set status to Errored
689+
try:
690+
# Check that the GPU is (still) accessible
691+
check_nvidia_driver()
692+
# Check that the video can be loaded
693+
if not self._video_intact(mp4_file):
694+
_logger.error(f"Corrupt raw video file {mp4_file}")
695+
self.status = -1
696+
continue
697+
t0 = time.time()
698+
_logger.info(f'Running Ligthning Pose on {label}Camera.')
699+
command2run = f"{self.scripts.joinpath('run_litpose.sh')} {str(self.env)} {mp4_file} {overwrite}"
700+
_logger.info(command2run)
701+
process = subprocess.Popen(
702+
command2run,
703+
shell=True,
704+
stdout=subprocess.PIPE,
705+
stderr=subprocess.PIPE,
706+
executable="/bin/bash",
707+
)
708+
info, error = process.communicate()
709+
if process.returncode != 0:
710+
error_str = error.decode("utf-8").strip()
711+
_logger.error(f'Lightning pose failed for {label}Camera.\n\n'
712+
f'++++++++ Output of subprocess for debugging ++++++++\n\n'
713+
f'{error_str}\n'
714+
f'++++++++++++++++++++++++++++++++++++++++++++\n')
715+
self.status = -1
716+
continue
717+
else:
718+
_logger.info(f'{label} camera took {(time.time() - t0)} seconds')
719+
result = next(self.session_path.joinpath('alf').glob(f'_ibl_{label}Camera.lightningPose*.pqt'))
720+
actual_outputs.append(result)
721+
722+
except BaseException:
723+
_logger.error(traceback.format_exc())
724+
self.status = -1
725+
continue
726+
727+
return actual_outputs
728+

0 commit comments

Comments
 (0)