Skip to content

Commit 064ec4e

Browse files
committed
stablize the order of outputs passed to join_vec
Signed-off-by: Teo Koon Peng <[email protected]>
1 parent a1e6f69 commit 064ec4e

File tree

3 files changed

+105
-43
lines changed

3 files changed

+105
-43
lines changed

src/diagram.rs

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@ pub type OperationId = String;
4040
)]
4141
#[serde(untagged, rename_all = "snake_case")]
4242
pub enum NextOperation {
43-
Target(String),
43+
Target(OperationId),
4444
Builtin { builtin: BuiltinTarget },
4545
}
4646

4747
impl Display for NextOperation {
4848
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
4949
match self {
50-
Self::Target(builder_id) => f.write_str(builder_id),
50+
Self::Target(operation_id) => f.write_str(operation_id),
5151
Self::Builtin { builtin } => write!(f, "builtin:{}", builtin),
5252
}
5353
}
@@ -77,6 +77,49 @@ pub enum BuiltinTarget {
7777
Dispose,
7878
}
7979

80+
#[derive(
81+
Debug, Clone, Serialize, Deserialize, JsonSchema, Hash, PartialEq, Eq, PartialOrd, Ord,
82+
)]
83+
#[serde(untagged, rename_all = "snake_case")]
84+
pub enum SourceOperation {
85+
Source(OperationId),
86+
Builtin { builtin: BuiltinSource },
87+
}
88+
89+
impl From<OperationId> for SourceOperation {
90+
fn from(value: OperationId) -> Self {
91+
SourceOperation::Source(value)
92+
}
93+
}
94+
95+
impl Display for SourceOperation {
96+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97+
match self {
98+
Self::Source(operation_id) => f.write_str(operation_id),
99+
Self::Builtin { builtin } => write!(f, "builtin:{}", builtin),
100+
}
101+
}
102+
}
103+
104+
#[derive(
105+
Debug,
106+
Clone,
107+
Serialize,
108+
Deserialize,
109+
JsonSchema,
110+
Hash,
111+
PartialEq,
112+
Eq,
113+
PartialOrd,
114+
Ord,
115+
strum::Display,
116+
)]
117+
#[serde(rename_all = "snake_case")]
118+
#[strum(serialize_all = "snake_case")]
119+
pub enum BuiltinSource {
120+
Start,
121+
}
122+
80123
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
81124
#[serde(rename_all = "snake_case")]
82125
pub struct TerminateOp {}

src/diagram/join.rs

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,19 @@ use tracing::debug;
77

88
use crate::{Builder, IterBufferable, Output};
99

10-
use super::{DiagramError, DynOutput, NextOperation, NodeRegistry, SerializeMessage};
10+
use super::{
11+
DiagramError, DynOutput, NextOperation, NodeRegistry, SerializeMessage, SourceOperation,
12+
};
1113

1214
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
1315
#[serde(rename_all = "snake_case")]
1416
pub struct JoinOp {
1517
pub(super) next: NextOperation,
1618

19+
/// Controls the order of the resulting join. Each item must be an operation id of one of the
20+
/// incoming outputs.
21+
pub(super) order: Vec<SourceOperation>,
22+
1723
/// Do not serialize before performing the join. If true, joins can only be done
1824
/// on outputs of the same type.
1925
pub(super) no_serialize: Option<bool>,
@@ -31,11 +37,6 @@ where
3137
registry
3238
.join_impls
3339
.insert(TypeId::of::<T>(), Box::new(join_impl::<T>));
34-
35-
// FIXME(koonpeng): join_vec results in a SmallVec<[T; N]>, we can't serialize it because
36-
// it doesn't implement JsonSchema, and we can't impl it because of orphan rule. We would need
37-
// to create our own trait that covers `JsonSchema` and `Serialize`.
38-
// register_serialize::<Vec<T>, Serializer>(registry);
3940
}
4041

