1
1
// Copyright (c) The future-queue Contributors
2
2
// SPDX-License-Identifier: MIT OR Apache-2.0
3
3
4
- use future_queue:: StreamExt as _;
5
- use futures:: { stream, StreamExt as _} ;
4
+ use future_queue:: { traits :: WeightedFuture , FutureQueue , StreamExt as _} ;
5
+ use futures:: { future :: BoxFuture , stream, Future , FutureExt , Stream , StreamExt as _} ;
6
6
use proptest:: prelude:: * ;
7
7
use proptest_derive:: Arbitrary ;
8
- use std:: time:: Duration ;
8
+ use std:: { pin :: Pin , time:: Duration } ;
9
9
use tokio_stream:: wrappers:: UnboundedReceiverStream ;
10
10
11
11
#[ derive( Clone , Debug , Arbitrary ) ]
12
- struct TestState {
12
+ struct TestState < G : GroupSpec > {
13
13
#[ proptest( strategy = "1usize..64" ) ]
14
14
max_weight : usize ,
15
15
#[ proptest( strategy = "prop::collection::vec(TestFutureDesc::arbitrary(), 0..512usize)" ) ]
16
- future_descriptions : Vec < TestFutureDesc > ,
16
+ future_descriptions : Vec < TestFutureDesc < G > > ,
17
17
}
18
18
19
19
#[ derive( Copy , Clone , Debug , Arbitrary ) ]
20
- struct TestFutureDesc {
20
+ struct TestFutureDesc < G : GroupSpec > {
21
21
#[ proptest( strategy = "duration_strategy()" ) ]
22
22
start_delay : Duration ,
23
23
#[ proptest( strategy = "duration_strategy()" ) ]
24
24
delay : Duration ,
25
25
#[ proptest( strategy = "0usize..8" ) ]
26
26
weight : usize ,
27
+ #[ allow( dead_code) ]
28
+ group : G ,
27
29
}
28
30
29
31
fn duration_strategy ( ) -> BoxedStrategy < Duration > {
30
32
// Allow for a delay between 0ms and 1000ms uniformly at random.
31
33
( 0u64 ..1000 ) . prop_map ( Duration :: from_millis) . boxed ( )
32
34
}
33
35
36
+ trait GroupSpec : Arbitrary + Send + Copy + ' static {
37
+ type Item : Send ;
38
+ type CheckState : Default ;
39
+
40
+ fn create_stream < St > ( stream : St , state : & TestState < Self > ) -> BoxedWeightedStream < ( ) >
41
+ where
42
+ St : Stream < Item = Self :: Item > + Send + ' static ;
43
+
44
+ fn create_stream_item (
45
+ desc : & TestFutureDesc < Self > ,
46
+ future : impl Future < Output = ( ) > + Send + ' static ,
47
+ ) -> Self :: Item ;
48
+
49
+ fn check_started (
50
+ check_state : & mut Self :: CheckState ,
51
+ desc : & TestFutureDesc < Self > ,
52
+ state : & TestState < Self > ,
53
+ ) ;
54
+
55
+ fn check_finished (
56
+ check_state : & mut Self :: CheckState ,
57
+ desc : & TestFutureDesc < Self > ,
58
+ state : & TestState < Self > ,
59
+ ) ;
60
+ }
61
+
62
+ trait WeightedStream : Stream {
63
+ fn current_weight ( & self ) -> usize ;
64
+ }
65
+
66
+ impl < St , Fut > WeightedStream for FutureQueue < St >
67
+ where
68
+ St : Stream < Item = Fut > ,
69
+ Fut : WeightedFuture ,
70
+ {
71
+ fn current_weight ( & self ) -> usize {
72
+ self . current_weight ( )
73
+ }
74
+ }
75
+
76
+ type BoxedWeightedStream < Item > = Pin < Box < dyn WeightedStream < Item = Item > + Send > > ;
77
+
78
+ impl GroupSpec for ( ) {
79
+ type Item = ( usize , BoxFuture < ' static , ( ) > ) ;
80
+ type CheckState = NonGroupedCheckState ;
81
+
82
+ fn create_stream < St > ( stream : St , state : & TestState < Self > ) -> BoxedWeightedStream < ( ) >
83
+ where
84
+ St : Stream < Item = Self :: Item > + Send + ' static ,
85
+ {
86
+ Box :: pin ( stream. future_queue ( state. max_weight ) )
87
+ }
88
+
89
+ fn create_stream_item (
90
+ desc : & TestFutureDesc < Self > ,
91
+ future : impl Future < Output = ( ) > + Send + ' static ,
92
+ ) -> Self :: Item {
93
+ ( desc. weight , future. boxed ( ) )
94
+ }
95
+
96
+ fn check_started (
97
+ check_state : & mut Self :: CheckState ,
98
+ desc : & TestFutureDesc < Self > ,
99
+ state : & TestState < Self > ,
100
+ ) {
101
+ // Check that current_weight doesn't go over the limit.
102
+ assert ! (
103
+ check_state. current_weight < state. max_weight,
104
+ "current weight {} exceeds max weight {}" ,
105
+ check_state. current_weight,
106
+ state. max_weight,
107
+ ) ;
108
+ check_state. current_weight += desc. weight ;
109
+ }
110
+
111
+ fn check_finished (
112
+ check_state : & mut Self :: CheckState ,
113
+ desc : & TestFutureDesc < Self > ,
114
+ _state : & TestState < Self > ,
115
+ ) {
116
+ check_state. current_weight -= desc. weight ;
117
+ }
118
+ }
119
+
120
+ #[ derive( Debug , Default ) ]
121
+ struct NonGroupedCheckState {
122
+ current_weight : usize ,
123
+ }
124
+
34
125
#[ test]
35
126
fn test_examples ( ) {
36
127
let state = TestState {
@@ -39,25 +130,26 @@ fn test_examples() {
39
130
start_delay: Duration :: ZERO ,
40
131
delay: Duration :: ZERO ,
41
132
weight: 0 ,
133
+ group: ( ) ,
42
134
} ] ,
43
135
} ;
44
136
test_future_queue_impl ( state) ;
45
137
}
46
138
47
139
proptest ! {
48
140
#[ test]
49
- fn proptest_future_queue( state: TestState ) {
141
+ fn proptest_future_queue( state: TestState < ( ) > ) {
50
142
test_future_queue_impl( state)
51
143
}
52
144
}
53
145
54
146
#[ derive( Clone , Copy , Debug ) ]
55
- enum FutureEvent {
56
- Started ( usize , TestFutureDesc ) ,
57
- Finished ( usize , TestFutureDesc ) ,
147
+ enum FutureEvent < G : GroupSpec > {
148
+ Started ( usize , TestFutureDesc < G > ) ,
149
+ Finished ( usize , TestFutureDesc < G > ) ,
58
150
}
59
151
60
- fn test_future_queue_impl ( state : TestState ) {
152
+ fn test_future_queue_impl < G : GroupSpec > ( state : TestState < G > ) {
61
153
let runtime = tokio:: runtime:: Builder :: new_current_thread ( )
62
154
. enable_time ( )
63
155
. start_paused ( true )
@@ -88,7 +180,7 @@ fn test_future_queue_impl(state: TestState) {
88
180
. expect ( "receiver held open by loop" ) ;
89
181
} ;
90
182
// Errors should never occur here.
91
- if let Err ( err) = future_sender. send ( ( desc. weight , delay_fut) ) {
183
+ if let Err ( err) = future_sender. send ( G :: create_stream_item ( & desc, delay_fut) ) {
92
184
panic ! ( "future_receiver held open by loop: {}" , err) ;
93
185
}
94
186
}
@@ -102,11 +194,11 @@ fn test_future_queue_impl(state: TestState) {
102
194
103
195
let mut completed_map = vec ! [ false ; state. future_descriptions. len( ) ] ;
104
196
let mut last_started_id: Option < usize > = None ;
105
- let mut current_weight = 0 ;
197
+ let mut check_state = G :: CheckState :: default ( ) ;
106
198
107
199
runtime. block_on ( async move {
108
200
// Record values that have been completed in this map.
109
- let mut stream = stream . future_queue ( state. max_weight ) ;
201
+ let mut stream = G :: create_stream ( stream , & state) ;
110
202
let mut receiver_done = false ;
111
203
loop {
112
204
tokio:: select! {
@@ -122,19 +214,12 @@ fn test_future_queue_impl(state: TestState) {
122
214
assert_eq!( expected_id, id, "expected future id to start != actual id that started" ) ;
123
215
last_started_id = Some ( id) ;
124
216
125
- // Check that the current weight doesn't go over the limit.
126
- assert!(
127
- current_weight < state. max_weight,
128
- "current weight {} exceeds max weight {}" ,
129
- current_weight,
130
- state. max_weight,
131
- ) ;
132
- current_weight += desc. weight;
217
+ G :: check_started( & mut check_state, & desc, & state) ;
133
218
}
134
219
Some ( FutureEvent :: Finished ( id, desc) ) => {
135
220
// Record that this value was completed.
136
221
completed_map[ id] = true ;
137
- current_weight -= desc. weight ;
222
+ G :: check_finished ( & mut check_state , & desc, & state ) ;
138
223
}
139
224
None => {
140
225
// All futures finished -- going to check for completion in stream.next() below.
0 commit comments