Skip to content
Draft
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,6 @@ cython_debug/
error.*.log
notebooks/semantic_search/output.json
notebooks/semantic_search/config.py

training/movinet_with_hmdb51/input
training/movinet_with_hmdb51/splits
105 changes: 105 additions & 0 deletions training/movinet_with_hmdb51/AHMDB51.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Any, Callable, Dict, Optional
from aperturedb.Videos import Videos
from aperturedb.CommonLibrary import create_connector, execute_query
import torchvision
from AVideoClips import AVideoClips


def get_videos(train:bool, split:int) -> Videos:
"""
HMDB51 stores videos in clips corresponding to 51 categories.
They videos are classified as a test and train set (70% : 30%)

The data set is further stored in 3 ways,
Get videos from aperturedb based on type (Train/Test)
and split.

Fetch the appropriate set.
"""

client = create_connector()

query = [{
"FindEntity": {
"_ref": 1,
"with_class": "Split",
"constraints": {
"id": ["==", split]
},
"results": {
"all_properties": True
}
}
}, {
"FindVideo":{
"is_connected_to": {
"ref": 1,
"constraints": {
"type": ["==", 1 if train else 2]
}
},
"results":{
"all_properties": True,
"count": True
}
}
}]
_, r, b = execute_query(client, query, [])


videos = Videos(client=client, response=r[1]["FindVideo"]["entities"])
videos.blobs = True
print(f"Retrieved {len(videos)} videos")
return videos


class AHMDB51(torchvision.datasets.HMDB51):
"""
Implementation of HMDB51 aware of aperturedb.
Notice how pytorch's implementation has so much code for local file processing.
"""
def __init__(self,
frames_per_clip: int = 5,
step_between_clips: int = 1,
frame_rate: Optional[int] = None,
fold: int = 1, train: bool = True,
transform: Optional[Callable] = None,
_precomputed_metadata: Optional[Dict[str, Any]] = None,
num_workers: int = 1,
_video_width: int = 0,
_video_height: int = 0,
_video_min_dimension: int = 0,
_audio_samples: int = 0,
output_format: str = "THWC") -> None:
self.video_pts = []
self.video_fps = []
self.transform = transform

videos = get_videos(train=train, split=fold)
self.ci = {}
videos.blobs = False
for v in videos:
if v["category"] not in self.ci:
self.ci[v["category"]] = len(self.ci)
self.samples = [(i, self.ci[v["category"]]) for i, v in enumerate(videos)]
videos.blobs = True


video_clips = AVideoClips(
videos,
frames_per_clip,
step_between_clips,
frame_rate,
_precomputed_metadata,
num_workers=num_workers,
_video_width=_video_width,
_video_height=_video_height,
_video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples,
output_format=output_format,
)

self.video_clips = video_clips
self.indices = [i for i in range(len(videos))]
assert len(videos) == len(list(filter(lambda e: 'preview' in e, videos)))
videos.loaded = True
149 changes: 149 additions & 0 deletions training/movinet_with_hmdb51/AVideoClips.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from typing import Any, Dict, List, Optional, Tuple
from torchvision.datasets.video_utils import VideoClips
from torchvision.datasets.video_utils import read_video_timestamps
from torchvision.io.video import read_video
import tempfile
import os
import shutil
from torch.utils.data.dataloader import DataLoader
import torch

from aperturedb.Videos import Videos
from tqdm import tqdm

class _VideoTimestampsDataset:
"""
Dataset used to parallelize the reading of the timestamps
of a list of videos, given their paths in the filesystem.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is not explaining much. but you did say you plan to add more comments. so that's good. if we need a class hierarchy diagram, maybe that helps

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh actually, this is the exact doc string of the pytorch version which we have overriden. Let me change it to why this is needed instead.

Used in VideoClips and defined at top level so it can be
pickled when forking.
"""
def __init__(self, videos: Videos) -> None:
self._videos = videos
self._tmp_path = "scratch"
if os.path.exists(self._tmp_path) and os.path.isdir(self._tmp_path):
pass
else:
shutil.rmtree(self._tmp_path, ignore_errors=True)
os.makedirs(self._tmp_path)