4142
/// Serialize the outputs before joining them, and convert the resulting joined output into a
@@ -151,9 +152,9 @@ mod tests {
151152
"ops": {
152153
"split": {
153154
"type": "split",
154-
"index": ["getSplitValue1", "getSplitValue2"]
155+
"index": ["get_split_value1", "get_split_value2"]
155156
},
156-
"getSplitValue1": {
157+
"get_split_value1": {
157158
"type": "node",
158159
"builder": "get_split_value",
159160
"next": "op1",
@@ -163,7 +164,7 @@ mod tests {
163164
"builder": "multiply3_uncloneable",
164165
"next": "join",
165166
},
166-
"getSplitValue2": {
167+
"get_split_value2": {
167168
"type": "node",
168169
"builder": "get_split_value",
169170
"next": "op2",
@@ -175,10 +176,11 @@ mod tests {
175176
},
176177
"join": {
177178
"type": "join",
178-
"next": "serializeJoinOutput",
179+
"order": ["op1", "op2"],
180+
"next": "serialize_join_output",
179181
"no_serialize": true,
180182
},
181-
"serializeJoinOutput": {
183+
"serialize_join_output": {
182184
"type": "node",
183185
"builder": "serialize_join_output",
184186
"next": { "builtin": "terminate" },
@@ -191,10 +193,8 @@ mod tests {
191193
.spawn_and_run(&diagram, serde_json::Value::from([1, 2]))
192194
.unwrap();
193195
assert_eq!(result.as_array().unwrap().len(), 2);
194-
// order is not guaranteed so need to test for both possibility
195-
assert!(result[0] == 3 || result[0] == 6);
196-
assert!(result[1] == 3 || result[1] == 6);
197-
assert!(result[0] != result[1]);
196+
assert_eq!(result[0], 3);
197+
assert_eq!(result[1], 6);
198198
}
199199

200200
#[test]
@@ -241,6 +241,7 @@ mod tests {
241241
},
242242
"join": {
243243
"type": "join",
244+
"order": [],
244245
"next": { "builtin": "terminate" },
245246
"no_serialize": true,
246247
},
@@ -296,6 +297,7 @@ mod tests {
296297
},
297298
"join": {
298299
"type": "join",
300+
"order": ["op1", "op2"],
299301
"next": { "builtin": "terminate" },
300302
},
301303
}
@@ -306,9 +308,7 @@ mod tests {
306308
.spawn_and_run(&diagram, serde_json::Value::Null)
307309
.unwrap();
308310
assert_eq!(result.as_array().unwrap().len(), 2);
309-
// order is not guaranteed so need to test for both possibility
310-
assert!(result[0] == 1 || result[0] == "hello");
311-
assert!(result[1] == 1 || result[1] == "hello");
312-
assert!(result[0] != result[1]);
311+
assert_eq!(result[0], 1);
312+
assert_eq!(result[1], "hello");
313313
}
314314
}

src/diagram/workflow_builder.rs

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::{diagram::join::serialize_and_join, Builder, InputSlot, Output, Strea
77
use super::{
88
fork_clone::DynForkClone, impls::DefaultImpl, split_chain, transform::transform_output,
99
BuiltinTarget, Diagram, DiagramError, DiagramOperation, DiagramScope, DynInputSlot, DynOutput,
10-
NextOperation, NodeOp, NodeRegistry, OperationId, SplitOpParams,
10+
NextOperation, NodeOp, NodeRegistry, OperationId, SourceOperation, SplitOpParams,
1111
};
1212

1313
struct Vertex<'a> {
@@ -18,8 +18,7 @@ struct Vertex<'a> {
1818
}
1919

2020
struct Edge<'a> {
21-
/// The source of the edge, may be `None` if it comes from outside the diagram, e.g. the entry point of the diagram.
22-
source: Option<&'a OperationId>,
21+
source: SourceOperation,
2322
target: &'a NextOperation,
2423
state: EdgeState<'a>,
2524
}
@@ -79,7 +78,9 @@ pub(super) fn create_workflow<'a, Streams: StreamPack>(
7978
edges.insert(
8079
edges.len(),
8180
Edge {
82-
source: None,
81+
source: SourceOperation::Builtin {
82+
builtin: super::BuiltinSource::Start,
83+
},
8384
target: &diagram.start,
8485
state: EdgeState::Ready {
8586
output: scope.input.into(),
@@ -95,10 +96,16 @@ pub(super) fn create_workflow<'a, Streams: StreamPack>(
9596

9697
let mut terminate_edges: Vec<usize> = Vec::new();
9798

98-
let mut add_edge = |source: Option<&'a OperationId>,
99+
let mut add_edge = |source: SourceOperation,
99100
target: &'a NextOperation,
100101
state: EdgeState<'a>|
101102
-> Result<(), DiagramError> {
103+
let source_id = if let SourceOperation::Source(source) = &source {
104+
Some(source.clone())
105+
} else {
106+
None
107+
};
108+
102109
edges.insert(
103110
edges.len(),
104111
Edge {
@@ -109,10 +116,10 @@ pub(super) fn create_workflow<'a, Streams: StreamPack>(
109116
);
110117
let new_edge_id = edges.len() - 1;
111118

112-
if let Some(source) = source {
119+
if let Some(source_id) = source_id {
113120
let source_vertex = vertices
114-
.get_mut(source)
115-
.ok_or_else(|| DiagramError::OperationNotFound(source.clone()))?;
121+
.get_mut(&source_id)
122+
.ok_or_else(|| DiagramError::OperationNotFound(source_id.clone()))?;
116123
source_vertex.out_edges.push(new_edge_id);
117124
}
118125

@@ -141,7 +148,7 @@ pub(super) fn create_workflow<'a, Streams: StreamPack>(
141148
let n = reg.create_node(builder, node_op.config.clone())?;
142149
inputs.insert(op_id, n.input);
143150
add_edge(
144-
Some(op_id),
151+
op_id.clone().into(),
145152
&node_op.next,
146153
EdgeState::Ready {
147154
output: n.output.into(),
@@ -151,35 +158,39 @@ pub(super) fn create_workflow<'a, Streams: StreamPack>(
151158
}
152159
DiagramOperation::ForkClone(fork_clone_op) => {
153160
for next_op_id in fork_clone_op.next.iter() {
154-
add_edge(Some(op_id), next_op_id, EdgeState::Pending)?;
161+
add_edge(op_id.clone().into(), next_op_id, EdgeState::Pending)?;
155162
}
156163
}
157164
DiagramOperation::Unzip(unzip_op) => {
158165
for next_op_id in unzip_op.next.iter() {
159-
add_edge(Some(op_id), next_op_id, EdgeState::Pending)?;
166+
add_edge(op_id.clone().into(), next_op_id, EdgeState::Pending)?;
160167
}
161168
}
162169
DiagramOperation::ForkResult(fork_result_op) => {
163-
add_edge(Some(op_id), &fork_result_op.ok, EdgeState::Pending)?;
164-
add_edge(Some(op_id), &fork_result_op.err, EdgeState::Pending)?;
170+
add_edge(op_id.clone().into(), &fork_result_op.ok, EdgeState::Pending)?;
171+
add_edge(
172+
op_id.clone().into(),
173+
&fork_result_op.err,
174+
EdgeState::Pending,
175+
)?;
165176
}
166177
DiagramOperation::Split(split_op) => {
167178
let next_op_ids: Vec<&NextOperation> = match &split_op.params {
168179
SplitOpParams::Index(v) => v.iter().collect(),
169180
SplitOpParams::Key(v) => v.values().collect(),
170181
};
171182
for next_op_id in next_op_ids {
172-
add_edge(Some(op_id), next_op_id, EdgeState::Pending)?;
183+
add_edge(op_id.clone().into(), next_op_id, EdgeState::Pending)?;
173184
}
174185
if let Some(remaining) = &split_op.remaining {
175-
add_edge(Some(op_id), &remaining, EdgeState::Pending)?;
186+
add_edge(op_id.clone().into(), &remaining, EdgeState::Pending)?;
176187
}
177188
}
178189
DiagramOperation::Join(join_op) => {
179-
add_edge(Some(op_id), &join_op.next, EdgeState::Pending)?;
190+
add_edge(op_id.clone().into(), &join_op.next, EdgeState::Pending)?;
180191
}
181192
DiagramOperation::Transform(transform_op) => {
182-
add_edge(Some(op_id), &transform_op.next, EdgeState::Pending)?;
193+
add_edge(op_id.clone().into(), &transform_op.next, EdgeState::Pending)?;
183194
}
184195
DiagramOperation::Dispose => {}
185196
}
@@ -252,23 +263,31 @@ fn connect_vertex<'a>(
252263
if target.in_edges.is_empty() {
253264
return Err(DiagramError::EmptyJoin);
254265
}
255-
let outputs: Vec<DynOutput> = target
266+
let mut outputs: HashMap<SourceOperation, DynOutput> = target
256267
.in_edges
257268
.iter()
258269
.map(|e| {
259270
let edge = edges.remove(e).unwrap();
260271
match edge.state {
261-
EdgeState::Ready { output, origin: _ } => output,
272+
EdgeState::Ready { output, origin: _ } => (edge.source, output),
262273
_ => panic!("expected all incoming edges to be ready"),
263274
}
264275
})
265276
.collect();
266277

278+
let mut ordered_outputs: Vec<DynOutput> = Vec::with_capacity(target.in_edges.len());
279+
for source_id in join_op.order.iter() {
280+
let o = outputs
281+
.remove(source_id)
282+
.ok_or(DiagramError::OperationNotFound(source_id.to_string()))?;
283+
ordered_outputs.push(o);
284+
}
285+
267286
let joined_output = if join_op.no_serialize.unwrap_or(false) {
268-
let join_impl = &registry.join_impls[&outputs[0].type_id];
269-
join_impl(builder, outputs)?
287+
let join_impl = &registry.join_impls[&ordered_outputs[0].type_id];
288+
join_impl(builder, ordered_outputs)?
270289
} else {
271-
serialize_and_join(builder, registry, outputs)?.into()
290+
serialize_and_join(builder, registry, ordered_outputs)?.into()
272291
};
273292

274293
let out_edge = edges.get_mut(&target.out_edges[0]).unwrap();

0 commit comments

Comments
 (0)