Skip to content
This repository was archived by the owner on Dec 21, 2023. It is now read-only.

Commit 1f007de

Browse files
authored
Introduce a minimal reactive streams framework, inspired somewhat by Swift's Combine (#3000)
The main goal is make it easier to reason about the various portions of our deep-learning training pipelines. For example, in Object Detection, we have a data iterator that produces raw training examples, an image augmenter that performs data augmentation, and a model backend that consumes a stream of augmented training examples and produces progress updates and models. As we integrate this framework, it should be easier to maintain unit tests for these evolving pipelines and to manipulate threading and concurrency. For example, if done right, this can replace our implementations of "double buffering" currently copied across each model toolkit.
1 parent c8b7be6 commit 1f007de

11 files changed

+1689
-5
lines changed

src/ml/neural_net/combine.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/* Copyright © 2020 Apple Inc. All rights reserved.
2+
*
3+
* Use of this source code is governed by a BSD-3-clause license that can
4+
* be found in the LICENSE.txt file or at
5+
* https://opensource.org/licenses/BSD-3-Clause
6+
*/
7+
8+
#ifndef ML_NEURAL_NET_COMBINE_HPP_
9+
#define ML_NEURAL_NET_COMBINE_HPP_
10+
11+
/**
12+
* \file combine.hpp
13+
*
14+
* Defines a reactive-streams library inspired by the Swift Combine
15+
* framework. Its intent is to simplify reasoning about and testing of NN
16+
* model-training pipelines.
17+
*/
18+
19+
#include <ml/neural_net/combine_base.hpp>
20+
#include <ml/neural_net/combine_futures_subscriber.hpp>
21+
#include <ml/neural_net/combine_iterator.hpp>
22+
#include <ml/neural_net/combine_map.hpp>
23+
24+
#endif // ML_NEURAL_NET_COMBINE_HPP_