def __len__(self) -> int:
return len(self._videos)

def __getitem__(self, idx: int) -> Tuple[List[int], Optional[float]]:
video = self._videos[idx]

with tempfile.NamedTemporaryFile(dir=self._tmp_path, suffix=".mp4") as ostream:
ostream.write(video["preview"])
x = read_video_timestamps(ostream.name)
return x
raise Exception("Should not be here")

class AVideoClips(VideoClips):
"""
Pytorch VideoClips with aperturedb.
"""
def __init__(self, videos: Videos, clip_length_in_frames: int = 16, frames_between_clips: int = 1,
frame_rate: Optional[int] = None, _precomputed_metadata: Optional[Dict[str, Any]] = None, num_workers: int = 0,
_video_width: int = 0, _video_height: int = 0, _video_min_dimension: int = 0, _video_max_dimension: int = 0,
_audio_samples: int = 0, _audio_channels: int = 0, output_format: str = "THWC") -> None:
self._videos = videos
self._num_workers = num_workers

# these options are not valid for pyav backend
self._video_width = _video_width
self._video_height = _video_height
self._video_min_dimension = _video_min_dimension
self._video_max_dimension = _video_max_dimension
self._audio_samples = _audio_samples
self._audio_channels = _audio_channels
self.output_format = output_format.upper()

self._compute_frame_pts()
self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate)
assert len(self._videos) == len(list(filter(lambda e: 'preview' in e, self._videos)))




def _compute_frame_pts(self) -> None:
dl: DataLoader = DataLoader(
_VideoTimestampsDataset(self._videos),
batch_size=16,
num_workers=self._num_workers,
collate_fn=lambda x: x
)

self.video_fps = []
self.video_pts = []

with tqdm(total=len(dl)) as pbar:
for batch in dl:
pbar.update(1)
clips, fps = list(zip(*batch))
clips = [torch.as_tensor(c, dtype=torch.long) for c in clips]
self.video_pts.extend(clips)
self.video_fps.extend(fps)

def __len__(self) -> int:
return len(self._videos)

def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any], int]:
"""
Gets a subclip from a list of videos.

Args:
idx (int): index of the subclip. Must be between 0 and num_clips().

Returns:
video (Tensor)
audio (Tensor)
info (Dict)
video_idx (int): index of the video in `video_paths`
"""
if idx >= self.num_clips():
raise IndexError(f"Index {idx} out of range ({self.num_clips()} number of clips)")
video_idx, clip_idx = self.get_clip_location(idx)
clip_pts = self.clips[video_idx][clip_idx]

from torchvision import get_video_backend

backend = get_video_backend()

if backend == "pyav":
# check for invalid options
if self._video_width != 0:
raise ValueError("pyav backend doesn't support _video_width != 0")
if self._video_height != 0:
raise ValueError("pyav backend doesn't support _video_height != 0")
if self._video_min_dimension != 0:
raise ValueError("pyav backend doesn't support _video_min_dimension != 0")
if self._video_max_dimension != 0:
raise ValueError("pyav backend doesn't support _video_max_dimension != 0")
if self._audio_samples != 0:
raise ValueError("pyav backend doesn't support _audio_samples != 0")

if backend == "pyav":
start_pts = clip_pts[0].item()
end_pts = clip_pts[-1].item()
with tempfile.NamedTemporaryFile(dir="scratch", suffix=".mp4") as ostream:
ostream.write(self._videos[video_idx]["preview"])
video, audio, info = read_video(ostream.name, start_pts, end_pts)

if self.frame_rate is not None:
resampling_idx = self.resampling_idxs[video_idx][clip_idx]
if isinstance(resampling_idx, torch.Tensor):
resampling_idx = resampling_idx - resampling_idx[0]
video = video[resampling_idx]
info["video_fps"] = self.frame_rate
assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}"

if self.output_format == "TCHW":
# [T,H,W,C] --> [T,C,H,W]
video = video.permute(0, 3, 1, 2)

return video, audio, info, video_idx


Loading