Skip to content

Commit 7b2236c

Browse files
author
Fangchang Ma
committed
uploaded codes
1 parent c5cb360 commit 7b2236c

File tree

9 files changed

+1486
-2
lines changed

9 files changed

+1486
-2
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
results
2+
data
3+
14
# Byte-compiled / optimized / DLL files
25
__pycache__/
36
*.py[cod]

README.md

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,116 @@
1-
# sparse-to-dense.pytorch
2-
PyTorch Version of ICRA 2018 "Sparse-to-Dense: Depth Prediction from Sparse Depth Samples and a Single Image"
1+
sparse-to-dense.pytorch
2+
============================
3+
4+
This repo implements the training and testing of deep regression neural networks for ["Sparse-to-Dense: Depth Prediction from Sparse Depth Samples and a Single Image"](https://arxiv.org/pdf/1709.07492.pdf) by [Fangchang Ma](http://www.mit.edu/~fcma) and [Sertac Karaman](http://karaman.mit.edu/) at MIT. A video demonstration is available on [YouTube](https://youtu.be/vNIIT_M7x7Y).
5+
<p align="center">
6+
<img src="http://www.mit.edu/~fcma/images/ICRA2018.png" alt="photo not available" width="50%" height="50%">
7+
<img src="https://j.gifs.com/Z4qDow.gif" alt="photo not available" height="50%">
8+
</p>
9+
10+
This repo can be used for training and testing of
11+
- RGB (or grayscale image) based depth prediction
12+
- sparse depth based depth prediction
13+
- RGBd (i.e., both RGB and sparse depth) based depth prediction
14+
15+
The original Torch implementation of the paper can be found [here](https://github.com/fangchangma/sparse-to-dense). This PyTorch version is under development and is subject to major modifications in the future.
16+
17+
## Contents
18+
0. [Requirements](#requirements)
19+
0. [Training](#training)
20+
0. [Testing](#testing)
21+
0. [Trained Models](#trained-models)
22+
0. [Benchmark](#benchmark)
23+
0. [Citation](#citation)
24+
25+
## Requirements
26+
- Install [PyTorch](http://pytorch.org/) on a machine with CUDA GPU.
27+
- Install the [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) and other dependencies (files in our pre-processed datasets are in HDF5 formats).
28+
```bash
29+
sudo apt-get update
30+
sudo apt-get install -y libhdf5-serial-dev hdf5-tools
31+
pip install h5py matplotlib imageio scikit-image
32+
```
33+
- Download the preprocessed [NYU Depth V2](http://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) dataset in HDF5 formats, and place them under the `data` folder. The downloading process might take an hour or so. The NYU dataset requires 32G of storage space.
34+
```bash
35+
mkdir data
36+
cd data
37+
wget http://datasets.lids.mit.edu/sparse-to-dense/data/nyudepthv2.tar.gz
38+
tar -xvf nyudepthv2.tar.gz && rm -f nyudepthv2.tar.gz
39+
cd ..
40+
```
41+
## Training
42+
The training scripts come with several options, which can be listed with the `--help` flag. Currently this repo only supports training on the NYU dataset, and deconvolution with different kernel sizes (no `upconv` or `upproj` since we found them to be inefficient compared with using simple `deconv` with larger kernel sizes).
43+
```bash
44+
python3 main.py --help
45+
```
46+
47+
For instance, run the following command to train a network with ResNet50 as the encoder, deconvolutions of kernel size 3 as the decoder, and both RGB and 100 random sparse depth samples as the input to the network.
48+
```bash
49+
python3 main.py -a resnet50 -d deconv3 -m rgbd -s 100
50+
```
51+
52+
Training results will be saved under the `results` folder.
53+
54+
55+
## Testing
56+
To test the performance of a trained model, simply run main.py with the `-e` option, along with other model options. For instance,
57+
```bash
58+
python3 main.py -e
59+
```
60+
61+
## Trained Models
62+
Trained models will be released later.
63+
64+
## Benchmark
65+
The following numbers are from the original Torch repo.
66+
- Error metrics on NYU Depth v2:
67+
68+
| RGB | rms | rel | delta1 | delta2 | delta3 |
69+
|-----------------------------|:-----:|:-----:|:-----:|:-----:|:-----:|
70+
| [Roy & Todorovic](http://web.engr.oregonstate.edu/~sinisa/research/publications/cvpr16_NRF.pdf) (_CVPR 2016_) | 0.744 | 0.187 | - | - | - |
71+
| [Eigen & Fergus](http://cs.nyu.edu/~deigen/dnl/) (_ICCV 2015_) | 0.641 | 0.158 | 76.9 | 95.0 | 98.8 |
72+
| [Laina et al](https://arxiv.org/pdf/1606.00373.pdf) (_3DV 2016_) | 0.573 | **0.127** | **81.1** | 95.3 | 98.8 |
73+
| Ours-RGB | **0.514** | 0.143 | 81.0 | **95.9** | **98.9** |
74+
75+
| RGBd-#samples | rms | rel | delta1 | delta2 | delta3 |
76+
|-----------------------------|:-----:|:-----:|:-----:|:-----:|:-----:|
77+
| [Liao et al](https://arxiv.org/abs/1611.02174) (_ICRA 2017_)-225 | 0.442 | 0.104 | 87.8 | 96.4 | 98.9 |
78+
| Ours-20 | 0.351 | 0.078 | 92.8 | 98.4 | 99.6 |
79+
| Ours-50 | 0.281 | 0.059 | 95.5 | 99.0 | 99.7 |
80+
| Ours-200| **0.230** | **0.044** | **97.1** | **99.4** | **99.8** |
81+
82+
<img src="http://www.mit.edu/~fcma/images/ICRA18/acc_vs_samples_nyu.png" alt="photo not available" width="50%" height="50%">
83+
84+
- Error metrics on KITTI dataset:
85+
86+
| RGB | rms | rel | delta1 | delta2 | delta3 |
87+
|-----------------------------|:-----:|:-----:|:-----:|:-----:|:-----:|
88+
| [Make3D](http://papers.nips.cc/paper/5539-depth-map-prediction-from-a-single-image-using-a-multi-scale-deep-network.pdf) | 8.734 | 0.280 | 60.1 | 82.0 | 92.6 |
89+
| [Mancini et al](https://arxiv.org/pdf/1607.06349.pdf) (_IROS 2016_) | 7.508 | - | 31.8 | 61.7 | 81.3 |
90+
| [Eigen et al](http://papers.nips.cc/paper/5539-depth-map-prediction-from-a-single-image-using-a-multi-scale-deep-network.pdf) (_NIPS 2014_) | 7.156 | **0.190** | **69.2** | 89.9 | **96.7** |
91+
| Ours-RGB | **6.266** | 0.208 | 59.1 | **90.0** | 96.2 |
92+
93+
| RGBd-#samples | rms | rel | delta1 | delta2 | delta3 |
94+
|-----------------------------|:-----:|:-----:|:-----:|:-----:|:-----:|
95+
| [Cadena et al](https://pdfs.semanticscholar.org/18d5/f0747a23706a344f1d15b032ea22795324fa.pdf) (_RSS 2016_)-650 | 7.14 | 0.179 | 70.9 | 88.8 | 95.6 |
96+
| Ours-50 | 4.884 | 0.109 | 87.1 | 95.2 | 97.9 |
97+
| [Liao et al](https://arxiv.org/abs/1611.02174) (_ICRA 2017_)-225 | 4.50 | 0.113 | 87.4 | 96.0 | 98.4 |
98+
| Ours-100 | 4.303 | 0.095 | 90.0 | 96.3 | 98.3 |
99+
| Ours-200 | 3.851 | 0.083 | 91.9 | 97.0 | 98.6 |
100+
| Ours-500| **3.378** | **0.073** | **93.5** | **97.6** | **98.9** |
101+
102+
<img src="http://www.mit.edu/~fcma/images/ICRA18/acc_vs_samples_kitti.png" alt="photo not available" width="50%" height="50%">
103+
104+
Note: our networks are trained on the KITTI odometry dataset, using only sparse labels from laser measurements.
105+
106+
## Citation
107+
If you use our code or method in your work, please cite:
108+
109+
@article{Ma2017SparseToDense,
110+
title={Sparse-to-Dense: Depth Prediction from Sparse Depth Samples and a Single Image},
111+
author={Ma, Fangchang and Karaman, Sertac},
112+
journal={arXiv preprint arXiv:1709.07492},
113+
year={2017}
114+
}
115+
116+
Please direct any questions to [Fangchang Ma](http://www.mit.edu/~fcma) at [email protected].

criteria.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.autograd import Variable
4+
5+
class MaskedMSELoss(nn.Module):
6+
def __init__(self):
7+
super(MaskedMSELoss, self).__init__()
8+
9+
def forward(self, pred, target):
10+
assert pred.dim() == target.dim(), "inconsistent dimensions"
11+
valid_mask = (target>0).detach()
12+
diff = target - pred
13+
diff = diff[valid_mask]
14+
self.loss = (diff ** 2).mean()
15+
return self.loss
16+
17+
class MaskedL1Loss(nn.Module):
18+
def __init__(self):
19+
super(MaskedL1Loss, self).__init__()
20+
21+
def forward(self, pred, target):
22+
assert pred.dim() == target.dim(), "inconsistent dimensions"
23+
valid_mask = (target>0).detach()
24+
diff = target - pred
25+
diff = diff[valid_mask]
26+
self.loss = diff.abs().mean()
27+
return self.loss

0 commit comments

Comments
 (0)