Skip to content
This repository was archived by the owner on Dec 21, 2023. It is now read-only.

Commit 7dfa7d8

Browse files
authored
Integrate new Combine-like framework into OD training (#3006)
Also begins encapsulating the darknet-yolo-specific bits. However, the inference path and CoreML export still need to be refactored
1 parent 1f007de commit 7dfa7d8

File tree

12 files changed

+1173
-477
lines changed

12 files changed

+1173
-477
lines changed

src/toolkits/object_detection/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ make_library(unity_object_detection OBJECT
44
SOURCES
55
class_registrations.cpp
66
object_detector.cpp
7+
od_darknet_yolo_model.cpp
78
od_data_iterator.cpp
8-
od_yolo.cpp
99
od_evaluation.cpp
10+
od_model.cpp
1011
od_serialization.cpp
12+
od_yolo.cpp
1113
REQUIRES
1214
image_io
1315
random

src/toolkits/object_detection/object_detector.cpp

Lines changed: 141 additions & 318 deletions
Large diffs are not rendered by default.

src/toolkits/object_detection/object_detector.hpp

Lines changed: 41 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,18 @@
1010
#include <functional>
1111
#include <map>
1212
#include <memory>
13+
#include <queue>
1314

14-
#include <core/logging/table_printer/table_printer.hpp>
15-
#include <model_server/lib/extensions/ml_model.hpp>
1615
#include <core/data/sframe/gl_sframe.hpp>
17-
#include <toolkits/coreml_export/mlmodel_wrapper.hpp>
16+
#include <core/logging/table_printer/table_printer.hpp>
1817
#include <ml/neural_net/compute_context.hpp>
1918
#include <ml/neural_net/image_augmentation.hpp>
2019
#include <ml/neural_net/model_backend.hpp>
2120
#include <ml/neural_net/model_spec.hpp>
21+
#include <model_server/lib/extensions/ml_model.hpp>
22+
#include <toolkits/coreml_export/mlmodel_wrapper.hpp>
2223
#include <toolkits/object_detection/od_data_iterator.hpp>
24+
#include <toolkits/object_detection/od_model.hpp>
2325

2426
namespace turi {
2527
namespace object_detection {
@@ -45,7 +47,7 @@ class EXPORT object_detector: public ml_model_base {
4547
std::map<std::string, flexible_type> opts);
4648
variant_type predict(variant_type data,
4749
std::map<std::string, flexible_type> opts);
48-
std::shared_ptr<coreml::MLModelWrapper> export_to_coreml(
50+
virtual std::shared_ptr<coreml::MLModelWrapper> export_to_coreml(
4951
std::string filename, std::string short_description,
5052
std::map<std::string, flexible_type> additional_user_defined,
5153
std::map<std::string, flexible_type> opts);
@@ -155,23 +157,19 @@ class EXPORT object_detector: public ml_model_base {
155157
END_CLASS_MEMBER_REGISTRATION
156158

157159
protected:
158-
// Constructor allowing tests to set the initial state of this class and to
159-
// inject dependencies.
160-
object_detector(
161-
const std::map<std::string, variant_type>& initial_state,
162-
std::unique_ptr<neural_net::model_spec> nn_spec,
163-
std::unique_ptr<neural_net::compute_context> training_compute_context,
164-
std::unique_ptr<data_iterator> training_data_iterator,
165-
std::unique_ptr<neural_net::image_augmenter> training_data_augmenter,
166-
std::unique_ptr<neural_net::model_backend> training_model)
167-
: nn_spec_(std::move(nn_spec)),
168-
training_compute_context_(std::move(training_compute_context)),
169-
training_data_iterator_(std::move(training_data_iterator)),
170-
training_data_augmenter_(std::move(training_data_augmenter)),
171-
training_model_(std::move(training_model)) {
172-
add_or_update_state(initial_state);
160+
// Constructor allowing tests to set the initial state of this class.
161+
object_detector(std::map<std::string, variant_type> initial_state,
162+
neural_net::float_array_map initial_weights) {
163+
load(std::move(initial_state), std::move(initial_weights));
173164
}
174165

166+
// Resets the internal state. Used by deserialization code and unit tests.
167+
void load(std::map<std::string, variant_type> state,
168+
neural_net::float_array_map weights);
169+
170+
// Synchronously loads weights from the backend if necessary.
171+
Checkpoint* read_checkpoint() const;
172+
175173
// Override points allowing subclasses to inject dependencies
176174

177175
// Factory for data_iterator
@@ -186,12 +184,19 @@ class EXPORT object_detector: public ml_model_base {
186184
virtual
187185
std::unique_ptr<neural_net::compute_context> create_compute_context() const;
188186

189-
// Returns the initial neural network to train (represented by its CoreML
190-
// spec), given the path to a mlmodel file containing the pretrained weights.
191-
virtual std::unique_ptr<neural_net::model_spec> init_model(
192-
const std::string& pretrained_mlmodel_path, size_t num_classes) const;
187+
// Factories for Model
188+
virtual std::unique_ptr<Model> create_model(
189+
const Checkpoint& checkpoint,
190+
std::unique_ptr<neural_net::compute_context> context) const;
191+
virtual std::unique_ptr<Model> create_model(
192+
const Config& config, const std::string& pretrained_model_path,
193+
int random_seed,
194+
std::unique_ptr<neural_net::compute_context> context) const;
193195

194-
void init_training_backend();
196+
// Establishes training pipelines from the backend.
197+
void connect_training_backend(std::unique_ptr<Model> backend,
198+
std::unique_ptr<data_iterator> iterator,
199+
int batch_size);
195200

196201
virtual std::vector<neural_net::image_annotation> convert_yolo_to_annotations(
197202
const neural_net::float_array& yolo_map,
@@ -219,12 +224,9 @@ class EXPORT object_detector: public ml_model_base {
219224
}
220225

221226
private:
227+
neural_net::float_array_map strip_fwd(
228+
const neural_net::float_array_map& params) const;
222229

223-
neural_net::float_array_map get_model_params() const;
224-
225-
neural_net::shared_float_array prepare_label_batch(
226-
std::vector<std::vector<neural_net::image_annotation>> annotations_batch)
227-
const;
228230
flex_int get_max_iterations() const;
229231
flex_int get_training_iterations() const;
230232
flex_int get_num_classes() const;
@@ -236,35 +238,32 @@ class EXPORT object_detector: public ml_model_base {
236238
const std::string& column_name);
237239

238240
// Sets certain user options heuristically (from the data).
239-
void infer_derived_options();
241+
void infer_derived_options(neural_net::compute_context* context,
242+
data_iterator* iterator);
240243

241244
// Waits until the number of pending patches is at most `max_pending`.
242245
void wait_for_training_batches(size_t max_pending = 0);
243246

244-
// Ensures that the local copy of the model weights are in sync with the
245-
// training backend.
246-
void synchronize_model(neural_net::model_spec* nn_spec) const;
247-
248247
// Computes and records training/validation metrics.
249248
void update_model_metrics(gl_sframe data, gl_sframe validation_data);
250249

251-
// Primary representation for the trained model.
252-
std::unique_ptr<neural_net::model_spec> nn_spec_;
250+
// Primary representation for the trained model. Can be null if the model has
251+
// been updated since the last checkpoint.
252+
mutable std::unique_ptr<Checkpoint> checkpoint_;
253253

254254
// Primary dependencies for training. These should be nonnull while training
255255
// is in progress.
256256
gl_sframe training_data_; // TODO: Avoid storing gl_sframe AND data_iterator.
257257
gl_sframe validation_data_;
258-
std::unique_ptr<neural_net::compute_context> training_compute_context_;
259-
std::unique_ptr<data_iterator> training_data_iterator_;
260-
std::unique_ptr<neural_net::image_augmenter> training_data_augmenter_;
261-
std::unique_ptr<neural_net::model_backend> training_model_;
258+
std::shared_ptr<neural_net::FuturesStream<TrainingOutputBatch>>
259+
training_futures_;
260+
std::shared_ptr<neural_net::FuturesStream<Checkpoint>> checkpoint_futures_;
262261

263262
// Nonnull while training is in progress, if progress printing is enabled.
264263
std::unique_ptr<table_printer> training_table_printer_;
265264

266-
// Map from iteration index to the loss future.
267-
std::map<size_t, neural_net::shared_float_array> pending_training_batches_;
265+
std::queue<std::future<std::unique_ptr<TrainingOutputBatch>>>
266+
pending_training_batches_;
268267

269268
struct inference_batch : neural_net::image_augmenter::result {
270269
std::vector<std::pair<float, float>> image_dimensions_batch;

0 commit comments

Comments
 (0)