Skip to content

Commit 41a123b

Browse files
committed
add serialize option to join op
Signed-off-by: Teo Koon Peng <[email protected]>
1 parent 9a231e7 commit 41a123b

File tree

3 files changed

+119
-11
lines changed

3 files changed

+119
-11
lines changed

diagram.schema.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,13 @@
209209
"next": {
210210
"$ref": "#/definitions/NextOperation"
211211
},
212+
"serialize": {
213+
"description": "Whether to serialize incoming outputs before perform the join. This allows for joining outputs of different types at the cost of serialization overhead. If there is true, the resulting output will be of [`serde_json::Value`].",
214+
"type": [
215+
"boolean",
216+
"null"
217+
]
218+
},
212219
"type": {
213220
"type": "string",
214221
"enum": [

src/diagram/join.rs

Lines changed: 103 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,19 @@ use serde::{Deserialize, Serialize};
55
use smallvec::SmallVec;
66
use tracing::debug;
77

8-
use crate::{Builder, IterBufferable};
8+
use crate::{Builder, IterBufferable, Output};
99

1010
use super::{DiagramError, DynOutput, NextOperation, NodeRegistry, SerializeMessage};
1111

1212
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
1313
#[serde(rename_all = "snake_case")]
1414
pub struct JoinOp {
1515
pub(super) next: NextOperation,
16+
17+
/// Whether to serialize incoming outputs before perform the join. This allows for joining
18+
/// outputs of different types at the cost of serialization overhead. If there is true, the
19+
/// resulting output will be of [`serde_json::Value`].
20+
pub(super) serialize: Option<bool>,
1621
}
1722

1823
pub(super) fn register_join_impl<T, Serializer>(registry: &mut NodeRegistry)
@@ -34,6 +39,42 @@ where
3439
// register_serialize::<Vec<T>, Serializer>(registry);
3540
}
3641

42+
/// Serialize the outputs before joining them, and convert the resulting joined output into a
43+
/// [`serde_json::Value`].
44+
pub(super) fn serialize_and_join(
45+
builder: &mut Builder,
46+
registry: &NodeRegistry,
47+
outputs: Vec<DynOutput>,
48+
) -> Result<Output<serde_json::Value>, DiagramError> {
49+
debug!("serialize and join outputs {:?}", outputs);
50+
51+
if outputs.is_empty() {
52+
// do not allow empty joins
53+
return Err(DiagramError::EmptyJoin);
54+
}
55+
56+
let outputs = outputs
57+
.into_iter()
58+
.map(|o| {
59+
let serialize_impl = registry
60+
.serialize_impls
61+
.get(&o.type_id)
62+
.ok_or(DiagramError::NotSerializable)?;
63+
let serialized_output = serialize_impl(builder, o)?;
64+
Ok(serialized_output)
65+
})
66+
.collect::<Result<Vec<_>, DiagramError>>()?;
67+
68+
// we need to convert the joined output to [`serde_json::Value`] in order for it to be
69+
// serializable.
70+
let joined_output = outputs.join_vec::<4>(builder).output();
71+
let json_output = joined_output
72+
.chain(builder)
73+
.map_block(|o| serde_json::to_value(o).unwrap())
74+
.output();
75+
Ok(json_output)
76+
}
77+
3778
fn join_impl<T>(builder: &mut Builder, outputs: Vec<DynOutput>) -> Result<DynOutput, DiagramError>
3879
where
3980
T: Send + Sync + 'static,
@@ -51,12 +92,6 @@ where
5192
let outputs = outputs
5293
.into_iter()
5394
.map(|o| {
54-
// joins is only supported for outputs of the same type. This is because joins of
55-
// different types produces a tuple and we cannot output a tuple as we don't
56-
// know the number and order of join inputs at compile time.
57-
// A workaround is to serialize them all the `serde_json::Value` or convert them to `Box<dyn Any>`.
58-
// But the problem with `Box<dyn Any>` is that we can't convert it back to the original type,
59-
// so nodes need to take a request of `JoinOutput<Box<dyn Any>>`.
6095
if o.type_id != first_type {
6196
Err(DiagramError::TypeMismatch)
6297
} else {
@@ -215,4 +250,65 @@ mod tests {
215250
let err = fixture.spawn_io_workflow(&diagram).unwrap_err();
216251
assert!(matches!(err, DiagramError::EmptyJoin));
217252
}
253+
254+
#[test]
255+
fn test_serialize_and_join() {
256+
let mut fixture = DiagramTestFixture::new();
257+
258+
fn num_output(_: serde_json::Value) -> i64 {
259+
1
260+
}
261+
262+
fixture.registry.register_node_builder(
263+
"num_output".to_string(),
264+
"num_output".to_string(),
265+
|builder, _config: ()| builder.create_map_block(num_output),
266+
);
267+
268+
fn string_output(_: serde_json::Value) -> String {
269+
"hello".to_string()
270+
}
271+
272+
fixture.registry.register_node_builder(
273+
"string_output".to_string(),
274+
"string_output".to_string(),
275+
|builder, _config: ()| builder.create_map_block(string_output),
276+
);
277+
278+
let diagram = Diagram::from_json(json!({
279+
"version": "0.1.0",
280+
"start": "fork_clone",
281+
"ops": {
282+
"fork_clone": {
283+
"type": "fork_clone",
284+
"next": ["op1", "op2"]
285+
},
286+
"op1": {
287+
"type": "node",
288+
"builder": "num_output",
289+
"next": "join",
290+
},
291+
"op2": {
292+
"type": "node",
293+
"builder": "string_output",
294+
"next": "join",
295+
},
296+
"join": {
297+
"type": "join",
298+
"next": { "builtin": "terminate" },
299+
"serialize": true,
300+
},
301+
}
302+
}))
303+
.unwrap();
304+
305+
let result = fixture
306+
.spawn_and_run(&diagram, serde_json::Value::Null)
307+
.unwrap();
308+
assert_eq!(result.as_array().unwrap().len(), 2);
309+
// order is not guaranteed so need to test for both possibility
310+
assert!(result[0] == 1 || result[0] == "hello");
311+
assert!(result[1] == 1 || result[1] == "hello");
312+
assert!(result[0] != result[1]);
313+
}
218314
}

src/diagram/workflow_builder.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::{any::TypeId, collections::HashMap};
22

33
use tracing::{debug, warn};
44

5-
use crate::{Builder, InputSlot, Output, StreamPack};
5+
use crate::{diagram::join::serialize_and_join, Builder, InputSlot, Output, StreamPack};
66

77
use super::{
88
fork_clone::DynForkClone, impls::DefaultImpl, split_chain, transform::transform_output,
@@ -248,7 +248,7 @@ fn connect_vertex<'a>(
248248
match target.op {
249249
// join needs all incoming edges to be connected at once so it is done at the vertex level
250250
// instead of per edge level.
251-
DiagramOperation::Join(_) => {
251+
DiagramOperation::Join(join_op) => {
252252
if target.in_edges.is_empty() {
253253
return Err(DiagramError::EmptyJoin);
254254
}
@@ -264,8 +264,13 @@ fn connect_vertex<'a>(
264264
})
265265
.collect();
266266

267-
let join_impl = &registry.join_impls[&outputs[0].type_id];
268-
let joined_output = join_impl(builder, outputs)?;
267+
let joined_output = if join_op.serialize.unwrap_or(false) {
268+
serialize_and_join(builder, registry, outputs)?.into()
269+
} else {
270+
let join_impl = &registry.join_impls[&outputs[0].type_id];
271+
join_impl(builder, outputs)?
272+
};
273+
269274
let out_edge = edges.get_mut(&target.out_edges[0]).unwrap();
270275
out_edge.state = EdgeState::Ready {
271276
output: joined_output,

0 commit comments

Comments
 (0)