Skip to content

Commit 57d9a50

Browse files
committed
add benchmark for optical flow Raft model
1 parent fc30423 commit 57d9a50

File tree

9 files changed

+82
-2
lines changed

9 files changed

+82
-2
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ Some examples are listed below. You can find more in the directory of each model
107107

108108
![crnn_demo](./models/text_recognition_crnn/example_outputs/CRNNCTC.gif)
109109

110+
### Optical Estimation with [RAFT](./models/optical_flow_estimation_raft/)
111+
112+
![raft_demo](./models/optical_flow_estimation_raft/example_outputs/result.jpg)
113+
110114
## License
111115

112116
OpenCV Zoo is licensed under the [Apache 2.0 license](./LICENSE). Please refer to licenses of different models.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
Benchmark:
2+
name: "Optical Flow Estimation Benchmark"
3+
type: "OpticalFlow"
4+
data:
5+
path: "data/optical_flow_estimation"
6+
files: [["driving0.png", "driving1.png"], ["flyingThings3D0.png", "flyingThings3D1.png"], ["monkaa0.png", "monkaa1.png"]]
7+
sizes: # [[w1, h1], ...], Omit to run at original scale
8+
- [360, 480]
9+
metric:
10+
warmup: 30
11+
repeat: 10
12+
backend: "default"
13+
target: "cpu"
14+
15+
Model:
16+
name: "Raft"

benchmark/download_data.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,10 @@ def get_confirm_token(response): # in case of large files
217217
url='https://drive.google.com/u/0/uc?id=1RbLyetgqFUTt0IHaVmu6c_b7KeXJgKbc&export=download',
218218
sha='fbae2fb0a47fe65e316bbd0ec57ba21461967550',
219219
filename='person_detection.zip'),
220+
optical_flow_estimation=Downloader(name='optical_flow_estimation',
221+
url='https://drive.google.com/u/0/uc?id=1_fvN7cgc-j92MeI_wHKGkhWxbXeML_gR&export=download',
222+
sha='96b75eaef250efdde62184b07707827d76bd336c',
223+
filename='optical_flow_estimation.zip'),
220224
)
221225

222226
if __name__ == '__main__':

benchmark/utils/dataloaders/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from .classification import ClassificationImageLoader
33
from .recognition import RecognitionImageLoader
44
from .tracking import TrackingVideoLoader
5+
from .optical_flow import OpticalFlowImageLoader
56

6-
__all__ = ['BaseImageLoader', 'BaseVideoLoader', 'ClassificationImageLoader', 'RecognitionImageLoader', 'TrackingVideoLoader']
7+
__all__ = ['BaseImageLoader', 'BaseVideoLoader', 'ClassificationImageLoader', 'RecognitionImageLoader', 'TrackingVideoLoader', 'OpticalFlowImageLoader']
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import os
2+
3+
import numpy as np
4+
import cv2 as cv
5+
6+
from .base_dataloader import _BaseImageLoader
7+
from ..factory import DATALOADERS
8+
9+
@DATALOADERS.register
10+
class OpticalFlowImageLoader(_BaseImageLoader):
11+
def __init__(self, **kwargs):
12+
super().__init__(**kwargs)
13+
14+
def __iter__(self):
15+
for case in self._files:
16+
image0 = cv.imread(os.path.join(self._path, case[0]))
17+
image1 = cv.imread(os.path.join(self._path, case[1]))
18+
if [0, 0] in self._sizes:
19+
yield "{}, {}".format(case[0], case[1]), image0, image1
20+
else:
21+
for size in self._sizes:
22+
image0_r = cv.resize(image0, size)
23+
image1_r = cv.resize(image1, size)
24+
yield "{}, {}".format(case[0], case[1]), image0_r, image1_r

benchmark/utils/metrics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from .detection import Detection
33
from .recognition import Recognition
44
from .tracking import Tracking
5+
from .optical_flow import OpticalFlow
56

6-
__all__ = ['Base', 'Detection', 'Recognition', 'Tracking']
7+
__all__ = ['Base', 'Detection', 'Recognition', 'Tracking', 'OpticalFlow']
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import cv2 as cv
2+
3+
from .base_metric import BaseMetric
4+
from ..factory import METRICS
5+
6+
@METRICS.register
7+
class OpticalFlow(BaseMetric):
8+
def __init__(self, **kwargs):
9+
super().__init__(**kwargs)
10+
11+
def forward(self, model, *args, **kwargs):
12+
img0, img1 = args
13+
14+
self._timer.reset()
15+
for _ in range(self._warmup):
16+
model.infer(img0, img1)
17+
for _ in range(self._repeat):
18+
self._timer.start()
19+
model.infer(img0, img1)
20+
self._timer.stop()
21+
22+
return self._timer.getRecords()

models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .facial_expression_recognition.facial_fer_model import FacialExpressionRecog
2121
from .object_tracking_vittrack.vittrack import VitTrack
2222
from .text_detection_ppocr.ppocr_det import PPOCRDet
23+
from .optical_flow_estimation_raft.raft import Raft
2324

2425
class ModuleRegistery:
2526
def __init__(self, name):
@@ -94,3 +95,4 @@ def register(self, item):
9495
MODELS.register(FacialExpressionRecog)
9596
MODELS.register(VitTrack)
9697
MODELS.register(PPOCRDet)
98+
MODELS.register(Raft)

models/optical_flow_estimation_raft/raft.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ def _preprocess(self, image):
2929
img_input = img_input.astype(np.float32)
3030
return img_input
3131

32+
def setBackendAndTarget(self, backendId, targetId):
33+
self.backend_id = backendId
34+
self.target_id = targetId
35+
self.model.setPreferableBackend(self.backend_id)
36+
self.model.setPreferableTarget(self.target_id)
37+
3238
def infer(self, image1, image2):
3339

3440
# Preprocess

0 commit comments

Comments
 (0)