Skip to content

Commit 7758dd9

Browse files
committed
support more evaluation datasets
1 parent f74f08d commit 7758dd9

File tree

9 files changed

+11700
-2
lines changed

9 files changed

+11700
-2
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
name: diode
2+
disp_name: diode_val_all
3+
# dir: diode
4+
dir: diode/diode_val.tar
5+
filenames: datasets/eval/depth/data_split/diode/diode_val_all_filename_list.txt
6+
processing_res: 640

datasets/eval/depth/data_split/diode/diode_val_all_filename_list.txt

Lines changed: 771 additions & 0 deletions
Large diffs are not rendered by default.

datasets/eval/depth/data_split/diode/diode_val_indoor_filename_list.txt

Lines changed: 325 additions & 0 deletions
Large diffs are not rendered by default.

datasets/eval/depth/data_split/diode/diode_val_outdoor_filename_list.txt

Lines changed: 446 additions & 0 deletions
Large diffs are not rendered by default.

eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def gen_normal(img, pipe, prompt="", num_inference_steps=1):
207207
"kitti": "configs/data_kitti_eigen_test.yaml",
208208
"scannet": "configs/data_scannet_val.yaml",
209209
"eth3d": "configs/data_eth3d.yaml",
210+
"diode": "configs/data_diode_all.yaml",
210211
}
211212
for dataset_name, config_path in test_depth_dataset_configs.items():
212213
eval_dir = os.path.join(args.output_dir, args.task_name, dataset_name)
@@ -218,7 +219,7 @@ def gen_normal(img, pipe, prompt="", num_inference_steps=1):
218219
elif args.task_name == 'normal':
219220
test_data_dir = os.path.join(args.base_test_data_dir, args.task_name)
220221
dataset_split_path = "evaluation/dataset_normal"
221-
eval_datasets = [('nyuv2', 'test'), ('scannet', 'test'), ('ibims', 'ibims'), ('sintel', 'sintel')]
222+
eval_datasets = [ ('nyuv2', 'test'), ('scannet', 'test'), ('ibims', 'ibims'), ('sintel', 'sintel'), ('oasis', 'val')]
222223
eval_dir = os.path.join(args.output_dir, args.task_name)
223224
evaluation_normal(eval_dir, test_data_dir, dataset_split_path, eval_mode="generate_prediction",
224225
gen_prediction=gen_normal, pipeline=pipeline, eval_datasets=eval_datasets,

evaluation/dataset_depth/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
from .kitti_dataset import KITTIDataset
99
from .nyu_dataset import NYUDataset
1010
from .scannet_dataset import ScanNetDataset
11-
11+
from .diode_dataset import DIODEDataset
1212

1313
dataset_name_class_dict = {
1414
"nyu_v2": NYUDataset,
1515
"kitti": KITTIDataset,
1616
"eth3d": ETH3DDataset,
1717
"scannet": ScanNetDataset,
18+
"diode": DIODEDataset,
1819
}
1920

2021

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Author: Bingxin Ke
2+
# Last modified: 2024-02-26
3+
4+
import os
5+
import tarfile
6+
from io import BytesIO
7+
8+
import numpy as np
9+
import torch
10+
11+
from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode, DatasetMode
12+
13+
14+
class DIODEDataset(BaseDepthDataset):
15+
def __init__(
16+
self,
17+
**kwargs,
18+
) -> None:
19+
super().__init__(
20+
# DIODE data parameter
21+
min_depth=0.6,
22+
max_depth=350,
23+
has_filled_depth=False,
24+
name_mode=DepthFileNameMode.id,
25+
**kwargs,
26+
)
27+
28+
def _read_npy_file(self, rel_path):
29+
if self.is_tar:
30+
if self.tar_obj is None:
31+
self.tar_obj = tarfile.open(self.dataset_dir)
32+
fileobj = self.tar_obj.extractfile("./" + rel_path)
33+
npy_path_or_content = BytesIO(fileobj.read())
34+
else:
35+
npy_path_or_content = os.path.join(self.dataset_dir, rel_path)
36+
data = np.load(npy_path_or_content).squeeze()[np.newaxis, :, :]
37+
return data
38+
39+
def _read_depth_file(self, rel_path):
40+
depth = self._read_npy_file(rel_path)
41+
return depth
42+
43+
def _get_data_path(self, index):
44+
return self.filenames[index]
45+
46+
def _get_data_item(self, index):
47+
# Special: depth mask is read from data
48+
49+
rgb_rel_path, depth_rel_path, mask_rel_path = self._get_data_path(index=index)
50+
51+
rasters = {}
52+
53+
# RGB data
54+
rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path))
55+
56+
# Depth data
57+
if DatasetMode.RGB_ONLY != self.mode:
58+
# load data
59+
depth_data = self._load_depth_data(
60+
depth_rel_path=depth_rel_path, filled_rel_path=None
61+
)
62+
rasters.update(depth_data)
63+
64+
# valid mask
65+
mask = self._read_npy_file(mask_rel_path).astype(bool)
66+
mask = torch.from_numpy(mask).bool()
67+
rasters["valid_mask_raw"] = mask.clone()
68+
rasters["valid_mask_filled"] = mask.clone()
69+
70+
other = {"index": index, "rgb_relative_path": rgb_rel_path}
71+
72+
return rasters, other
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
""" Get samples from OASIS validation set (https://pvl.cs.princeton.edu/OASIS/)
2+
"""
3+
import os
4+
import cv2
5+
import numpy as np
6+
import pickle
7+
8+
from evaluation.dataset_normal import Sample
9+
10+
11+
def read_normal(path, h, w):
12+
normal_dict = pickle.load(open(path, 'rb'))
13+
14+
mask = np.zeros((h,w))
15+
normal = np.zeros((h,w,3))
16+
17+
# Stuff ROI normal into bounding box
18+
min_y = normal_dict['min_y']
19+
max_y = normal_dict['max_y']
20+
min_x = normal_dict['min_x']
21+
max_x = normal_dict['max_x']
22+
roi_normal = normal_dict['normal']
23+
24+
# to LUB
25+
normal[min_y:max_y+1, min_x:max_x+1, :] = roi_normal
26+
normal = normal.astype(np.float32)
27+
normal[:,:,0] *= -1
28+
normal[:,:,1] *= -1
29+
30+
# Make mask
31+
roi_mask = np.logical_or(np.logical_or(roi_normal[:,:,0] != 0, roi_normal[:,:,1] != 0), roi_normal[:,:,2] != 0).astype(np.float32)
32+
mask[min_y:max_y+1, min_x:max_x+1] = roi_mask
33+
mask = mask[:, :, None]
34+
mask = mask > 0.5
35+
36+
return normal, mask
37+
38+
39+
def get_sample(base_data_dir, sample_path, info):
40+
# e.g. sample_path = "val/100277_DT_img.png"
41+
scene_name = sample_path.split('/')[0]
42+
img_name, img_ext = sample_path.split('/')[-1].split('_img')
43+
44+
dataset_path = os.path.join(base_data_dir, 'dsine_eval', 'oasis')
45+
img_path = '%s/%s' % (dataset_path, sample_path)
46+
normal_path = img_path.replace('_img'+img_ext, '_normal.pkl')
47+
intrins_path = img_path.replace('_img'+img_ext, '_intrins.npy')
48+
assert os.path.exists(img_path)
49+
assert os.path.exists(normal_path)
50+
assert os.path.exists(intrins_path)
51+
52+
# read image (H, W, 3)
53+
img = cv2.cvtColor(cv2.imread(img_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB)
54+
img = img.astype(np.float32) / 255.0
55+
56+
# read normal (H, W, 3)
57+
h = img.shape[0]
58+
w = img.shape[1]
59+
normal, normal_mask = read_normal(normal_path, h, w)
60+
61+
# read intrins (3, 3)
62+
intrins = np.load(intrins_path)
63+
64+
sample = Sample(
65+
img=img,
66+
normal=normal,
67+
normal_mask=normal_mask,
68+
intrins=intrins,
69+
70+
dataset_name='oasis',
71+
scene_name=scene_name,
72+
img_name=img_name,
73+
info=info
74+
)
75+
76+
return sample

0 commit comments

Comments
 (0)