Skip to content

Commit 74776f2

Browse files
authored
Merge pull request #46 from MichiganCOG/dev
Dev
2 parents 326c42c + a76fe1e commit 74776f2

34 files changed

+2636
-159
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ runs/*
1313
models/HGC3D
1414
*.json
1515
pbs/*
16+
*.pt

README.md

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Video Platform for Recognition and Detection in Pytorch
1+
# [Video Platform for Recognition and Detection in Pytorch](https://arxiv.org/abs/1910.02793)
22

33
A platform for quick and easy development of deep learning networks for recognition and detection in videos. Includes popular models like C3D and SSD.
44

@@ -9,13 +9,39 @@ Check out our [wiki!](https://github.com/MichiganCOG/ViP/wiki)
99
### Recognition
1010
| Model Architecture | Dataset | ViP Accuracy (%) |
1111
|:--------------------:|:------------------:|:---------------------:|
12+
| I3D | HMDB51 (Split 1) | 72.75 |
1213
| C3D | HMDB51 (Split 1) | 50.14 ± 0.777 |
1314
| C3D | UCF101 (Split 1) | 80.40 ± 0.399 |
1415

1516
### Object Detection
1617
| Model Architecture | Dataset | ViP Accuracy (%) |
1718
|:--------------------:|:------------------:|:---------------------:|
1819
| SSD300 | VOC2007 | 76.58 |
20+
21+
### Video Object Grounding
22+
| Model Architecture | Dataset | ViP Accuracy (%) |
23+
|:--------------------:|:------------------:|:---------------------:|
24+
| DVSA (+fw, obj) | YC2-BB (Validation) | 30.09 |
25+
26+
**fw**: framewise weighting, **obj**: object interaction
27+
28+
29+
## Citation
30+
31+
Please cite ViP when releasing any work that used this platform: https://arxiv.org/abs/1910.02793
32+
33+
```
34+
@article{ganesh2019vip,
35+
title={ViP: Video Platform for PyTorch},
36+
author={Ganesh, Madan Ravi and Hofesmann, Eric and Louis, Nathan and Corso, Jason},
37+
journal={arXiv preprint arXiv:1910.02793},
38+
year={2019}
39+
}
40+
41+
```
42+
43+
44+
1945
## Table of Contents
2046

2147
* [Datasets](#configured-datasets)
@@ -38,12 +64,16 @@ Check out our [wiki!](https://github.com/MichiganCOG/ViP/wiki)
3864
|[ImageNetVID](http://bvisionweb1.cs.unc.edu/ilsvrc2015/download-videos-3j16.php) | Video Object Detection |
3965
|[MSCOCO 2014](http://cocodataset.org/#download) | Object Detection, Keypoints|
4066
|[VOC2007](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/) | Object Detection, Classification|
67+
|[YC2-BB](http://youcook2.eecs.umich.edu/download)| Video Object Grounding|
68+
|[DHF1K](https://github.com/wenguanwang/DHF1K) | Video Saliency Prediction|
4169

4270
## Models
4371
| Model | Task(s) |
4472
|:------------------------------------------------:|:--------------------:|
4573
|[C3D](https://github.com/jfzhang95/pytorch-video-recognition/blob/master/network/C3D_model.py) | Activity Recognition |
74+
|[I3D](https://github.com/piergiaj/pytorch-i3d) | Activity Recognition |
4675
|[SSD300](https://github.com/amdegroot/ssd.pytorch) | Object Detection |
76+
|[DVSA (+fw, obj)](https://github.com/MichiganCOG/Video-Grounding-from-Text)| Video Object Grounding|
4777

4878
## Requirements
4979

config_default_example.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
11
# Preprocessing
22
clip_length: 16 # Number of frames within a clip
33
clip_offset: 0 # Frame offset between beginning of video and clip (1st clip only)
4-
clip_stride: 0 # Frame offset between successive frames
4+
clip_stride: 1 # Frame offset between successive clips, must be >= 1
55
crop_shape: [112,112] # (Height, Width) of frame
66
crop_type: Random # Type of cropping operation (Random, Central and None)
77
final_shape: [112,112] # (Height, Width) of input to be given to CNN
88
num_clips: -1 # Number clips to be generated from a video (<0: uniform sampling, 0: Divide entire video into clips, >0: Defines number of clips)
99
random_offset: 0 # Boolean switch to generate a clip length sized clip from a video
1010
resize_shape: [128,171] # (Height, Width) to resize original data
11-
sample_duration: 16 # Temporal size of video to be provided as input to the model
12-
sample_size: 112 # Height of frame to be provided as input to the model
1311
subtract_mean: '' # Subtract mean (R,G,B) from all frames during preprocessing
1412

1513
# Experiment Setup
1614
acc_metric: Accuracy # Accuracy metric
17-
batch_size: 3 # Numbers of videos in a mini-batch
15+
batch_size: 15 # Numbers of videos in a mini-batch
1816
dataset: HMDB51 # Name of dataset
1917
debug: 0 # If True, do not plot, save, or create data files
2018
epoch: 30 # Total number of epochs
2119
exp: exp # Experiment name
2220
gamma: 0.1 # Multiplier with which to change learning rate
21+
grad_max_norm: 0 # Norm for gradient clipping
2322
json_path: /z/dat/HMDB51/ # Path to the json file for the given dataset
2423
labels: 51 # Number of total classes in the dataset
2524
load_type: train # Environment selection, to include only training/training and validation/testing dataset
@@ -37,3 +36,4 @@ rerun: 1 # Number of trials to repeat an experim
3736
save_dir: './results' # Path to results directory
3837
seed: 999 # Seed for reproducibility
3938
weight_decay: 0.0005 # Weight decay
39+
resume: 0 # Flag to resume training or switch to alternate objective after loading

datasets/DHF1K.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import torch
2+
try:
3+
from .abstract_datasets import DetectionDataset
4+
except:
5+
from abstract_datasets import DetectionDataset
6+
import cv2
7+
import os
8+
import numpy as np
9+
import json
10+
try:
11+
import datasets.preprocessing_transforms as pt
12+
except:
13+
import preprocessing_transforms as pt
14+
15+
class DHF1K(DetectionDataset):
16+
def __init__(self, *args, **kwargs):
17+
super(DHF1K, self).__init__(*args, **kwargs)
18+
19+
# Get model object in case preprocessing other than default is used
20+
self.model_object = kwargs['model_obj']
21+
self.load_type = kwargs['load_type']
22+
23+
print(self.load_type)
24+
if self.load_type=='train':
25+
self.transforms = kwargs['model_obj'].train_transforms
26+
27+
else:
28+
self.transforms = kwargs['model_obj'].test_transforms
29+
30+
31+
32+
33+
def __getitem__(self, idx):
34+
vid_info = self.samples[idx]
35+
36+
37+
base_path = vid_info['base_path']
38+
vid_size = vid_info['frame_size']
39+
40+
input_data = []
41+
map_data = []
42+
bin_data = []
43+
44+
for frame_ind in range(len(vid_info['frames'])):
45+
frame = vid_info['frames'][frame_ind]
46+
frame_path = frame['img_path']
47+
map_path = frame['map_path']
48+
bin_path = frame['bin_path']
49+
50+
# Load frame, convert to RGB from BGR and normalize from 0 to 1
51+
input_data.append(cv2.imread(os.path.join(base_path, frame_path))[...,::-1]/255.)
52+
53+
# Load frame, Normalize from 0 to 1
54+
# All frame channels have repeated values
55+
map_data.append(cv2.imread(map_path)/255.)
56+
bin_data.append(cv2.imread(bin_path)/255.)
57+
58+
59+
60+
vid_data = self.transforms(input_data)
61+
62+
# Annotations must be resized in the loss/metric
63+
map_data = torch.Tensor(map_data)
64+
bin_data = torch.Tensor(bin_data)
65+
66+
# Permute the PIL dimensions (Frame, Height, Width, Chan) to pytorch (Chan, frame, height, width)
67+
vid_data = vid_data.permute(3, 0, 1, 2)
68+
map_data = map_data.permute(3, 0, 1, 2)
69+
bin_data = bin_data.permute(3, 0, 1, 2)
70+
# All channels are repeated so remove the unnecessary channels
71+
map_data = map_data[0].unsqueeze(0)
72+
bin_data = bin_data[0].unsqueeze(0)
73+
74+
75+
ret_dict = dict()
76+
ret_dict['data'] = vid_data
77+
78+
annot_dict = dict()
79+
annot_dict['map'] = map_data
80+
annot_dict['bin'] = bin_data
81+
annot_dict['input_shape'] = vid_data.size()
82+
annot_dict['name'] = base_path
83+
ret_dict['annots'] = annot_dict
84+
85+
return ret_dict
86+
87+
88+
if __name__=='__main__':
89+
90+
class tts():
91+
def __call__(self, x):
92+
return pt.ToTensorClip()(x)
93+
class debug_model():
94+
def __init__(self):
95+
self.train_transforms = tts()
96+
97+
98+
json_path = '/path/to/DHF1K' #### Change this when testing ####
99+
100+
101+
dataset = DHF1K(model_obj=debug_model(), json_path=json_path, load_type='train', clip_length=16, clip_offset=0, clip_stride=1, num_clips=0, random_offset=0, resize_shape=0, crop_shape=0, crop_type='Center', final_shape=0, batch_size=1)
102+
train_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False)
103+
104+
105+
import matplotlib.pyplot as plt
106+
for x in enumerate(train_loader):
107+
dat = x[1]['data'][0,:,0].permute(1,2,0).numpy()
108+
bin = x[1]['annots']['bin'][0,:,0].permute(1,2,0).numpy().repeat(3,axis=2)
109+
map = x[1]['annots']['map'][0,:,0].permute(1,2,0).numpy().repeat(3, axis=2)
110+
img = np.concatenate([dat,bin,map], axis=0)
111+
plt.imshow(img)
112+
plt.show()
113+
import pdb; pdb.set_trace()

datasets/HMDB51.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ def __getitem__(self, idx):
4040
base_path = vid_info['base_path']
4141

4242
input_data = []
43-
vid_data = np.zeros((self.clip_length, self.final_shape[0], self.final_shape[1], 3))-1
44-
labels = np.zeros((self.clip_length))-1
43+
vid_length = len(vid_info['frames'])
44+
vid_data = np.zeros((vid_length, self.final_shape[0], self.final_shape[1], 3))-1
45+
labels = np.zeros((vid_length))-1
4546
input_data = []
4647

4748
for frame_ind in range(len(vid_info['frames'])):

datasets/ImageNetVID.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@ def __getitem__(self, idx):
4242
vid_size = vid_info['frame_size']
4343

4444
input_data = []
45-
vid_data = np.zeros((self.clip_length, self.final_shape[0], self.final_shape[1], 3))-1
46-
bbox_data = np.zeros((self.clip_length, self.max_objects, 4))-1
47-
labels = np.zeros((self.clip_length, self.max_objects))-1
48-
occlusions = np.zeros((self.clip_length, self.max_objects))-1
45+
46+
vid_length = len(vid_info['frames'])
47+
vid_data = np.zeros((vid_length, self.final_shape[0], self.final_shape[1], 3))-1
48+
bbox_data = np.zeros((vid_length, self.max_objects, 4))-1
49+
labels = np.zeros((vid_length, self.max_objects))-1
50+
occlusions = np.zeros((vid_length, self.max_objects))-1
4951

5052

5153

datasets/KTH.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import torch
2+
from .abstract_datasets import RecognitionDataset
3+
from PIL import Image
4+
import cv2
5+
import os
6+
import numpy as np
7+
from torchvision import transforms
8+
9+
class KTH(RecognitionDataset):
10+
def __init__(self, *args, **kwargs):
11+
"""
12+
Initialize KTH class
13+
Args:
14+
load_type (String): Select training or testing set
15+
resize_shape (Int): [Int, Int] Array indicating desired height and width to resize input
16+
crop_shape (Int): [Int, Int] Array indicating desired height and width to crop input
17+
final_shape (Int): [Int, Int] Array indicating desired height and width of input to deep network
18+
preprocess (String): Keyword to select different preprocessing types
19+
20+
Return:
21+
None
22+
"""
23+
super(KTH, self).__init__(*args, **kwargs)
24+
25+
self.load_type = kwargs['load_type']
26+
self.resize_shape = kwargs['resize_shape']
27+
self.crop_shape = kwargs['crop_shape']
28+
self.final_shape = kwargs['final_shape']
29+
self.preprocess = kwargs['preprocess']
30+
31+
if self.load_type=='train':
32+
self.transforms = kwargs['model_obj'].train_transforms
33+
34+
else:
35+
self.transforms = kwargs['model_obj'].test_transforms
36+
37+
38+
def __getitem__(self, idx):
39+
vid_info = self.samples[idx]
40+
base_path = vid_info['base_path']
41+
42+
input_data = []
43+
44+
vid_length = len(vid_info['frames'])
45+
vid_data = np.zeros((vid_length, self.final_shape[0], self.final_shape[1], 3))-1
46+
labels = np.zeros((vid_length))-1
47+
input_data = []
48+
49+
for frame_ind in range(len(vid_info['frames'])):
50+
frame_path = os.path.join(base_path, vid_info['frames'][frame_ind]['img_path'])
51+
52+
for frame_labels in vid_info['frames'][frame_ind]['actions']:
53+
labels[frame_ind] = frame_labels['action_class']
54+
55+
# Load frame image data and preprocess image accordingly
56+
input_data.append(cv2.imread(frame_path)[...,::-1]/1.)
57+
58+
59+
# Preprocess data
60+
vid_data = self.transforms(input_data)
61+
labels = torch.from_numpy(labels).float()
62+
63+
# Permute the PIL dimensions (Frame, Height, Width, Chan) to pytorch (Chan, frame, height, width)
64+
vid_data = vid_data.permute(3, 0, 1, 2)
65+
66+
ret_dict = dict()
67+
ret_dict['data'] = vid_data
68+
69+
annot_dict = dict()
70+
annot_dict['labels'] = labels
71+
72+
ret_dict['annots'] = annot_dict
73+
74+
return ret_dict
75+
76+
77+
#dataset = HMDB51(json_path='/z/dat/HMDB51', dataset_type='train', clip_length=100, num_clips=0)
78+
#dat = dataset.__getitem__(0)
79+
#import pdb; pdb.set_trace()

datasets/MSCOCO.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from .abstract_datasets import DetectionDataset
3-
from PIL import Image
3+
import cv2
44
import os
55
import numpy as np
66
import datasets.preprocessing_transforms as pt
@@ -34,10 +34,11 @@ def __getitem__(self, idx):
3434
vid_size = vid_info['frame_size']
3535

3636
input_data = []
37-
vid_data = np.zeros((self.clip_length, self.final_shape[0], self.final_shape[1], 3))-1
38-
bbox_data = np.zeros((self.clip_length, self.max_objects, 4))-1
39-
labels = np.zeros((self.clip_length, self.max_objects))-1
40-
iscrowds = np.zeros((self.clip_length, self.max_objects))-1
37+
vid_length = len(vid_info['frames'])
38+
vid_data = np.zeros((vid_length, self.final_shape[0], self.final_shape[1], 3))-1
39+
bbox_data = np.zeros((vid_length, self.max_objects, 4))-1
40+
labels = np.zeros((vid_length, self.max_objects))-1
41+
iscrowds = np.zeros((vid_length, self.max_objects))-1
4142

4243

4344

@@ -62,7 +63,7 @@ def __getitem__(self, idx):
6263
iscrowds[frame_ind, trackid] = iscrowd
6364

6465

65-
input_data.append(Image.open(os.path.join(base_path, frame_path)))
66+
input_data.append(cv2.imread(os.path.join(base_path, frame_path))[...,::-1])
6667

6768
vid_data, bbox_data = self.transforms(input_data, bbox_data)
6869

0 commit comments

Comments
 (0)