Skip to content

Commit b160d35

Browse files
authored
Fix logic to build optimized graph (#423)
1 parent 3953872 commit b160d35

File tree

3 files changed

+165
-112
lines changed

3 files changed

+165
-112
lines changed

pipeline/util/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@
2626
<artifactId>datacommons-import-util</artifactId>
2727
<version>${revision}</version>
2828
</dependency>
29+
<dependency>
30+
<groupId>org.apache.beam</groupId>
31+
<artifactId>beam-sdks-java-extensions-protobuf</artifactId>
32+
<version>${beam.version}</version>
33+
</dependency>
2934
<dependency>
3035
<groupId>org.apache.beam</groupId>
3136
<artifactId>beam-sdks-java-core</artifactId>

pipeline/util/src/main/java/org/datacommons/pipeline/util/PipelineUtils.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
import java.nio.charset.StandardCharsets;
1111
import java.util.ArrayList;
1212
import java.util.Base64;
13+
import java.util.Comparator;
1314
import java.util.HashMap;
1415
import java.util.HashSet;
1516
import java.util.List;
1617
import java.util.Map;
1718
import java.util.Set;
19+
import java.util.stream.Collectors;
20+
import java.util.stream.StreamSupport;
1821
import java.util.zip.GZIPOutputStream;
1922
import org.apache.beam.sdk.Pipeline;
2023
import org.apache.beam.sdk.io.TFRecordIO;
@@ -183,18 +186,21 @@ public void processElement(
183186
element,
184187
OutputReceiver<McfOptimizedGraph> receiver) {
185188
McfStatVarObsSeries.Builder svObsSeries = McfStatVarObsSeries.newBuilder();
186-
187189
svObsSeries.setKey(element.getKey());
188-
for (McfStatVarObsSeries.StatVarObs svo : element.getValue()) {
190+
191+
Iterable<McfStatVarObsSeries.StatVarObs> observations =
192+
StreamSupport.stream(element.getValue().spliterator(), false)
193+
.sorted(
194+
Comparator.comparing(McfStatVarObsSeries.StatVarObs::getDate))
195+
.collect(Collectors.toList());
196+
for (McfStatVarObsSeries.StatVarObs svo : observations) {
189197
svObsSeries.addSvObsList(svo);
190198
}
191199
McfOptimizedGraph.Builder res =
192200
McfOptimizedGraph.newBuilder().setSvObsSeries(svObsSeries);
193201
receiver.output(res.build());
194-
LOGGER.info(res.build().toString());
195202
}
196203
}));
197-
198204
return svObs;
199205
}
200206

@@ -295,7 +301,6 @@ public void processElement(
295301
McfGraph.Builder graphBuilder = McfGraph.newBuilder();
296302
graphBuilder.putNodes(element.getKey(), element.getValue());
297303
receiver.output(graphBuilder.build());
298-
System.out.println(graphBuilder.build().toString());
299304
}
300305
}));
301306
return combinedGraph;

pipeline/util/src/test/java/org/datacommons/pipeline/util/PipelineUtilsTest.java

Lines changed: 150 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
package org.datacommons.pipeline.util;
22

