Skip to content

Commit 719c08e

Browse files
IlyaOvodovvera121
authored andcommitted
Weight parameter in solver is used in caffe.exe
Loading weights is moved from caffe.exe to solver class, so new "weights" solver parameter is used not only from command line but when caffe is used as library (including python) corrected formatting fixed line length more formatting corrected
1 parent 47f20ae commit 719c08e

File tree

4 files changed

+43
-16
lines changed

4 files changed

+43
-16
lines changed

src/caffe/proto/caffe.proto

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ message NetParameter {
214214
// Update the next available ID when you add a new SolverParameter field.
215215
//
216216
// SolverParameter next available ID: 44 (last added: plateau_winsize)
217+
217218
message SolverParameter {
218219
//////////////////////////////////////////////////////////////////////////////
219220
// Specifying the train and test networks
@@ -372,6 +373,16 @@ message SolverParameter {
372373
optional bool gan_solver = 51 [default = false];
373374
// Overlap compute and communication for data parallel training
374375
optional bool layer_wise_reduce = 45 [default = true];
376+
377+
// Path to caffemodel file(s) with pretrained weights to initialize finetuning.
378+
// Tha same as command line --weights parameter for caffe train command.
379+
// If command line --weights parameter if specified, it has higher priority
380+
// and owerwrites this one(s).
381+
// If --snapshot command line parameter is specified, this one(s) are ignored.
382+
// If several model files are expected, they can be listed in a one
383+
// weights parameter separated by ',' (like in a command string) or
384+
// in repeated weights parameters separately.
385+
repeated string weights = 46;
375386
}
376387

377388
// A message that stores the solver snapshots

src/caffe/solver.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <utility>
66
#include <vector>
77

8+
#include "boost/algorithm/string.hpp"
89
#include "caffe/solver.hpp"
910
#include "caffe/util/bbox_util.hpp"
1011
#include "caffe/util/format.hpp"
@@ -62,6 +63,20 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
6263
current_step_ = 0;
6364
}
6465

66+
// Load weights from the caffemodel(s) specified in "weights" solver parameter
67+
// into the train and test nets.
68+
template <typename Dtype>
69+
void LoadNetWeights(shared_ptr<Net<Dtype> > net,
70+
const std::string& model_list) {
71+
std::vector<std::string> model_names;
72+
boost::split(model_names, model_list, boost::is_any_of(","));
73+
for (int i = 0; i < model_names.size(); ++i) {
74+
boost::trim(model_names[i]);
75+
LOG(INFO) << "Finetuning from " << model_names[i];
76+
net->CopyTrainedLayersFrom(model_names[i]);
77+
}
78+
}
79+
6580
template <typename Dtype>
6681
void Solver<Dtype>::InitTrainNet() {
6782
const int num_train_nets = param_.has_net() + param_.has_net_param() +
@@ -101,6 +116,9 @@ void Solver<Dtype>::InitTrainNet() {
101116
net_state.MergeFrom(param_.train_state());
102117
net_param.mutable_state()->CopyFrom(net_state);
103118
net_.reset(new Net<Dtype>(net_param));
119+
for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
120+
LoadNetWeights(net_, param_.weights(w_idx));
121+
}
104122
}
105123

106124
template <typename Dtype>
@@ -176,6 +194,9 @@ void Solver<Dtype>::InitTestNets() {
176194
<< "Creating test net (#" << i << ") specified by " << sources[i];
177195
test_nets_[i].reset(new Net<Dtype>(net_params[i]));
178196
test_nets_[i]->set_debug_info(param_.debug_info());
197+
for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
198+
LoadNetWeights(test_nets_[i], param_.weights(w_idx));
199+
}
179200
}
180201
}
181202

src/caffe/test/test_upgrade_proto.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2952,6 +2952,8 @@ TEST_F(SolverTypeUpgradeTest, TestSimple) {
29522952
for (int i = 0; i < 6; ++i) {
29532953
const string& input_proto =
29542954
"net: 'examples/mnist/lenet_train_test.prototxt' "
2955+
"weights: 'examples/mnist/lenet_train_test1.caffemodel' "
2956+
"weights: 'examples/mnist/lenet_train_test2.caffemodel' "
29552957
"test_iter: 100 "
29562958
"test_interval: 500 "
29572959
"base_lr: 0.01 "
@@ -2968,6 +2970,8 @@ TEST_F(SolverTypeUpgradeTest, TestSimple) {
29682970
"solver_type: " + std::string(old_type_vec[i]) + " ";
29692971
const string& expected_output_proto =
29702972
"net: 'examples/mnist/lenet_train_test.prototxt' "
2973+
"weights: 'examples/mnist/lenet_train_test1.caffemodel' "
2974+
"weights: 'examples/mnist/lenet_train_test2.caffemodel' "
29712975
"test_iter: 100 "
29722976
"test_interval: 500 "
29732977
"base_lr: 0.01 "

tools/caffe.cpp

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -146,20 +146,6 @@ int device_query() {
146146
}
147147
RegisterBrewFunction(device_query);
148148

149-
// Load the weights from the specified caffemodel(s) into the train and
150-
// test nets.
151-
void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) {
152-
std::vector<std::string> model_names;
153-
boost::split(model_names, model_list, boost::is_any_of(",") );
154-
for (int i = 0; i < model_names.size(); ++i) {
155-
LOG(INFO) << "Finetuning from " << model_names[i];
156-
solver->net()->CopyTrainedLayersFrom(model_names[i]);
157-
for (int j = 0; j < solver->test_nets().size(); ++j) {
158-
solver->test_nets()[j]->CopyTrainedLayersFrom(model_names[i]);
159-
}
160-
}
161-
}
162-
163149
// Translate the signal effect the user specified on the command-line to the
164150
// corresponding enumeration.
165151
caffe::SolverAction::Enum GetRequestedAction(
@@ -234,6 +220,13 @@ int train() {
234220
GetRequestedAction(FLAGS_sigint_effect),
235221
GetRequestedAction(FLAGS_sighup_effect));
236222

223+
if (FLAGS_snapshot.size()) {
224+
solver_param.clear_weights();
225+
} else if (FLAGS_weights.size()) {
226+
solver_param.clear_weights();
227+
solver_param.add_weights(FLAGS_weights);
228+
}
229+
237230
shared_ptr<caffe::Solver<float> >
238231
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
239232

@@ -242,8 +235,6 @@ int train() {
242235
if (FLAGS_snapshot.size()) {
243236
LOG(INFO) << "Resuming from " << FLAGS_snapshot;
244237
solver->Restore(FLAGS_snapshot.c_str());
245-
} else if (FLAGS_weights.size()) {
246-
CopyLayers(solver.get(), FLAGS_weights);
247238
}
248239

249240
LOG(INFO) << "Starting Optimization";

0 commit comments

Comments
 (0)