Skip to content
This repository was archived by the owner on Dec 21, 2023. It is now read-only.

Commit f30f7ae

Browse files
authored
Introduce style_transfer::ModelTrainer abstraction (#3082)
1 parent 8d53fbb commit f30f7ae

11 files changed

+959
-191
lines changed

src/ml/neural_net/combine_iterator.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,32 @@ class Iterator : public std::enable_shared_from_this<Iterator<T>> {
5151
}
5252
};
5353

54+
/**
55+
* Templated implementation of Iterator that wraps an arbitrary callable type.
56+
*/
57+
template <typename Callable>
58+
class CallableIterator
59+
: public Iterator<typename std::result_of<Callable()>::type> {
60+
public:
61+
using Output = typename std::result_of<Callable()>::type;
62+
63+
CallableIterator(Callable impl) : impl_(std::move(impl)) {}
64+
65+
bool HasNext() const override { return true; }
66+
67+
Output Next() override { return impl_(); }
68+
69+
private:
70+
Callable impl_;
71+
};
72+
73+
template <typename Callable>
74+
std::shared_ptr<IteratorPublisher<typename std::result_of<Callable()>::type>>
75+
CreatePublisherFromCallable(Callable impl) {
76+
return std::make_shared<CallableIterator<Callable>>(std::move(impl))
77+
->AsPublisher();
78+
}
79+
5480
/**
5581
* Concrete Publisher that wraps an Iterator.
5682
*

src/ml/neural_net/model_spec.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,8 @@ model_spec::model_spec(const std::string& mlmodel_path)
506506
impl_->Swap(mlmodel.mutable_neuralnetwork());
507507
}
508508

509+
model_spec::model_spec(model_spec&&) = default;
510+
model_spec& model_spec::operator=(model_spec&&) = default;
509511
model_spec::~model_spec() = default;
510512

511513
std::unique_ptr<NeuralNetwork> model_spec::move_coreml_spec() && {

src/ml/neural_net/model_spec.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ class model_spec {
6868
// Declared here and defined in the .cpp file just to prevent the implicit
6969
// default destructor from attempting (and failing) to instantiate
7070
// std::unique_ptr<NeuralNetwork>::~unique_ptr()
71+
model_spec(model_spec&&);
72+
model_spec& operator=(model_spec&&);
7173
~model_spec();
7274

7375
/**

src/toolkits/style_transfer/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ project(Turi)
22

33
make_library(unity_style_transfer OBJECT
44
SOURCES
5+
st_model_trainer.cpp
6+
st_resnet16_model_trainer.cpp
57
style_transfer.cpp
68
style_transfer_data_iterator.cpp
79
style_transfer_model_definition.cpp
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
/* Copyright © 2020 Apple Inc. All rights reserved.
2+
*
3+
* Use of this source code is governed by a BSD-3-clause license that can
4+
* be found in the LICENSE.txt file or at
5+
* https://opensource.org/licenses/BSD-3-Clause
6+
*/
7+
8+
#include <toolkits/style_transfer/st_model_trainer.hpp>
9+
10+
#include <algorithm>
11+
12+
#include <model_server/lib/image_util.hpp>
13+
#include <toolkits/style_transfer/style_transfer.hpp>
14+
15+
namespace turi {
16+
namespace style_transfer {
17+
18+
using neural_net::compute_context;
19+
using neural_net::float_array_map;
20+
using neural_net::model_backend;
21+
using neural_net::model_spec;
22+
using neural_net::Publisher;
23+
using neural_net::shared_float_array;
24+
25+
DataBatch DataIterator::Next() {
26+
DataBatch batch;
27+
batch.iteration_id = ++last_iteration_id_;
28+
batch.examples = impl_->next_batch(batch_size_);
29+
return batch;
30+
}
31+
32+
InferenceDataIterator::InferenceDataIterator(
33+
std::shared_ptr<DataIterator> base_iterator, std::vector<int> style_indices)
34+
: base_iterator_(std::move(base_iterator)),
35+
style_indices_(std::move(style_indices)),
36+
next_style_(style_indices_.end()) {}
37+
38+
bool InferenceDataIterator::HasNext() const {
39+
return next_style_ != style_indices_.end() || base_iterator_->HasNext();
40+
}
41+
42+
DataBatch InferenceDataIterator::Next() {
43+
// If we're done emitting all the styles for the current underlying batch,
44+
// fetch the next batch from the underlying data iterator.
45+
if (next_style_ == style_indices_.end() && base_iterator_->HasNext()) {
46+
current_batch_ = base_iterator_->Next();
47+
next_style_ = style_indices_.begin();
48+
}
49+
50+
// Write the next style index into all the images in the current batch.
51+
if (next_style_ != style_indices_.end()) {
52+
for (st_example& example : current_batch_.examples) {
53+
example.style_index = *next_style_;
54+
}
55+
++next_style_;
56+
}
57+
58+
return current_batch_;
59+
}
60+
61+
TrainingProgress ProgressUpdater::Invoke(EncodedBatch batch) {
62+
auto reduce = [](const shared_float_array& array) {
63+
float loss = std::accumulate(array.data(), array.data() + array.size(), 0.f,
64+
std::plus<float>());
65+
loss /= array.size();
66+
return loss;
67+
};
68+
69+
// Compute the loss for this batch.
70+
float batch_loss = reduce(batch.encoded_data.at("loss"));
71+
72+
// Update our rolling average (smoothed) loss.
73+
if (smoothed_loss_) {
74+
*smoothed_loss_ *= 0.9f;
75+
*smoothed_loss_ += 0.1f * batch_loss;
76+
} else {
77+
// Initialize smoothed loss to the first loss value.
78+
smoothed_loss_.reset(new float(batch_loss));
79+
}
80+
81+
// Write smoothed loss into the result.
82+
TrainingProgress progress;
83+
progress.iteration_id = batch.iteration_id;
84+
progress.smoothed_loss = *smoothed_loss_;
85+
86+
// Write optional loss components into the result.
87+
auto style_loss_it = batch.encoded_data.find("style_loss");
88+
if (style_loss_it != batch.encoded_data.end()) {
89+
progress.style_loss = reduce(style_loss_it->second);
90+
}
91+
auto content_loss_it = batch.encoded_data.find("content_loss");
92+
if (content_loss_it != batch.encoded_data.end()) {
93+
progress.content_loss = reduce(content_loss_it->second);
94+
}
95+
96+
return progress;
97+
}
98+
99+
// static
100+
float_array_map Checkpoint::ExtractWeights(
101+
std::unique_ptr<model_spec> nn_spec) {
102+
float_array_map result = nn_spec->export_params_view();
103+
for (auto& name_and_weights : result) {
104+
// The original values will not be valid once the nn_spec is deconstructed.
105+
// TODO: Ideally this would not require copying. But we should move away
106+
// from using the protocol buffer as our primary representation anyway.
107+
name_and_weights.second = shared_float_array::copy(name_and_weights.second);
108+
}
109+
return result;
110+
}
111+
112+
std::shared_ptr<Publisher<TrainingProgress>>
113+
ModelTrainer::AsTrainingBatchPublisher(
114+
std::unique_ptr<data_iterator> training_data,
115+
const std::string& vgg_mlmodel_path, int offset,
116+
std::unique_ptr<float> initial_training_loss, compute_context* context) {
117+
auto iterator = std::make_shared<DataIterator>(std::move(training_data),
118+
config_.batch_size, offset);
119+
120+
int height = config_.training_image_height;
121+
int width = config_.training_image_width;
122+
auto encode = [height, width](DataBatch batch) {
123+
return EncodeTrainingBatch(std::move(batch), width, height);
124+
};
125+
126+
std::shared_ptr<model_backend> backend =
127+
CreateTrainingBackend(vgg_mlmodel_path, context);
128+
auto train = [backend](EncodedBatch batch) {
129+
EncodedBatch result;
130+
result.iteration_id = batch.iteration_id;
131+
result.encoded_data = backend->train(batch.encoded_data);
132+
return result;
133+
};
134+
135+
auto update_progress =
136+
std::make_shared<ProgressUpdater>(std::move(initial_training_loss));
137+
138+
return iterator->AsPublisher()->Map(encode)->Map(train)->Map(update_progress);
139+
}
140+
141+
std::shared_ptr<Publisher<DataBatch>> ModelTrainer::AsInferenceBatchPublisher(
142+
std::unique_ptr<data_iterator> test_data, std::vector<int> style_indices,
143+
compute_context* context) {
144+
auto base_iterator = std::make_shared<DataIterator>(
145+
std::move(test_data), /* batch_size */ 1, /* offset */ 0);
146+
auto iterator = std::make_shared<InferenceDataIterator>(
147+
base_iterator, std::move(style_indices));
148+
149+
std::shared_ptr<model_backend> backend = CreateInferenceBackend(context);
150+
auto predict = [backend](EncodedInferenceBatch batch) {
151+
EncodedInferenceBatch result;
152+
result.iteration_id = batch.iteration_id;
153+
result.encoded_data = backend->predict(batch.encoded_data);
154+
result.style_index = batch.style_index;
155+
return result;
156+
};
157+
158+
return iterator->AsPublisher()
159+
->Map(EncodeInferenceBatch)
160+
->Map(predict)
161+
->Map(DecodeInferenceBatch);
162+
}
163+
164+
EncodedBatch EncodeTrainingBatch(DataBatch batch, int width, int height) {
165+
EncodedBatch result;
166+
result.iteration_id = batch.iteration_id;
167+
168+
result.encoded_data = prepare_batch(batch.examples, width, height,
169+
/* train */ true);
170+
171+
return result;
172+
}
173+
174+
EncodedInferenceBatch EncodeInferenceBatch(DataBatch batch) {
175+
EncodedInferenceBatch result;
176+
result.iteration_id = batch.iteration_id;
177+
result.encoded_data = prepare_predict(batch.examples.front());
178+
result.style_index = static_cast<int>(batch.examples.front().style_index);
179+
return result;
180+
}
181+
182+
DataBatch DecodeInferenceBatch(EncodedInferenceBatch batch) {
183+
DataBatch result;
184+
result.iteration_id = batch.iteration_id;
185+
186+
shared_float_array output = batch.encoded_data.at("output");
187+
std::vector<std::pair<flex_int, flex_image>> processed_batch =
188+
process_output(output, batch.style_index);
189+
result.examples.resize(processed_batch.size());
190+
std::transform(processed_batch.begin(), processed_batch.end(),
191+
result.examples.begin(),
192+
[](const std::pair<flex_int, flex_image>& style_and_image) {
193+
st_example example;
194+
example.style_index = style_and_image.first;
195+
example.style_image = style_and_image.second;
196+
return example;
197+
});
198+
199+
return result;
200+
}
201+
202+
} // namespace style_transfer
203+
} // namespace turi

0 commit comments

Comments
 (0)