-
Notifications
You must be signed in to change notification settings - Fork 0
Train movinet using hmdb51, and show it's predictions in classifications. #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
gsaluja9
wants to merge
11
commits into
main
Choose a base branch
from
train_movinet
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
c5a12b6
Working ingestion.
gsaluja9 71a96cc
Fix ingesting. Create more edges between same 2 nodes depending on th…
gsaluja9 a3f6ac7
original train script
gsaluja9 cb0491c
Preserve the indexes to class mapping after model training.
gsaluja9 453b908
refactored
gsaluja9 3d946cf
tested - 1
gsaluja9 878e926
Inference using trained model vs off the shelf.
gsaluja9 1b6b032
Conditional add and indexes in connection.
gsaluja9 e28549f
use main thread for dataloader.
gsaluja9 dce5ef2
handle edge cases in classification.
gsaluja9 f767c4c
Added notes and instructions for aqquiring helper scripts.
gsaluja9 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| from typing import Any, Callable, Dict, Optional | ||
| from aperturedb.Videos import Videos | ||
| from aperturedb.CommonLibrary import create_connector, execute_query | ||
| import torchvision | ||
| from AVideoClips import AVideoClips | ||
|
|
||
|
|
||
| def get_videos(train:bool, split:int) -> Videos: | ||
| """ | ||
| HMDB51 stores videos in clips corresponding to 51 categories. | ||
| They videos are classified as a test and train set (70% : 30%) | ||
|
|
||
| The data set is further stored in 3 ways, | ||
| Get videos from aperturedb based on type (Train/Test) | ||
| and split. | ||
|
|
||
| Fetch the appropriate set. | ||
| """ | ||
|
|
||
| client = create_connector() | ||
|
|
||
| query = [{ | ||
| "FindEntity": { | ||
| "_ref": 1, | ||
| "with_class": "Split", | ||
| "constraints": { | ||
| "id": ["==", split] | ||
| }, | ||
| "results": { | ||
| "all_properties": True | ||
| } | ||
| } | ||
| }, { | ||
| "FindVideo":{ | ||
| "is_connected_to": { | ||
| "ref": 1, | ||
| "constraints": { | ||
| "type": ["==", 1 if train else 2] | ||
| } | ||
| }, | ||
| "results":{ | ||
| "all_properties": True, | ||
| "count": True | ||
| } | ||
| } | ||
| }] | ||
| _, r, b = execute_query(client, query, []) | ||
|
|
||
|
|
||
| videos = Videos(client=client, response=r[1]["FindVideo"]["entities"]) | ||
| videos.blobs = True | ||
| print(f"Retrieved {len(videos)} videos") | ||
| return videos | ||
|
|
||
|
|
||
| class AHMDB51(torchvision.datasets.HMDB51): | ||
| """ | ||
| Implementation of HMDB51 aware of aperturedb. | ||
| Notice how pytorch's implementation has so much code for local file processing. | ||
| """ | ||
| def __init__(self, | ||
| frames_per_clip: int = 5, | ||
| step_between_clips: int = 1, | ||
| frame_rate: Optional[int] = None, | ||
| fold: int = 1, train: bool = True, | ||
| transform: Optional[Callable] = None, | ||
| _precomputed_metadata: Optional[Dict[str, Any]] = None, | ||
| num_workers: int = 1, | ||
| _video_width: int = 0, | ||
| _video_height: int = 0, | ||
| _video_min_dimension: int = 0, | ||
| _audio_samples: int = 0, | ||
| output_format: str = "THWC") -> None: | ||
| self.video_pts = [] | ||
| self.video_fps = [] | ||
| self.transform = transform | ||
|
|
||
| videos = get_videos(train=train, split=fold) | ||
| self.ci = {} | ||
| videos.blobs = False | ||
| for v in videos: | ||
| if v["category"] not in self.ci: | ||
| self.ci[v["category"]] = len(self.ci) | ||
| self.samples = [(i, self.ci[v["category"]]) for i, v in enumerate(videos)] | ||
| videos.blobs = True | ||
|
|
||
|
|
||
| video_clips = AVideoClips( | ||
| videos, | ||
| frames_per_clip, | ||
| step_between_clips, | ||
| frame_rate, | ||
| _precomputed_metadata, | ||
| num_workers=num_workers, | ||
| _video_width=_video_width, | ||
| _video_height=_video_height, | ||
| _video_min_dimension=_video_min_dimension, | ||
| _audio_samples=_audio_samples, | ||
| output_format=output_format, | ||
| ) | ||
|
|
||
| self.video_clips = video_clips | ||
| self.indices = [i for i in range(len(videos))] | ||
| assert len(videos) == len(list(filter(lambda e: 'preview' in e, videos))) | ||
| videos.loaded = True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| from typing import Any, Dict, List, Optional, Tuple | ||
| from torchvision.datasets.video_utils import VideoClips | ||
| from torchvision.datasets.video_utils import read_video_timestamps | ||
| from torchvision.io.video import read_video | ||
| import tempfile | ||
| import os | ||
| import shutil | ||
| from torch.utils.data.dataloader import DataLoader | ||
| import torch | ||
|
|
||
| from aperturedb.Videos import Videos | ||
| from tqdm import tqdm | ||
|
|
||
| class _VideoTimestampsDataset: | ||
| """ | ||
| Dataset used to parallelize the reading of the timestamps | ||
| of a list of videos, given their paths in the filesystem. | ||
|
|
||
| Used in VideoClips and defined at top level so it can be | ||
| pickled when forking. | ||
| """ | ||
| def __init__(self, videos: Videos) -> None: | ||
| self._videos = videos | ||
| self._tmp_path = "scratch" | ||
| if os.path.exists(self._tmp_path) and os.path.isdir(self._tmp_path): | ||
| pass | ||
| else: | ||
| shutil.rmtree(self._tmp_path, ignore_errors=True) | ||
| os.makedirs(self._tmp_path) | ||
|
|
||
|
|
||
| def __len__(self) -> int: | ||
| return len(self._videos) | ||
|
|
||
| def __getitem__(self, idx: int) -> Tuple[List[int], Optional[float]]: | ||
| video = self._videos[idx] | ||
|
|
||
| with tempfile.NamedTemporaryFile(dir=self._tmp_path, suffix=".mp4") as ostream: | ||
| ostream.write(video["preview"]) | ||
| x = read_video_timestamps(ostream.name) | ||
| return x | ||
| raise Exception("Should not be here") | ||
|
|
||
| class AVideoClips(VideoClips): | ||
| """ | ||
| Pytorch VideoClips with aperturedb. | ||
| """ | ||
| def __init__(self, videos: Videos, clip_length_in_frames: int = 16, frames_between_clips: int = 1, | ||
| frame_rate: Optional[int] = None, _precomputed_metadata: Optional[Dict[str, Any]] = None, num_workers: int = 0, | ||
| _video_width: int = 0, _video_height: int = 0, _video_min_dimension: int = 0, _video_max_dimension: int = 0, | ||
| _audio_samples: int = 0, _audio_channels: int = 0, output_format: str = "THWC") -> None: | ||
| self._videos = videos | ||
| self._num_workers = num_workers | ||
|
|
||
| # these options are not valid for pyav backend | ||
| self._video_width = _video_width | ||
| self._video_height = _video_height | ||
| self._video_min_dimension = _video_min_dimension | ||
| self._video_max_dimension = _video_max_dimension | ||
| self._audio_samples = _audio_samples | ||
| self._audio_channels = _audio_channels | ||
| self.output_format = output_format.upper() | ||
|
|
||
| self._compute_frame_pts() | ||
| self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate) | ||
| assert len(self._videos) == len(list(filter(lambda e: 'preview' in e, self._videos))) | ||
|
|
||
|
|
||
|
|
||
|
|
||
| def _compute_frame_pts(self) -> None: | ||
| dl: DataLoader = DataLoader( | ||
| _VideoTimestampsDataset(self._videos), | ||
| batch_size=16, | ||
| num_workers=self._num_workers, | ||
| collate_fn=lambda x: x | ||
| ) | ||
|
|
||
| self.video_fps = [] | ||
| self.video_pts = [] | ||
|
|
||
| with tqdm(total=len(dl)) as pbar: | ||
| for batch in dl: | ||
| pbar.update(1) | ||
| clips, fps = list(zip(*batch)) | ||
| clips = [torch.as_tensor(c, dtype=torch.long) for c in clips] | ||
| self.video_pts.extend(clips) | ||
| self.video_fps.extend(fps) | ||
|
|
||
| def __len__(self) -> int: | ||
| return len(self._videos) | ||
|
|
||
| def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any], int]: | ||
| """ | ||
| Gets a subclip from a list of videos. | ||
|
|
||
| Args: | ||
| idx (int): index of the subclip. Must be between 0 and num_clips(). | ||
|
|
||
| Returns: | ||
| video (Tensor) | ||
| audio (Tensor) | ||
| info (Dict) | ||
| video_idx (int): index of the video in `video_paths` | ||
| """ | ||
| if idx >= self.num_clips(): | ||
| raise IndexError(f"Index {idx} out of range ({self.num_clips()} number of clips)") | ||
| video_idx, clip_idx = self.get_clip_location(idx) | ||
| clip_pts = self.clips[video_idx][clip_idx] | ||
|
|
||
| from torchvision import get_video_backend | ||
|
|
||
| backend = get_video_backend() | ||
|
|
||
| if backend == "pyav": | ||
| # check for invalid options | ||
| if self._video_width != 0: | ||
| raise ValueError("pyav backend doesn't support _video_width != 0") | ||
| if self._video_height != 0: | ||
| raise ValueError("pyav backend doesn't support _video_height != 0") | ||
| if self._video_min_dimension != 0: | ||
| raise ValueError("pyav backend doesn't support _video_min_dimension != 0") | ||
| if self._video_max_dimension != 0: | ||
| raise ValueError("pyav backend doesn't support _video_max_dimension != 0") | ||
| if self._audio_samples != 0: | ||
| raise ValueError("pyav backend doesn't support _audio_samples != 0") | ||
|
|
||
| if backend == "pyav": | ||
| start_pts = clip_pts[0].item() | ||
| end_pts = clip_pts[-1].item() | ||
| with tempfile.NamedTemporaryFile(dir="scratch", suffix=".mp4") as ostream: | ||
| ostream.write(self._videos[video_idx]["preview"]) | ||
| video, audio, info = read_video(ostream.name, start_pts, end_pts) | ||
|
|
||
| if self.frame_rate is not None: | ||
| resampling_idx = self.resampling_idxs[video_idx][clip_idx] | ||
| if isinstance(resampling_idx, torch.Tensor): | ||
| resampling_idx = resampling_idx - resampling_idx[0] | ||
| video = video[resampling_idx] | ||
| info["video_fps"] = self.frame_rate | ||
| assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}" | ||
|
|
||
| if self.output_format == "TCHW": | ||
| # [T,H,W,C] --> [T,C,H,W] | ||
| video = video.permute(0, 3, 1, 2) | ||
|
|
||
| return video, audio, info, video_idx | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment is not explaining much. but you did say you plan to add more comments. so that's good. if we need a class hierarchy diagram, maybe that helps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh actually, this is the exact doc string of the pytorch version which we have overriden. Let me change it to why this is needed instead.