3+
import java.util.Arrays;
4+
import java.util.Comparator;
35
import java.util.HashMap;
6+
import java.util.List;
47
import java.util.Map;
8+
import java.util.stream.Collectors;
9+
import org.apache.beam.sdk.PipelineResult;
10+
import org.apache.beam.sdk.extensions.protobuf.ProtoCoder;
511
import org.apache.beam.sdk.options.PipelineOptions;
612
import org.apache.beam.sdk.options.PipelineOptionsFactory;
713
import org.apache.beam.sdk.testing.PAssert;
@@ -12,7 +18,11 @@
1218
import org.datacommons.proto.Mcf.McfGraph;
1319
import org.datacommons.proto.Mcf.McfGraph.PropertyValues;
1420
import org.datacommons.proto.Mcf.McfGraph.TypedValue;
21+
import org.datacommons.proto.Mcf.McfGraph.Values;
22+
import org.datacommons.proto.Mcf.McfOptimizedGraph;
23+
import org.datacommons.proto.Mcf.McfStatVarObsSeries;
1524
import org.datacommons.proto.Mcf.ValueType;
25+
import org.junit.Assert;
1626
import org.junit.Rule;
1727
import org.junit.Test;
1828
import org.junit.runner.RunWith;
@@ -21,142 +31,176 @@
2131
@RunWith(JUnit4.class)
2232
public class PipelineUtilsTest {
2333

24-
PipelineOptions options = PipelineOptionsFactory.create();
25-
34+
static PipelineOptions options = PipelineOptionsFactory.create();
2635
@Rule public TestPipeline p = TestPipeline.fromOptions(options);
2736

37+
private McfGraph createStatVarObservationGraph(
38+
String obsId, String statVar, String location, String date, String value) {
39+
McfGraph.Builder graph = McfGraph.newBuilder();
40+
PropertyValues.Builder pv = PropertyValues.newBuilder();
41+
pv.putPvs(
42+
"typeOf",
43+
Values.newBuilder()
44+
.addTypedValues(TypedValue.newBuilder().setValue("StatVarObservation"))
45+
.build());
46+
pv.putPvs(
47+
"variableMeasured",
48+
Values.newBuilder().addTypedValues(TypedValue.newBuilder().setValue(statVar)).build());
49+
pv.putPvs(
50+
"observationAbout",
51+
Values.newBuilder().addTypedValues(TypedValue.newBuilder().setValue(location)).build());
52+
pv.putPvs(
53+
"observationDate",
54+
Values.newBuilder().addTypedValues(TypedValue.newBuilder().setValue(date)).build());
55+
pv.putPvs(
56+
"value",
57+
Values.newBuilder().addTypedValues(TypedValue.newBuilder().setValue(value)).build());
58+
pv.putPvs(
59+
"dcid",
60+
Values.newBuilder().addTypedValues(TypedValue.newBuilder().setValue(obsId)).build());
61+
graph.putNodes(obsId, pv.build());
62+
return graph.build();
63+
}
64+
65+
private McfStatVarObsSeries.StatVarObs createStatVarObs(String date, double value, String dcid) {
66+
McfStatVarObsSeries.StatVarObs.Builder svObs = McfStatVarObsSeries.StatVarObs.newBuilder();
67+
svObs.setDate(date);
68+
svObs.setNumber(value);
69+
svObs.setDcid(dcid);
70+
svObs.setPvs(PropertyValues.newBuilder().build());
71+
return svObs.build();
72+
}
73+
74+
private McfStatVarObsSeries createMcfStatVarObsSeries(
75+
String statVar, String location, List<McfStatVarObsSeries.StatVarObs> observations) {
76+
McfStatVarObsSeries.Key.Builder keyBuilder = McfStatVarObsSeries.Key.newBuilder();
77+
keyBuilder.setObservationAbout(location);
78+
keyBuilder.setVariableMeasured(statVar);
79+
80+
List<McfStatVarObsSeries.StatVarObs> sortedSvObs =
81+
observations.stream()
82+
.sorted(Comparator.comparing(McfStatVarObsSeries.StatVarObs::getDate))
83+
.collect(Collectors.toList());
84+
85+
McfStatVarObsSeries.Builder seriesBuilder = McfStatVarObsSeries.newBuilder();
86+
seriesBuilder.setKey(keyBuilder.build());
87+
seriesBuilder.addAllSvObsList(sortedSvObs);
88+
return seriesBuilder.build();
89+
}
90+
91+
@Test
92+
public void testBuildOptimizedMcfGraph() {
93+
options.setStableUniqueNames(PipelineOptions.CheckEnabled.OFF);
94+
p.getCoderRegistry()
95+
.registerCoderForClass(
96+
McfStatVarObsSeries.Key.class, ProtoCoder.of(McfStatVarObsSeries.Key.class));
97+
98+
PCollection<McfGraph> input =
99+
p.apply(
100+
Create.of(
101+
createStatVarObservationGraph(
102+
"obs1", "count_person", "country/USA", "2020", "32.0"),
103+
createStatVarObservationGraph(
104+
"obs2", "count_person", "country/USA", "2021", "33.0"),
105+
createStatVarObservationGraph(
106+
"obs4", "count_person", "country/India", "2022", "36.0")));
107+
108+
PCollection<McfOptimizedGraph> result = PipelineUtils.buildOptimizedMcfGraph(input);
109+
110+
McfOptimizedGraph expected1 =
111+
McfOptimizedGraph.newBuilder()
112+
.setSvObsSeries(
113+
createMcfStatVarObsSeries(
114+
"count_person",
115+
"country/USA",
116+
Arrays.asList(
117+
createStatVarObs("2020", 32.0, "obs1"),
118+
createStatVarObs("2021", 33.0, "obs2"))))
119+
.build();
120+
McfOptimizedGraph expected2 =
121+
McfOptimizedGraph.newBuilder()
122+
.setSvObsSeries(
123+
createMcfStatVarObsSeries(
124+
"count_person",
125+
"country/India",
126+
List.of(createStatVarObs("2022", 36.0, "obs4"))))
127+
.build();
128+
129+
PAssert.that(result).containsInAnyOrder(expected1, expected2);
130+
PipelineResult.State state = p.run().waitUntilFinish();
131+
Assert.assertEquals(PipelineResult.State.DONE, state);
132+
}
133+
28134
@Test
29135
public void testCombineGraphNodes() {
30136
// Input Graph 1
31137
McfGraph graph1 =
32-
McfGraph.newBuilder()
33-
.putNodes(
138+
createGraph(
139+
Map.of(
34140
"node1",
35-
PropertyValues.newBuilder()
36-
.putPvs(
37-
"propA",
38-
McfGraph.Values.newBuilder()
39-
.addTypedValues(
40-
TypedValue.newBuilder().setType(ValueType.TEXT).setValue("val1"))
41-
.build())
42-
.putPvs(
43-
"propB",
44-
McfGraph.Values.newBuilder()
45-
.addTypedValues(
46-
TypedValue.newBuilder().setType(ValueType.TEXT).setValue("valB1"))
47-
.build())
48-
.build())
49-
.putNodes(
141+
Map.of(
142+
"propA", List.of("val1"),
143+
"propB", List.of("valB1")),
50144
"node2",
51-
PropertyValues.newBuilder()
52-
.putPvs(
53-
"propC",
54-
McfGraph.Values.newBuilder()
55-
.addTypedValues(
56-
TypedValue.newBuilder().setType(ValueType.TEXT).setValue("valC1"))
57-
.build())
58-
.build())
59-
.build();
145+
Map.of("propC", List.of("valC1"))));
60146

61147
// Input Graph 2
62148
McfGraph graph2 =
63-
McfGraph.newBuilder()
64-
.putNodes(
149+
createGraph(
150+
Map.of(
65151
"node1",
66-
PropertyValues.newBuilder()
67-
.putPvs(
68-
"propA",
69-
McfGraph.Values.newBuilder()
70-
.addTypedValues(
71-
TypedValue.newBuilder().setType(ValueType.TEXT).setValue("val1"))
72-
.addTypedValues(
73-
TypedValue.newBuilder().setType(ValueType.TEXT).setValue("val2"))
74-
.build())
75-
.putPvs(
76-
"propD",
77-
McfGraph.Values.newBuilder()
78-
.addTypedValues(
79-
TypedValue.newBuilder().setType(ValueType.TEXT).setValue("valD1"))
80-
.build())
81-
.build())
82-
.putNodes(
152+
Map.of(
153+
"propA", List.of("val1", "val2"),
154+
"propD", List.of("valD1")),
83155
"node3",
84-
PropertyValues.newBuilder()
85-
.putPvs(
86-
"propE",
87-
McfGraph.Values.newBuilder()
88-
.addTypedValues(
89-
TypedValue.newBuilder().setType(ValueType.TEXT).setValue("valE1"))
90-
.build())
91-
.build())
92-
.build();
156+
Map.of("propE", List.of("valE1"))));
93157

94158
// Expected Combined Graph
95159
McfGraph expectedCombinedGraph =
96-
McfGraph.newBuilder()
97-
.putNodes(
160+
createGraph(
161+
Map.of(
98162
"node1",
99-
PropertyValues.newBuilder()
100-
.putPvs(
101-
"propA",
102-
McfGraph.Values.newBuilder()
103-
.addTypedValues(
104-
TypedValue.newBuilder().setType(ValueType.TEXT).setValue("val1"))
105-
.addTypedValues(
106-
TypedValue.newBuilder().setType(ValueType.TEXT).setValue("val2"))
107-
.build())
108-
.putPvs(
109-
"propB",
110-
McfGraph.Values.newBuilder()
111-
.addTypedValues(
112-
TypedValue.newBuilder().setType(ValueType.TEXT).setValue("valB1"))
113-
.build())
114-
.putPvs(
115-
"propD",
116-
McfGraph.Values.newBuilder()
117-
.addTypedValues(
118-
TypedValue.newBuilder().setType(ValueType.TEXT).setValue("valD1"))
119-
.build())
120-
.build())
121-
.putNodes(
163+
Map.of(
164+
"propA", List.of("val1", "val2"),
165+
"propB", List.of("valB1"),
166+
"propD", List.of("valD1")),
122167
"node2",
123-
PropertyValues.newBuilder()
124-
.putPvs(
125-
"propC",
126-
McfGraph.Values.newBuilder()
127-
.addTypedValues(
128-
TypedValue.newBuilder().setType(ValueType.TEXT).setValue("valC1"))
129-
.build())
130-
.build())
131-
.putNodes(
168+
Map.of("propC", List.of("valC1")),
132169
"node3",
133-
PropertyValues.newBuilder()
134-
.putPvs(
135-
"propE",
136-
McfGraph.Values.newBuilder()
137-
.addTypedValues(
138-
TypedValue.newBuilder().setType(ValueType.TEXT).setValue("valE1"))
139-
.build())
140-
.build())
141-
.build();
170+
Map.of("propE", List.of("valE1"))));
142171

143172
PCollection<McfGraph> input = p.apply("CreateInput", Create.of(graph1, graph2));
144173
PCollection<McfGraph> output = PipelineUtils.combineGraphNodes(input);
145174

146-
// The combineGraphNodes method returns a PCollection where each element is an McfGraph
147-
// containing a single combined node. To compare against a single expected graph,
148-
// we need to merge these single-node graphs back into one.
149175
PCollection<McfGraph> mergedOutput =
150176
output.apply(
151177
"MergeOutputGraphs", Combine.globally(new MergeMcfGraphsCombineFn()).withoutDefaults());
152-
153178
PAssert.thatSingleton(mergedOutput).isEqualTo(expectedCombinedGraph);
179+
PipelineResult.State state = p.run().waitUntilFinish();
180+
Assert.assertEquals(PipelineResult.State.DONE, state);
181+
}
154182

155-
p.run().waitUntilFinish();
183+
private McfGraph createGraph(Map<String, Map<String, List<String>>> nodeData) {
184+
McfGraph.Builder graph = McfGraph.newBuilder();
185+
for (Map.Entry<String, Map<String, List<String>>> nodeEntry : nodeData.entrySet()) {
186+
String nodeName = nodeEntry.getKey();
187+
Map<String, List<String>> props = nodeEntry.getValue();
188+
PropertyValues.Builder pvs = PropertyValues.newBuilder();
189+
for (Map.Entry<String, List<String>> propEntry : props.entrySet()) {
190+
String propName = propEntry.getKey();
191+
List<String> values = propEntry.getValue();
192+
McfGraph.Values.Builder valuesBuilder = McfGraph.Values.newBuilder();
193+
for (String value : values) {
194+
valuesBuilder.addTypedValues(
195+
TypedValue.newBuilder().setType(ValueType.TEXT).setValue(value));
196+
}
197+
pvs.putPvs(propName, valuesBuilder.build());
198+
}
199+
graph.putNodes(nodeName, pvs.build());
200+
}
201+
return graph.build();
156202
}
157203

158-
// A CombineFn to merge multiple McfGraph objects (each containing a single node) into a single
159-
// McfGraph containing all nodes.
160204
static class MergeMcfGraphsCombineFn
161205
extends Combine.CombineFn<McfGraph, Map<String, PropertyValues>, McfGraph> {
162206
@Override
@@ -167,7 +211,6 @@ public Map<String, PropertyValues> createAccumulator() {
167211
@Override
168212
public Map<String, PropertyValues> addInput(
169213
Map<String, PropertyValues> accumulator, McfGraph input) {
170-
// Each input McfGraph is expected to contain exactly one node.
171214
accumulator.putAll(input.getNodesMap());
172215
return accumulator;
173216
}

0 commit comments

Comments
 (0)