|
| 1 | +/* Copyright © 2020 Apple Inc. All rights reserved. |
| 2 | + * |
| 3 | + * Use of this source code is governed by a BSD-3-clause license that can |
| 4 | + * be found in the LICENSE.txt file or at |
| 5 | + * https://opensource.org/licenses/BSD-3-Clause |
| 6 | + */ |
| 7 | + |
| 8 | +#include <toolkits/style_transfer/st_model_trainer.hpp> |
| 9 | + |
| 10 | +#include <algorithm> |
| 11 | + |
| 12 | +#include <model_server/lib/image_util.hpp> |
| 13 | +#include <toolkits/style_transfer/style_transfer.hpp> |
| 14 | + |
| 15 | +namespace turi { |
| 16 | +namespace style_transfer { |
| 17 | + |
| 18 | +using neural_net::compute_context; |
| 19 | +using neural_net::float_array_map; |
| 20 | +using neural_net::model_backend; |
| 21 | +using neural_net::model_spec; |
| 22 | +using neural_net::Publisher; |
| 23 | +using neural_net::shared_float_array; |
| 24 | + |
| 25 | +DataBatch DataIterator::Next() { |
| 26 | + DataBatch batch; |
| 27 | + batch.iteration_id = ++last_iteration_id_; |
| 28 | + batch.examples = impl_->next_batch(batch_size_); |
| 29 | + return batch; |
| 30 | +} |
| 31 | + |
| 32 | +InferenceDataIterator::InferenceDataIterator( |
| 33 | + std::shared_ptr<DataIterator> base_iterator, std::vector<int> style_indices) |
| 34 | + : base_iterator_(std::move(base_iterator)), |
| 35 | + style_indices_(std::move(style_indices)), |
| 36 | + next_style_(style_indices_.end()) {} |
| 37 | + |
| 38 | +bool InferenceDataIterator::HasNext() const { |
| 39 | + return next_style_ != style_indices_.end() || base_iterator_->HasNext(); |
| 40 | +} |
| 41 | + |
| 42 | +DataBatch InferenceDataIterator::Next() { |
| 43 | + // If we're done emitting all the styles for the current underlying batch, |
| 44 | + // fetch the next batch from the underlying data iterator. |
| 45 | + if (next_style_ == style_indices_.end() && base_iterator_->HasNext()) { |
| 46 | + current_batch_ = base_iterator_->Next(); |
| 47 | + next_style_ = style_indices_.begin(); |
| 48 | + } |
| 49 | + |
| 50 | + // Write the next style index into all the images in the current batch. |
| 51 | + if (next_style_ != style_indices_.end()) { |
| 52 | + for (st_example& example : current_batch_.examples) { |
| 53 | + example.style_index = *next_style_; |
| 54 | + } |
| 55 | + ++next_style_; |
| 56 | + } |
| 57 | + |
| 58 | + return current_batch_; |
| 59 | +} |
| 60 | + |
| 61 | +TrainingProgress ProgressUpdater::Invoke(EncodedBatch batch) { |
| 62 | + auto reduce = [](const shared_float_array& array) { |
| 63 | + float loss = std::accumulate(array.data(), array.data() + array.size(), 0.f, |
| 64 | + std::plus<float>()); |
| 65 | + loss /= array.size(); |
| 66 | + return loss; |
| 67 | + }; |
| 68 | + |
| 69 | + // Compute the loss for this batch. |
| 70 | + float batch_loss = reduce(batch.encoded_data.at("loss")); |
| 71 | + |
| 72 | + // Update our rolling average (smoothed) loss. |
| 73 | + if (smoothed_loss_) { |
| 74 | + *smoothed_loss_ *= 0.9f; |
| 75 | + *smoothed_loss_ += 0.1f * batch_loss; |
| 76 | + } else { |
| 77 | + // Initialize smoothed loss to the first loss value. |
| 78 | + smoothed_loss_.reset(new float(batch_loss)); |
| 79 | + } |
| 80 | + |
| 81 | + // Write smoothed loss into the result. |
| 82 | + TrainingProgress progress; |
| 83 | + progress.iteration_id = batch.iteration_id; |
| 84 | + progress.smoothed_loss = *smoothed_loss_; |
| 85 | + |
| 86 | + // Write optional loss components into the result. |
| 87 | + auto style_loss_it = batch.encoded_data.find("style_loss"); |
| 88 | + if (style_loss_it != batch.encoded_data.end()) { |
| 89 | + progress.style_loss = reduce(style_loss_it->second); |
| 90 | + } |
| 91 | + auto content_loss_it = batch.encoded_data.find("content_loss"); |
| 92 | + if (content_loss_it != batch.encoded_data.end()) { |
| 93 | + progress.content_loss = reduce(content_loss_it->second); |
| 94 | + } |
| 95 | + |
| 96 | + return progress; |
| 97 | +} |
| 98 | + |
| 99 | +// static |
| 100 | +float_array_map Checkpoint::ExtractWeights( |
| 101 | + std::unique_ptr<model_spec> nn_spec) { |
| 102 | + float_array_map result = nn_spec->export_params_view(); |
| 103 | + for (auto& name_and_weights : result) { |
| 104 | + // The original values will not be valid once the nn_spec is deconstructed. |
| 105 | + // TODO: Ideally this would not require copying. But we should move away |
| 106 | + // from using the protocol buffer as our primary representation anyway. |
| 107 | + name_and_weights.second = shared_float_array::copy(name_and_weights.second); |
| 108 | + } |
| 109 | + return result; |
| 110 | +} |
| 111 | + |
| 112 | +std::shared_ptr<Publisher<TrainingProgress>> |
| 113 | +ModelTrainer::AsTrainingBatchPublisher( |
| 114 | + std::unique_ptr<data_iterator> training_data, |
| 115 | + const std::string& vgg_mlmodel_path, int offset, |
| 116 | + std::unique_ptr<float> initial_training_loss, compute_context* context) { |
| 117 | + auto iterator = std::make_shared<DataIterator>(std::move(training_data), |
| 118 | + config_.batch_size, offset); |
| 119 | + |
| 120 | + int height = config_.training_image_height; |
| 121 | + int width = config_.training_image_width; |
| 122 | + auto encode = [height, width](DataBatch batch) { |
| 123 | + return EncodeTrainingBatch(std::move(batch), width, height); |
| 124 | + }; |
| 125 | + |
| 126 | + std::shared_ptr<model_backend> backend = |
| 127 | + CreateTrainingBackend(vgg_mlmodel_path, context); |
| 128 | + auto train = [backend](EncodedBatch batch) { |
| 129 | + EncodedBatch result; |
| 130 | + result.iteration_id = batch.iteration_id; |
| 131 | + result.encoded_data = backend->train(batch.encoded_data); |
| 132 | + return result; |
| 133 | + }; |
| 134 | + |
| 135 | + auto update_progress = |
| 136 | + std::make_shared<ProgressUpdater>(std::move(initial_training_loss)); |
| 137 | + |
| 138 | + return iterator->AsPublisher()->Map(encode)->Map(train)->Map(update_progress); |
| 139 | +} |
| 140 | + |
| 141 | +std::shared_ptr<Publisher<DataBatch>> ModelTrainer::AsInferenceBatchPublisher( |
| 142 | + std::unique_ptr<data_iterator> test_data, std::vector<int> style_indices, |
| 143 | + compute_context* context) { |
| 144 | + auto base_iterator = std::make_shared<DataIterator>( |
| 145 | + std::move(test_data), /* batch_size */ 1, /* offset */ 0); |
| 146 | + auto iterator = std::make_shared<InferenceDataIterator>( |
| 147 | + base_iterator, std::move(style_indices)); |
| 148 | + |
| 149 | + std::shared_ptr<model_backend> backend = CreateInferenceBackend(context); |
| 150 | + auto predict = [backend](EncodedInferenceBatch batch) { |
| 151 | + EncodedInferenceBatch result; |
| 152 | + result.iteration_id = batch.iteration_id; |
| 153 | + result.encoded_data = backend->predict(batch.encoded_data); |
| 154 | + result.style_index = batch.style_index; |
| 155 | + return result; |
| 156 | + }; |
| 157 | + |
| 158 | + return iterator->AsPublisher() |
| 159 | + ->Map(EncodeInferenceBatch) |
| 160 | + ->Map(predict) |
| 161 | + ->Map(DecodeInferenceBatch); |
| 162 | +} |
| 163 | + |
| 164 | +EncodedBatch EncodeTrainingBatch(DataBatch batch, int width, int height) { |
| 165 | + EncodedBatch result; |
| 166 | + result.iteration_id = batch.iteration_id; |
| 167 | + |
| 168 | + result.encoded_data = prepare_batch(batch.examples, width, height, |
| 169 | + /* train */ true); |
| 170 | + |
| 171 | + return result; |
| 172 | +} |
| 173 | + |
| 174 | +EncodedInferenceBatch EncodeInferenceBatch(DataBatch batch) { |
| 175 | + EncodedInferenceBatch result; |
| 176 | + result.iteration_id = batch.iteration_id; |
| 177 | + result.encoded_data = prepare_predict(batch.examples.front()); |
| 178 | + result.style_index = static_cast<int>(batch.examples.front().style_index); |
| 179 | + return result; |
| 180 | +} |
| 181 | + |
| 182 | +DataBatch DecodeInferenceBatch(EncodedInferenceBatch batch) { |
| 183 | + DataBatch result; |
| 184 | + result.iteration_id = batch.iteration_id; |
| 185 | + |
| 186 | + shared_float_array output = batch.encoded_data.at("output"); |
| 187 | + std::vector<std::pair<flex_int, flex_image>> processed_batch = |
| 188 | + process_output(output, batch.style_index); |
| 189 | + result.examples.resize(processed_batch.size()); |
| 190 | + std::transform(processed_batch.begin(), processed_batch.end(), |
| 191 | + result.examples.begin(), |
| 192 | + [](const std::pair<flex_int, flex_image>& style_and_image) { |
| 193 | + st_example example; |
| 194 | + example.style_index = style_and_image.first; |
| 195 | + example.style_image = style_and_image.second; |
| 196 | + return example; |
| 197 | + }); |
| 198 | + |
| 199 | + return result; |
| 200 | +} |
| 201 | + |
| 202 | +} // namespace style_transfer |
| 203 | +} // namespace turi |
0 commit comments