10
10
#include < functional>
11
11
#include < map>
12
12
#include < memory>
13
+ #include < queue>
13
14
14
- #include < core/logging/table_printer/table_printer.hpp>
15
- #include < model_server/lib/extensions/ml_model.hpp>
16
15
#include < core/data/sframe/gl_sframe.hpp>
17
- #include < toolkits/coreml_export/mlmodel_wrapper .hpp>
16
+ #include < core/logging/table_printer/table_printer .hpp>
18
17
#include < ml/neural_net/compute_context.hpp>
19
18
#include < ml/neural_net/image_augmentation.hpp>
20
19
#include < ml/neural_net/model_backend.hpp>
21
20
#include < ml/neural_net/model_spec.hpp>
21
+ #include < model_server/lib/extensions/ml_model.hpp>
22
+ #include < toolkits/coreml_export/mlmodel_wrapper.hpp>
22
23
#include < toolkits/object_detection/od_data_iterator.hpp>
24
+ #include < toolkits/object_detection/od_model.hpp>
23
25
24
26
namespace turi {
25
27
namespace object_detection {
@@ -45,7 +47,7 @@ class EXPORT object_detector: public ml_model_base {
45
47
std::map<std::string, flexible_type> opts);
46
48
variant_type predict (variant_type data,
47
49
std::map<std::string, flexible_type> opts);
48
- std::shared_ptr<coreml::MLModelWrapper> export_to_coreml (
50
+ virtual std::shared_ptr<coreml::MLModelWrapper> export_to_coreml (
49
51
std::string filename, std::string short_description,
50
52
std::map<std::string, flexible_type> additional_user_defined,
51
53
std::map<std::string, flexible_type> opts);
@@ -155,23 +157,19 @@ class EXPORT object_detector: public ml_model_base {
155
157
END_CLASS_MEMBER_REGISTRATION
156
158
157
159
protected:
158
- // Constructor allowing tests to set the initial state of this class and to
159
- // inject dependencies.
160
- object_detector (
161
- const std::map<std::string, variant_type>& initial_state,
162
- std::unique_ptr<neural_net::model_spec> nn_spec,
163
- std::unique_ptr<neural_net::compute_context> training_compute_context,
164
- std::unique_ptr<data_iterator> training_data_iterator,
165
- std::unique_ptr<neural_net::image_augmenter> training_data_augmenter,
166
- std::unique_ptr<neural_net::model_backend> training_model)
167
- : nn_spec_(std::move(nn_spec)),
168
- training_compute_context_ (std::move(training_compute_context)),
169
- training_data_iterator_(std::move(training_data_iterator)),
170
- training_data_augmenter_(std::move(training_data_augmenter)),
171
- training_model_(std::move(training_model)) {
172
- add_or_update_state (initial_state);
160
+ // Constructor allowing tests to set the initial state of this class.
161
+ object_detector (std::map<std::string, variant_type> initial_state,
162
+ neural_net::float_array_map initial_weights) {
163
+ load (std::move (initial_state), std::move (initial_weights));
173
164
}
174
165
166
+ // Resets the internal state. Used by deserialization code and unit tests.
167
+ void load (std::map<std::string, variant_type> state,
168
+ neural_net::float_array_map weights);
169
+
170
+ // Synchronously loads weights from the backend if necessary.
171
+ Checkpoint* read_checkpoint () const ;
172
+
175
173
// Override points allowing subclasses to inject dependencies
176
174
177
175
// Factory for data_iterator
@@ -186,12 +184,19 @@ class EXPORT object_detector: public ml_model_base {
186
184
virtual
187
185
std::unique_ptr<neural_net::compute_context> create_compute_context () const ;
188
186
189
- // Returns the initial neural network to train (represented by its CoreML
190
- // spec), given the path to a mlmodel file containing the pretrained weights.
191
- virtual std::unique_ptr<neural_net::model_spec> init_model (
192
- const std::string& pretrained_mlmodel_path, size_t num_classes) const ;
187
+ // Factories for Model
188
+ virtual std::unique_ptr<Model> create_model (
189
+ const Checkpoint& checkpoint,
190
+ std::unique_ptr<neural_net::compute_context> context) const ;
191
+ virtual std::unique_ptr<Model> create_model (
192
+ const Config& config, const std::string& pretrained_model_path,
193
+ int random_seed,
194
+ std::unique_ptr<neural_net::compute_context> context) const ;
193
195
194
- void init_training_backend ();
196
+ // Establishes training pipelines from the backend.
197
+ void connect_training_backend (std::unique_ptr<Model> backend,
198
+ std::unique_ptr<data_iterator> iterator,
199
+ int batch_size);
195
200
196
201
virtual std::vector<neural_net::image_annotation> convert_yolo_to_annotations (
197
202
const neural_net::float_array& yolo_map,
@@ -219,12 +224,9 @@ class EXPORT object_detector: public ml_model_base {
219
224
}
220
225
221
226
private:
227
+ neural_net::float_array_map strip_fwd (
228
+ const neural_net::float_array_map& params) const ;
222
229
223
- neural_net::float_array_map get_model_params () const ;
224
-
225
- neural_net::shared_float_array prepare_label_batch (
226
- std::vector<std::vector<neural_net::image_annotation>> annotations_batch)
227
- const ;
228
230
flex_int get_max_iterations () const ;
229
231
flex_int get_training_iterations () const ;
230
232
flex_int get_num_classes () const ;
@@ -236,35 +238,32 @@ class EXPORT object_detector: public ml_model_base {
236
238
const std::string& column_name);
237
239
238
240
// Sets certain user options heuristically (from the data).
239
- void infer_derived_options ();
241
+ void infer_derived_options (neural_net::compute_context* context,
242
+ data_iterator* iterator);
240
243
241
244
// Waits until the number of pending patches is at most `max_pending`.
242
245
void wait_for_training_batches (size_t max_pending = 0 );
243
246
244
- // Ensures that the local copy of the model weights are in sync with the
245
- // training backend.
246
- void synchronize_model (neural_net::model_spec* nn_spec) const ;
247
-
248
247
// Computes and records training/validation metrics.
249
248
void update_model_metrics (gl_sframe data, gl_sframe validation_data);
250
249
251
- // Primary representation for the trained model.
252
- std::unique_ptr<neural_net::model_spec> nn_spec_;
250
+ // Primary representation for the trained model. Can be null if the model has
251
+ // been updated since the last checkpoint.
252
+ mutable std::unique_ptr<Checkpoint> checkpoint_;
253
253
254
254
// Primary dependencies for training. These should be nonnull while training
255
255
// is in progress.
256
256
gl_sframe training_data_; // TODO: Avoid storing gl_sframe AND data_iterator.
257
257
gl_sframe validation_data_;
258
- std::unique_ptr<neural_net::compute_context> training_compute_context_;
259
- std::unique_ptr<data_iterator> training_data_iterator_;
260
- std::unique_ptr<neural_net::image_augmenter> training_data_augmenter_;
261
- std::unique_ptr<neural_net::model_backend> training_model_;
258
+ std::shared_ptr<neural_net::FuturesStream<TrainingOutputBatch>>
259
+ training_futures_;
260
+ std::shared_ptr<neural_net::FuturesStream<Checkpoint>> checkpoint_futures_;
262
261
263
262
// Nonnull while training is in progress, if progress printing is enabled.
264
263
std::unique_ptr<table_printer> training_table_printer_;
265
264
266
- // Map from iteration index to the loss future.
267
- std::map< size_t , neural_net::shared_float_array> pending_training_batches_;
265
+ std::queue<std:: future<std::unique_ptr<TrainingOutputBatch>>>
266
+ pending_training_batches_;
268
267
269
268
struct inference_batch : neural_net::image_augmenter::result {
270
269
std::vector<std::pair<float , float >> image_dimensions_batch;
0 commit comments