src/ml/neural_net/combine_base.hpp

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
/* Copyright © 2020 Apple Inc. All rights reserved.
2+
*
3+
* Use of this source code is governed by a BSD-3-clause license that can
4+
* be found in the LICENSE.txt file or at
5+
* https://opensource.org/licenses/BSD-3-Clause
6+
*/
7+
8+
#ifndef ML_NEURAL_NET_COMBINE_BASE_HPP_
9+
#define ML_NEURAL_NET_COMBINE_BASE_HPP_
10+
11+
/**
12+
* \file combine_base.hpp
13+
*
14+
* Defines the core data types for a reactive-streams library inspired by the
15+
* Swift Combine framework. Client code should generally import combine.hpp.
16+
*/
17+
18+
#include <exception>
19+
#include <memory>
20+
21+
namespace turi {
22+
namespace neural_net {
23+
24+
// Forward declarations for types defined by other headers included by
25+
// combine.hpp.
26+
27+
template <typename T, typename Callable>
28+
class CallableTransform;
29+
30+
template <typename T>
31+
class FuturesStream;
32+
33+
template <typename T>
34+
class FuturesSubscriber;
35+
36+
template <typename T, typename U>
37+
class MapPublisher;
38+
39+
template <typename T, typename U>
40+
class Transform;
41+
42+
/**
43+
* Simple type expressing how many values a Subscriber is ready to receive
44+
* from its Publisher.
45+
*/
46+
class Demand {
47+
public:
48+
static Demand Unlimited() { return Demand(-1); }
49+
static Demand None() { return Demand(0); }
50+
51+
/** Any negative value is interpreted as "unlimited". */
52+
explicit Demand(int max) : max_(max) {}
53+
54+
bool IsUnlimited() const { return max_ < 0; }
55+
bool IsNone() const { return max_ == 0; }
56+
57+
/** Returns a negative number to indicate "unlimited." */
58+
int max() const { return max_; }
59+
60+
/** Additively combines another Demand value into this one. */
61+
Demand& Add(Demand other) {
62+
if (IsUnlimited() || other.IsUnlimited()) {
63+
max_ = -1;
64+
} else {
65+
max_ += other.max_;
66+
}
67+
return *this;
68+
}
69+
70+
/** Decrease this demand by one if the current max is positive and finite. */
71+
Demand& Decrement() {
72+
if (max_ > 0) {
73+
--max_;
74+
}
75+
return *this;
76+
}
77+
78+
private:
79+
int max_ = 0;
80+
};
81+
82+
/**
83+
* Interface for objects that Publishers send to Subscribers to allow the
84+
* Subscribers to (potentially asynchronously) control the flow of values that
85+
* the Subscriber receives from the Publisher.
86+
*/
87+
class Subscription {
88+
public:
89+
virtual ~Subscription() = default;
90+
91+
/**
92+
* Requests the Publisher to stop sending anything to the Subscriber.
93+
*
94+
* After receiving Cancel() from a Subscriber, a Publisher should thereafter
95+
* ignore all future messages from that Subscriber, including future calls to
96+
* Cancel.
97+
*
98+
* Publishers must support Subscribers calling Cancel() from inside
99+
* Subscriber::Receive(...).
100+
*/
101+
virtual void Cancel() = 0;
102+
103+
/**
104+
* Requests the Publisher to send the indicated number of values to the
105+
* Subscriber.
106+
*
107+
* Publishers must support Subscribers calling Request(Demand) from inside
108+
* Subscriber::Receive(Subscription), but Subscribers should avoid calling
109+
* Request(Demand) inside Subscriber::Receive(Input). Instead, they should
110+
* send additional Demand via the return value of Subscriber::Receive(Input)
111+
* (to help prevent infinite recursion).
112+
*/
113+
virtual void Request(Demand demand) = 0;
114+
};
115+
116+
/**
117+
* Type representing a message from a Publisher to a Subscriber indicating that
118+
* the Subscriber will no longer receive any further messages.
119+
*/
120+
class Completion {
121+
public:
122+
/** Returns an instance that signals successful completion. */
123+
static Completion Finished() { return Completion(); }
124+
125+
/**
126+
* Returns an instance that signals failure, described by the given
127+
* exception.
128+
*/
129+
static Completion Failure(std::exception_ptr e) { return Completion(e); }
130+
131+
bool IsFinished() const { return failure_ == nullptr; }
132+
133+
/** Returns the exception if a failure and a null pointer otherwise. */
134+
std::exception_ptr failure() const { return failure_; }
135+
136+
private:
137+
explicit Completion(std::exception_ptr e = nullptr) : failure_(e) {}
138+
139+
std::exception_ptr failure_;
140+
};
141+
142+
/**
143+
* Interface for objects that consume values from a Publisher.
144+
*
145+
* Unless otherwise specified by the concrete implementation, external
146+
* synchronization must be used to avoid concurrent calls the Subscriber
147+
* interface from different threads.
148+
*/
149+
template <typename T>
150+
class Subscriber {
151+
public:
152+
/** The type of the values that this Subscriber consumes. */
153+
using Input = T;
154+
155+
virtual ~Subscriber() = default;
156+
157+
/**
158+
* The first signal that a Subscriber receives from a Publisher, passing the
159+
* Subscription that the Subscriber can use to control the flow of values.
160+
*
161+
* A Subscriber may only have one Publisher. If it somehow receives more than
162+
* one Subscription, it should call Subscription::Cancel() on any instances
163+
* received after the first.
164+
*
165+
* A Subscriber is explictly allowed to demand values synchronously from
166+
* within its implementation of this method.
167+
*/
168+
virtual void Receive(std::shared_ptr<Subscription> subscription) = 0;
169+
170+
/**
171+
* Transmits a value from the Publisher to this Subscriber.
172+
*
173+
* A Subcriber should never receive more calls to this method than the total
174+
* Demand it has requested from its publisher. Subscribers should only demand
175+
* more elements from within this method via its return value.
176+
*/
177+
virtual Demand Receive(Input element) = 0;
178+
179+
/**
180+
* Signals completion of the stream of values from the Publisher.
181+
*
182+
* A Subscriber should not receive any further signals of any kind after
183+
* receiving a Completion.
184+
*/
185+
virtual void Receive(Completion completion) = 0;
186+
};
187+
188+
/**
189+
* Interface for objects that produce values on demand from its Subscribers.
190+
*
191+
* Unless otherwise specified by the concrete implementation, external
192+
* synchronization must be used to avoid concurrent calls on multiple threads to
193+
* a Publisher, including via the Subscriptions that it passes to its
194+
* Subscribers.
195+
*
196+
* Each concrete implementation defines whether it is unicast or multicast:
197+
* whether multiple Subscribers observe the same values or not. (An
198+
* implementation might only support one Subscriber, by passing an immediate
199+
* Completion to each Subscriber after the first.)
200+
*
201+
* Note: instances of this class are intended to be stored using shared_ptr.
202+
* Many of the operators rely on generating strong references to the instance
203+
* being augmented.
204+
*/
205+
template <typename T>
206+
class Publisher : public std::enable_shared_from_this<Publisher<T>> {
207+
public:
208+
/** The type of values that this Publisher produces. */
209+
using Output = T;
210+
211+
virtual ~Publisher() = default;
212+
213+
/**
214+
* Establishes a connection between this Publisher and the given Subcriber.
215+
*
216+
* The Publisher must eventually call Subscriber::Receive(Subscription) on the
217+
* given Subscriber (and may do so synchronously). The Publisher must then
218+
* conform to the protocol established by the Subscription.
219+
*/
220+
virtual void Receive(std::shared_ptr<Subscriber<Output>> subscriber) = 0;
221+
222+
// Convenienience methods, supporting the chaining together of operations.
223+
// Many of these rely on the forward declarations above. Client code should
224+
// include combine.hpp to ensure these are defined before they are used.
225+
226+
void Subscribe(std::shared_ptr<Subscriber<Output>> subscriber) {
227+
Receive(std::move(subscriber));
228+
}
229+
230+
std::shared_ptr<FuturesStream<Output>> AsFutures() {
231+
auto subscriber = std::make_shared<FuturesSubscriber<Output>>();
232+
Subscribe(subscriber);
233+
return std::make_shared<FuturesStream<Output>>(std::move(subscriber));
234+
}
235+
236+
template <typename TransformType>
237+
std::shared_ptr<Publisher<typename TransformType::Output>> Map(
238+
std::shared_ptr<TransformType> transform) {
239+
using TransformInput = typename TransformType::Input;
240+
using TransformOutput = typename TransformType::Output;
241+
return std::make_shared<MapPublisher<TransformInput, TransformOutput>>(
242+
this->shared_from_this(), std::move(transform));
243+
}
244+
245+
template <typename Callable>
246+
std::shared_ptr<Publisher<typename std::result_of<Callable(Output)>::type>>
247+
Map(Callable fn) {
248+
using TransformType = CallableTransform<Output, Callable>;
249+
auto transform = std::make_shared<TransformType>(std::move(fn));
250+
return Map(std::move(transform));
251+
}
252+
};
253+
254+
} // namespace neural_net
255+
} // namespace turi
256+
257+
#endif // ML_NEURAL_NET_COMBINE_BASE_HPP_

0 commit comments

Comments
 (0)