Skip to content

Commit 94350e5

Browse files
lucasb-eyerPandoro
andcommitted
Initial commit of re-implemented training code.
Co-authored-by: Alexander Hermans <[email protected]>
1 parent f47f9ef commit 94350e5

23 files changed

+38282
-8
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
__pycache__
2+
*.pyc

README.md

Lines changed: 134 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,35 @@
22

33
Code for reproducing the results of our "In Defense of the Triplet Loss for Person Re-Identification" paper.
44

5-
Both main authors are currently in an internship.
6-
We will publish the full training code after our internships, which is end of September 2017.
7-
(By "Watching" this project on github, you will receive e-mails about updates to this repo.)
8-
Meanwhile, we provide the pre-trained weights for the TriNet model, as well as some rudimentary example code for using it to compute embeddings, see below.
5+
We provide the following things:
6+
- The exact pre-trained weights for the TriNet model as used in the paper, including some rudimentary example code for using it to compute embeddings.
7+
See section [Pretrained models](#pretrained-models).
8+
- A clean re-implementation of the training code that can be used for training your own models/data.
9+
See section [Training your own models](#training-your-own-models).
10+
- A script for evaluation which computes the CMC and mAP of embeddings in an HDF5 ("new .mat") file.
11+
See section [Evaluating embeddings](#evaluating-embeddings).
912

10-
# Pretrained Models
13+
If you use any of the provided code, please cite:
14+
```
15+
@article{HermansBeyer2017Arxiv,
16+
title = {{In Defense of the Triplet Loss for Person Re-Identification}},
17+
author = {Hermans*, Alexander and Beyer*, Lucas and Leibe, Bastian},
18+
journal = {arXiv preprint arXiv:1703.07737},
19+
year = {2017}
20+
}
21+
```
22+
23+
24+
# Pretrained models
1125

12-
This is a first, simple release. A better more generic script will follow in a few months, but this should be enough to get started trying out our models!
26+
We provide the exact TriNet model used in the paper, which was implemented in
27+
[Theano](http://deeplearning.net/software/theano/install.html)
28+
and
29+
[Lasagne](http://lasagne.readthedocs.io/en/latest/user/installation.html).
1330

1431
As a first step, download either of these pre-trained models:
1532
- [TriNet trained on MARS](https://omnomnom.vision.rwth-aachen.de/data/trinet-mars.npz) (md5sum: `72fafa2ee9aa3765f038d06e8dd8ef4b`)
1633
- [TriNet trained on Market1501](https://omnomnom.vision.rwth-aachen.de/data/trinet-market1501.npz) (md5sum: `5353f95d1489536129ec14638aded3c7`)
17-
- (LuNet models will follow.)
1834

1935
Next, create a file (`files.txt`) which contains the full path to the image files you want to embed, one filename per line, like so:
2036

@@ -23,7 +39,7 @@ Next, create a file (`files.txt`) which contains the full path to the image file
2339
/path/to/file2.jpg
2440
```
2541

26-
Finally, run the `trinet_embed.py` script, passing both the above file and the weights file you want to use, like so:
42+
Finally, run the `trinet_embed.py` script, passing both the above file and the pretrained model file you want to use, like so:
2743

2844
```
2945
python trinet_embed.py files.txt /path/to/trinet-mars.npz
@@ -47,3 +63,113 @@ You can now do meaningful work by comparing these embeddings using the Euclidean
4763
A couple notes:
4864
- The script depends on [Theano](http://deeplearning.net/software/theano/install.html), [Lasagne](http://lasagne.readthedocs.io/en/latest/user/installation.html) and [OpenCV Python](http://opencv.org/) (`pip install opencv-python`) being correctly installed.
4965
- The input files should be crops of a full person standing upright, and they will be resized to `288x144` before being passed to the network.
66+
67+
68+
# Training your own models
69+
70+
If you want more flexibility, we now provide code for training your own models.
71+
This is not the code that was used in the paper (which became a unusable mess),
72+
but rather a clean re-implementation of it in [TensorFlow](https://www.tensorflow.org/),
73+
achieving about the same performance.
74+
75+
- **This repository requires at least version 1.4 of TensorFlow.**
76+
- **The TensorFlow code is Python 3 only and won't work in Python 2!**
77+
78+
## Defining a dataset
79+
80+
A dataset consists of two things:
81+
82+
1. An `image_root` folder which contains all images, possibly in sub-folders.
83+
2. A dataset `.csv` file describing the dataset.
84+
85+
To create a dataset, you simply create a new `.csv` file for it of the following form:
86+
87+
```
88+
identity,relative_path/to/image.jpg
89+
```
90+
91+
Where the `identity` is also often called `PID` (`P`erson `ID`entity) and corresponds to the "class name",
92+
it can be any arbitrary string, but should be the same for images belonging to the same identity.
93+
94+
The `relative_path/to/image.jpg` is relative to aforementioned `image_root`.
95+
96+
## Training
97+
98+
Given the dataset file, and the `image_root`, you can already train a model.
99+
The minimal way of training a model is to just call `train.py` in the following way:
100+
101+
```
102+
python train.py \
103+
--train_set data/market1501_train.csv \
104+
--image_root /absolute/image/root \
105+
--experiment_root ~/experiments/my_experiment
106+
```
107+
108+
This will start training with all default parameters.
109+
We recommend writing a script file similar to `market1501_train.sh` where you define all kinds of parameters,
110+
it is **highly recommended** you tune hyperparameters such as `net_input_{height,width}`, `learning_rate`,
111+
`decay_start_iteration`, and many more.
112+
See the top of `train.py` for a list of all parameters.
113+
114+
As a convenience, we store all the parameters that were used for a run in `experiment_root/args.json`.
115+
116+
### Pre-trained initialization
117+
118+
If you want to initialize the model using pre-trained weights, such as done for TriNet,
119+
you need to specify the location of the checkpoint file through `--initial_checkpoint`.
120+
121+
For most common models, you can download the [checkpoints provided by Google here](https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models).
122+
For example, that's where we get our ResNet50 pre-trained weights from,
123+
and what you should pass as second parameter to `market1501_train.sh`.
124+
125+
## Interrupting and resuming training
126+
127+
Since training can take quite a while, interrupting and resuming training is important.
128+
You can interrupt training at any time by hitting `Ctrl+C` or sending `SIGINT (2)` or `SIGTERM (15)`
129+
to the training process; it will finish the current batch, store the model and optimizer state,
130+
and then terminate cleanly.
131+
Because of the `args.json` file, you can later resume that run simply by running:
132+
133+
```
134+
python train.py --experiment_root ~/experiments/my_experiment --resume
135+
```
136+
137+
The last checkpoint is determined automatically by TensorFlow using the contents of the `checkpoint` file.
138+
139+
## Performance issues
140+
141+
For some reason, current TensorFlow is known to have inconsistent performance and can sometimes become very slow.
142+
The current only known workaround is to install google's performance-tools and preload tcmalloc:
143+
144+
```
145+
env LD_PRELOAD=/usr/lib/libtcmalloc_minimal.so.4 python train.py ...
146+
```
147+
148+
This fixes the issues for us most of the time, but not always.
149+
If you know more, please open an issue and let us know!
150+
151+
## Out of memory
152+
153+
The setup as described in the paper requires a high-end GPU with a lot of memory.
154+
If you don't have that, you can still train a model, but you should either use a smaller network,
155+
or adjust the batch-size, which itself also adjusts learning difficulty, which might change results.
156+
157+
The two arguments for playing with the batch-size are `--batch_p` which controls the number of distinct
158+
persons in a batch, and `--batch_k` which controls the number of pictures per person.
159+
We usually lower `batch_p` first.
160+
161+
## Custom network architecture
162+
163+
TODO: Documentation. It's also pretty straightforward.
164+
165+
### The core network
166+
167+
### The network head
168+
169+
## Computing embeddings
170+
171+
TODO: Will be added later.
172+
173+
# Evaluating embeddings
174+
175+
TODO: Will be added later.

common.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
""" A bunch of general utilities shared by train/embed/eval """
2+
3+
from argparse import ArgumentTypeError
4+
import os
5+
6+
import numpy as np
7+
import tensorflow as tf
8+
9+
# Commandline argument parsing
10+
###
11+
12+
def check_directory(arg, access=os.W_OK, access_str="writeable"):
13+
""" Check for directory-type argument validity.
14+
15+
Checks whether the given `arg` commandline argument is either a readable
16+
existing directory, or a createable/writeable directory.
17+
18+
Args:
19+
arg (string): The commandline argument to check.
20+
access (constant): What access rights to the directory are requested.
21+
access_str (string): Used for the error message.
22+
23+
Returns:
24+
The string passed din `arg` if the checks succeed.
25+
26+
Raises:
27+
ArgumentTypeError if the checks fail.
28+
"""
29+
path_head = arg
30+
while path_head:
31+
if os.path.exists(path_head):
32+
if os.access(path_head, access):
33+
# Seems legit, but it still doesn't guarantee a valid path.
34+
# We'll just go with it for now though.
35+
return arg
36+
else:
37+
raise ArgumentTypeError(
38+
'The provided string `{0}` is not a valid {1} path '
39+
'since {2} is an existing folder without {1} access.'
40+
''.format(arg, access_str, path_head))
41+
path_head, _ = os.path.split(path_head)
42+
43+
# No part of the provided string exists and can be written on.
44+
raise ArgumentTypeError('The provided string `{}` is not a valid {}'
45+
' path.'.format(arg, access_str))
46+
47+
48+
def writeable_directory(arg):
49+
""" To be used as a type for `ArgumentParser.add_argument`. """
50+
return check_directory(arg, os.W_OK, "writeable")
51+
52+
53+
def readable_directory(arg):
54+
""" To be used as a type for `ArgumentParser.add_argument`. """
55+
return check_directory(arg, os.R_OK, "readable")
56+
57+
58+
def number_greater_x(arg, type_, x):
59+
try:
60+
value = type_(arg)
61+
except ValueError:
62+
raise ArgumentTypeError('The argument "{}" is not an {}.'.format(
63+
arg, type_.__name__))
64+
65+
if value > x:
66+
return value
67+
else:
68+
raise ArgumentTypeError('Found {} where an {} greater than {} was '
69+
'required'.format(arg, type_.__name__, x))
70+
71+
72+
def positive_int(arg):
73+
return number_greater_x(arg, int, 0)
74+
75+
76+
def nonnegative_int(arg):
77+
return number_greater_x(arg, int, -1)
78+
79+
80+
def positive_float(arg):
81+
return number_greater_x(arg, float, 0)
82+
83+
84+
def float_or_string(arg):
85+
"""Tries to convert the string to float, otherwise returns the string."""
86+
try:
87+
return float(arg)
88+
except (ValueError, TypeError):
89+
return arg
90+
91+
92+
# Dataset handling
93+
###
94+
95+
96+
def load_dataset(csv_file, image_root, fail_on_missing=True):
97+
""" Loads a dataset .csv file, returning PIDs and FIDs.
98+
99+
PIDs are the "person IDs", i.e. class names/labels.
100+
FIDs are the "file IDs", which are individual relative filenames.
101+
102+
Args:
103+
csv_file (string, file-like object): The csv data file to load.
104+
image_root (string): The path to which the image files as stored in the
105+
csv file are relative to. Used for verification purposes.
106+
fail_on_missing (bool): If one or more files from the dataset are not
107+
present in the `image_root`, either raise an IOError (if True) or
108+
remove it from the returned dataset (if False).
109+
110+
Returns:
111+
(pids, fids) a tuple of numpy string arrays corresponding to the PIDs,
112+
i.e. the identities/classes/labels and the FIDs, i.e. the filenames.
113+
114+
Raises:
115+
IOError if any one file is missing and `fail_on_missing` is True.
116+
"""
117+
dataset = np.genfromtxt(csv_file, delimiter=',', dtype='|U')
118+
pids, fids = dataset.T
119+
120+
# Check if all files exist
121+
missing = np.full(len(fids), False, dtype=bool)
122+
for i, fid in enumerate(fids):
123+
missing[i] = not os.path.isfile(os.path.join(image_root, fid))
124+
125+
missing_count = np.sum(missing)
126+
if missing_count > 0:
127+
if fail_on_missing:
128+
raise IOError('Using the `{}` file and `{}` as an image root {}/'
129+
'{} images are missing'.format(
130+
csv_file, image_root, missing_count, len(fids)))
131+
else:
132+
print('[Warning] removing {} missing file(s) from the'
133+
' dataset.'.format(missing_count))
134+
# We simply remove the missing files.
135+
fids = fids[np.logical_not(missing)]
136+
pids = pids[np.logical_not(missing)]
137+
138+
return pids, fids
139+
140+
141+
def fid_to_image(fid, pid, image_root, image_size):
142+
""" Loads and resizes an image given by FID. Pass-through the PID. """
143+
# Since there is no symbolic path.join, we just add a '/' to be sure.
144+
image_encoded = tf.read_file(tf.reduce_join([image_root, '/', fid]))
145+
146+
# tf.image.decode_image doesn't set the shape, not even the dimensionality,
147+
# because it potentially loads animated .gif files. Instead, we use either
148+
# decode_jpeg or decode_png, each of which can decode both.
149+
# Sounds ridiculous, but is true:
150+
# https://github.com/tensorflow/tensorflow/issues/9356#issuecomment-309144064
151+
image_decoded = tf.image.decode_jpeg(image_encoded, channels=3)
152+
image_resized = tf.image.resize_images(image_decoded, image_size)
153+
154+
return image_resized, fid, pid

0 commit comments

Comments
 (0)