diff --git a/setup.cfg b/setup.cfg index 46c2807..dfd4834 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,6 +35,7 @@ install_requires = dask-image matplotlib>=3.3 napari==0.4.18 + decord natsort numpy opencv-python-headless diff --git a/src/napari_deeplabcut/_reader.py b/src/napari_deeplabcut/_reader.py index 337f528..df9044d 100644 --- a/src/napari_deeplabcut/_reader.py +++ b/src/napari_deeplabcut/_reader.py @@ -14,6 +14,8 @@ from natsort import natsorted from napari_deeplabcut import misc +from napari_plugin_engine import napari_hook_implementation +from decord import VideoReader, cpu SUPPORTED_IMAGES = ".jpg", ".jpeg", ".png" SUPPORTED_VIDEOS = ".mp4", ".mov", ".avi" @@ -227,73 +229,104 @@ def read_hdf(filename: str) -> List[LayerData]: return layers -class Video: +class VideoReaderDecord(VideoReader): def __init__(self, video_path): - if not os.path.isfile(video_path): - raise ValueError(f'Video path "{video_path}" does not point to a file.') - - self.path = video_path - self.stream = cv2.VideoCapture(video_path) - if not self.stream.isOpened(): - raise OSError("Video could not be opened.") - - self._n_frames = int(self.stream.get(cv2.CAP_PROP_FRAME_COUNT)) - self._width = int(self.stream.get(cv2.CAP_PROP_FRAME_WIDTH)) - self._height = int(self.stream.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self._frame = cv2.UMat(self._height, self._width, cv2.CV_8UC3) - - def __len__(self): - return self._n_frames + super().__init__(video_path, ctx=cpu(0)) + + def __getitem__(self, index): + + # The following __getitem__ code comes from napari-video, which is an + # OpenCV-based video reader that relies on pyvideoreader and allows + # napari to run videos. + # https://github.com/janclemenslab/napari-video + # https://github.com/postpop/videoreader + # This has been modified, so that the Decord video player can be used + # to read videos within napari. + + frames = None + if isinstance(index, int): # single frame + # MODIFIED LINES + self.seek_accurate(index) + frames = self.next().asnumpy() + # ret, frames = self.read(index) + # frames = cv2.cvtColor(frames, cv2.COLOR_BGR2RGB) + elif isinstance(index, slice): # slice of frames + frames = np.stack([self[ii] for ii in range(*index.indices(len(self)))]) + elif isinstance(index, range): # range of frames + frames = np.stack([self[ii] for ii in index]) + elif isinstance(index, tuple): # unpack tuple of indices + if isinstance(index[0], slice): + indices = range(*index[0].indices(len(self))) + # ADDED LINE + frames = self.get_batch(indices) + elif isinstance(index[0], (np.integer, int)): + indices = int(index[0]) + # ADDED LINE + frames = self[indices] + else: + indices = None + + if indices is not None: + # REMOVED LINE + # frames = self[indices] + # ADDED LINES + if isinstance(frames, np.ndarray) == False: + frames = frames.asnumpy() + + # index into pixels and channels + for cnt, idx in enumerate(index[1:]): + if isinstance(idx, slice): + ix = range(*idx.indices(self.shape[cnt+1])) + elif isinstance(idx, int): + ix = range(idx-1, idx) + else: + continue + + if frames.ndim==4: # ugly indexing from the back (-1,-2 etc) + cnt = cnt+1 + frames = np.take(frames, ix, axis=cnt) + + if frames is not None: + if frames.shape[0] == 1: + frames = frames[0] + return frames @property - def width(self): - return self._width + def dtype(self): + return np.uint8 + + # MODIFIED + @property + def shape(self): + return (self._num_frame,) + self[0].shape + + # MODIFIED + @property + def ndim(self): + return len(self[0].shape)+1 @property - def height(self): - return self._height + def size(self): + return np.product(self.shape) + - def set_to_frame(self, ind): - ind = min(ind, len(self) - 1) - ind += 1 # Unclear why this is needed at all - self.stream.set(cv2.CAP_PROP_POS_FRAMES, ind) +def video_file_reader(path): + array = VideoReaderDecord(path) + return [(array, {'name': path}, 'image')] - def read_frame(self): - self.stream.retrieve(self._frame) - cv2.cvtColor(self._frame, cv2.COLOR_BGR2RGB, self._frame, 3) - return self._frame.get() - def close(self): - self.stream.release() +@napari_hook_implementation +def napari_get_reader(path): + # remember, path can be a list, so we check it's type first... + if isinstance(path, str) and any([path.endswith(ext) for ext in [".mp4", ".mov", ".avi"]]): + # If we recognize the format, we return the actual reader function + return video_file_reader + # otherwise we return None. + return None def read_video(filename: str, opencv: bool = True): - if opencv: - stream = Video(filename) - shape = stream.width, stream.height, 3 - - def _read_frame(ind): - stream.set_to_frame(ind) - return stream.read_frame() - - lazy_imread = delayed(_read_frame) - else: # pragma: no cover - from pims import PyAVReaderIndexed - - try: - stream = PyAVReaderIndexed(filename) - except ImportError: - raise ImportError("`pip install av` to use the PyAV video reader.") - - shape = stream.frame_shape - lazy_imread = delayed(stream.get_frame) - - movie = da.stack( - [ - da.from_delayed(lazy_imread(i), shape=shape, dtype=np.uint8) - for i in range(len(stream)) - ] - ) + movie = VideoReaderDecord(filename) elems = list(Path(filename).parts) elems[-2] = "labeled-data" elems[-1] = elems[-1].split(".")[0]