diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java index 185c4bfac..7c3d72fb9 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java @@ -38,6 +38,7 @@ import org.apache.geaflow.dsl.udf.graph.ClosenessCentrality; import org.apache.geaflow.dsl.udf.graph.ClusterCoefficient; import org.apache.geaflow.dsl.udf.graph.CommonNeighbors; +import org.apache.geaflow.dsl.udf.graph.ConnectedComponents; import org.apache.geaflow.dsl.udf.graph.IncKHopAlgorithm; import org.apache.geaflow.dsl.udf.graph.IncMinimumSpanningTree; import org.apache.geaflow.dsl.udf.graph.IncWeakConnectedComponents; @@ -45,6 +46,7 @@ import org.apache.geaflow.dsl.udf.graph.JaccardSimilarity; import org.apache.geaflow.dsl.udf.graph.KCore; import org.apache.geaflow.dsl.udf.graph.KHop; +import org.apache.geaflow.dsl.udf.graph.LabelPropagation; import org.apache.geaflow.dsl.udf.graph.PageRank; import org.apache.geaflow.dsl.udf.graph.SingleSourceShortestPath; import org.apache.geaflow.dsl.udf.graph.TriangleCount; @@ -223,6 +225,8 @@ public class BuildInSqlFunctionTable extends ListSqlOperatorTable { .add(GeaFlowFunction.of(CommonNeighbors.class)) .add(GeaFlowFunction.of(JaccardSimilarity.class)) .add(GeaFlowFunction.of(IncKHopAlgorithm.class)) + .add(GeaFlowFunction.of(LabelPropagation.class)) + .add(GeaFlowFunction.of(ConnectedComponents.class)) .build(); public BuildInSqlFunctionTable(GQLJavaTypeFactory typeFactory) { diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ConnectedComponents.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ConnectedComponents.java new file mode 100644 index 000000000..e58bce86a --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ConnectedComponents.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; +import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.RowEdge; +import org.apache.geaflow.dsl.common.data.RowVertex; +import org.apache.geaflow.dsl.common.data.impl.ObjectRow; +import org.apache.geaflow.dsl.common.function.Description; +import org.apache.geaflow.dsl.common.types.GraphSchema; +import org.apache.geaflow.dsl.common.types.StructType; +import org.apache.geaflow.dsl.common.types.TableField; +import org.apache.geaflow.model.graph.edge.EdgeDirection; + +/** + * Connected Components (CC) algorithm for finding connected components in undirected graphs. + * + *

This algorithm identifies all connected components in a graph by propagating + * the minimum vertex ID throughout each connected component. Each vertex starts with + * its own ID as the component ID and iteratively adopts the minimum component ID + * from its neighbors until convergence.

+ * + *

The algorithm treats the graph as undirected by considering edges in both directions.

+ * + *

Parameters:

+ * + * + *

Example usage:

+ *
+ * CALL cc(20, 'component') YIELD (id, component)
+ * RETURN id, component
+ * ORDER BY id;
+ * 
+ */ +@Description(name = "cc", description = "built-in udga for connected components") +public class ConnectedComponents implements AlgorithmUserFunction { + + private AlgorithmRuntimeContext context; + private String outputFieldName = "component"; + private int iterations = 20; + + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; + + if (parameters.length > 2) { + throw new IllegalArgumentException( + "Only support zero or two arguments, usage: cc([iterations, [outputFieldName]])"); + } + + if (parameters.length > 0) { + this.iterations = Integer.parseInt(String.valueOf(parameters[0])); + } + + if (parameters.length > 1) { + this.outputFieldName = String.valueOf(parameters[1]); + } + } + + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + updatedValues.ifPresent(vertex::setValue); + List edges = new ArrayList<>(context.loadEdges(EdgeDirection.BOTH)); + + if (context.getCurrentIterationId() == 1L) { + // First iteration: initialize component ID with vertex ID + String initValue = String.valueOf(vertex.getId()); + sendMessageToNeighbors(edges, initValue); + context.sendMessage(vertex.getId(), initValue); + context.updateVertexValue(ObjectRow.create(initValue)); + } else if (context.getCurrentIterationId() < iterations) { + // Subsequent iterations: find minimum component ID + String minComponent = messages.next(); + while (messages.hasNext()) { + String next = messages.next(); + if (next.compareTo(minComponent) < 0) { + minComponent = next; + } + } + + // Propagate the minimum component ID to all neighbors + sendMessageToNeighbors(edges, minComponent); + context.sendMessage(vertex.getId(), minComponent); + context.updateVertexValue(ObjectRow.create(minComponent)); + } + } + + @Override + public void finish(RowVertex vertex, Optional updatedValues) { + updatedValues.ifPresent(vertex::setValue); + String component = (String) vertex.getValue().getField(0, context.getGraphSchema().getIdType()); + context.take(ObjectRow.create(vertex.getId(), component)); + } + + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("id", graphSchema.getIdType(), false), + new TableField(outputFieldName, graphSchema.getIdType(), false) + ); + } + + private void sendMessageToNeighbors(List edges, String message) { + for (RowEdge edge : edges) { + context.sendMessage(edge.getTargetId(), message); + } + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/LabelPropagation.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/LabelPropagation.java new file mode 100644 index 000000000..7db4366d7 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/LabelPropagation.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; +import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.RowEdge; +import org.apache.geaflow.dsl.common.data.RowVertex; +import org.apache.geaflow.dsl.common.data.impl.ObjectRow; +import org.apache.geaflow.dsl.common.function.Description; +import org.apache.geaflow.dsl.common.types.GraphSchema; +import org.apache.geaflow.dsl.common.types.StructType; +import org.apache.geaflow.dsl.common.types.TableField; +import org.apache.geaflow.model.graph.edge.EdgeDirection; + +/** + * Label Propagation Algorithm (LPA) for community detection. + * + *

This algorithm assigns labels to vertices based on the most frequent label + * among their neighbors. It uses an iterative approach where vertices adopt the + * label that appears most frequently among their neighbors. In case of ties, + * the smallest label value is selected.

+ * + *

Performance Optimization: This implementation uses change detection to minimize + * communication overhead. Vertices only propagate their label to neighbors when it changes, + * significantly reducing message volume in later iterations when the algorithm stabilizes. + * This optimization makes the algorithm efficient for large-scale graphs.

+ * + *

Parameters:

+ *
    + *
  • iterations (optional): Maximum number of iterations (default: 100)
  • + *
  • outputFieldName (optional): Name of the output field (default: "label")
  • + *
+ * + *

Example usage:

+ *
+ * CALL lpa(100, 'label') YIELD (id, label)
+ * RETURN id, label
+ * ORDER BY id;
+ * 
+ */ +@Description(name = "lpa", description = "built-in udga for label propagation algorithm") +public class LabelPropagation implements AlgorithmUserFunction { + + private AlgorithmRuntimeContext context; + private String outputFieldName = "label"; + private int iterations = 100; + + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; + + if (parameters.length > 2) { + throw new IllegalArgumentException( + "Only support zero or two arguments, usage: lpa([iterations, [outputFieldName]])"); + } + + if (parameters.length > 0) { + this.iterations = Integer.parseInt(String.valueOf(parameters[0])); + } + + if (parameters.length > 1) { + this.outputFieldName = String.valueOf(parameters[1]); + } + } + + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + updatedValues.ifPresent(vertex::setValue); + List edges = new ArrayList<>(context.loadEdges(EdgeDirection.BOTH)); + + if (context.getCurrentIterationId() == 1L) { + // First iteration: initialize label with vertex ID + String initValue = String.valueOf(vertex.getId()); + sendMessageToNeighbors(edges, initValue); + context.sendMessage(vertex.getId(), initValue); + context.updateVertexValue(ObjectRow.create(initValue)); + } else if (context.getCurrentIterationId() < iterations) { + // Subsequent iterations: adopt most frequent label from neighbors + + // Collect and count neighbor labels + Map labelCount = new HashMap<>(); + while (messages.hasNext()) { + String label = messages.next(); + labelCount.merge(label, 1L, Long::sum); + } + + if (!labelCount.isEmpty()) { + // Find the most frequent label (smallest in case of tie) + String currentLabel = (String) vertex.getValue().getField(0, + context.getGraphSchema().getIdType()); + String newLabel = findMostFrequentLabel(labelCount, currentLabel); + + // Update and propagate if label changed + if (!newLabel.equals(currentLabel)) { + sendMessageToNeighbors(edges, newLabel); + context.sendMessage(vertex.getId(), newLabel); + context.updateVertexValue(ObjectRow.create(newLabel)); + } + } + } + } + + @Override + public void finish(RowVertex vertex, Optional updatedValues) { + updatedValues.ifPresent(vertex::setValue); + String label = (String) vertex.getValue().getField(0, context.getGraphSchema().getIdType()); + context.take(ObjectRow.create(vertex.getId(), label)); + } + + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("id", graphSchema.getIdType(), false), + new TableField(outputFieldName, graphSchema.getIdType(), false) + ); + } + + /** + * Finds the most frequent label from the label count map. + * In case of ties, returns the smallest label value. + * + * @param labelCount Map of labels to their frequencies + * @param currentLabel Current label of the vertex + * @return Most frequent label (smallest in case of tie) + */ + private String findMostFrequentLabel(Map labelCount, String currentLabel) { + if (labelCount.isEmpty()) { + return currentLabel; + } + + // Find maximum frequency + long maxCount = labelCount.values().stream() + .max(Long::compareTo) + .orElse(0L); + + // Find label with maximum frequency (smallest if tie) + String bestLabel = currentLabel; + for (Map.Entry entry : labelCount.entrySet()) { + if (entry.getValue() == maxCount) { + if (bestLabel == null || entry.getKey().compareTo(bestLabel) < 0) { + bestLabel = entry.getKey(); + } + } + } + + return bestLabel; + } + + private void sendMessageToNeighbors(List edges, String message) { + for (RowEdge edge : edges) { + context.sendMessage(edge.getTargetId(), message); + } + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java index 0ad33935f..cd9fbd20a 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java @@ -315,6 +315,24 @@ public void testEdgeIterator() throws Exception { .checkSinkResult(); } + @Test + public void testAlgorithmLabelPropagation() throws Exception { + QueryTester + .build() + .withQueryPath("/query/gql_algorithm_lpa.sql") + .execute() + .checkSinkResult(); + } + + @Test + public void testAlgorithmConnectedComponents() throws Exception { + QueryTester + .build() + .withQueryPath("/query/gql_algorithm_cc.sql") + .execute() + .checkSinkResult(); + } + private void clearGraph() throws IOException { File file = new File(TEST_GRAPH_PATH); if (file.exists()) { diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/data/algo_test_edges.txt b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/data/algo_test_edges.txt new file mode 100644 index 000000000..4edd46cc2 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/data/algo_test_edges.txt @@ -0,0 +1,10 @@ +1,2,1.0 +1,3,1.0 +2,3,1.0 +2,4,1.0 +3,4,1.0 +4,5,1.0 +5,6,1.0 +5,7,1.0 +6,7,1.0 +9,10,1.0 diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/data/algo_test_vertex.txt b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/data/algo_test_vertex.txt new file mode 100644 index 000000000..0aa56d50f --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/data/algo_test_vertex.txt @@ -0,0 +1,10 @@ +1,vertex_1 +2,vertex_2 +3,vertex_3 +4,vertex_4 +5,vertex_5 +6,vertex_6 +7,vertex_7 +8,vertex_8 +9,vertex_9 +10,vertex_10 diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_cc.txt b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_cc.txt new file mode 100644 index 000000000..39db7688f --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_cc.txt @@ -0,0 +1,10 @@ +1,1 +2,1 +3,1 +4,1 +5,1 +6,1 +7,1 +8,8 +9,10 +10,10 diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_lpa.txt b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_lpa.txt new file mode 100644 index 000000000..39db7688f --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_lpa.txt @@ -0,0 +1,10 @@ +1,1 +2,1 +3,1 +4,1 +5,1 +6,1 +7,1 +8,8 +9,10 +10,10 diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_cc.sql b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_cc.sql new file mode 100644 index 000000000..e4fe4e223 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_cc.sql @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +set geaflow.dsl.window.size = -1; +set geaflow.dsl.ignore.exception = true; + +CREATE GRAPH IF NOT EXISTS test_cc_graph ( + Vertex person ( + id varchar ID, + name varchar + ), + Edge knows ( + srcId varchar SOURCE ID, + targetId varchar DESTINATION ID, + weight double + ) +) WITH ( + storeType='rocksdb', + shardCount = 1 +); + +CREATE TABLE IF NOT EXISTS tbl_source ( + id varchar, + name varchar +) WITH ( + type='file', + geaflow.dsl.file.path = 'resource:///data/algo_test_vertex.txt' +); + +CREATE TABLE IF NOT EXISTS tbl_edge_source ( + srcId varchar, + targetId varchar, + weight double +) WITH ( + type='file', + geaflow.dsl.file.path = 'resource:///data/algo_test_edges.txt' +); + +CREATE TABLE IF NOT EXISTS tbl_result ( + vid varchar, + component varchar +) WITH ( + type='file', + geaflow.dsl.file.path = '${target}' +); + +USE GRAPH test_cc_graph; + +INSERT INTO test_cc_graph.person(id, name) +SELECT id, name FROM tbl_source; + +INSERT INTO test_cc_graph.knows(srcId, targetId, weight) +SELECT srcId, targetId, weight FROM tbl_edge_source; + +INSERT INTO tbl_result(vid, component) +CALL cc(20, 'component') YIELD (id, component) +RETURN id, component +ORDER BY id +; diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_lpa.sql b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_lpa.sql new file mode 100644 index 000000000..9bd306541 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_lpa.sql @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +set geaflow.dsl.window.size = -1; +set geaflow.dsl.ignore.exception = true; + +CREATE GRAPH IF NOT EXISTS test_lpa_graph ( + Vertex person ( + id varchar ID, + name varchar + ), + Edge knows ( + srcId varchar SOURCE ID, + targetId varchar DESTINATION ID, + weight double + ) +) WITH ( + storeType='rocksdb', + shardCount = 1 +); + +CREATE TABLE IF NOT EXISTS tbl_source ( + id varchar, + name varchar +) WITH ( + type='file', + geaflow.dsl.file.path = 'resource:///data/algo_test_vertex.txt' +); + +CREATE TABLE IF NOT EXISTS tbl_edge_source ( + srcId varchar, + targetId varchar, + weight double +) WITH ( + type='file', + geaflow.dsl.file.path = 'resource:///data/algo_test_edges.txt' +); + +CREATE TABLE IF NOT EXISTS tbl_result ( + vid varchar, + label varchar +) WITH ( + type='file', + geaflow.dsl.file.path = '${target}' +); + +USE GRAPH test_lpa_graph; + +INSERT INTO test_lpa_graph.person(id, name) +SELECT id, name FROM tbl_source; + +INSERT INTO test_lpa_graph.knows(srcId, targetId, weight) +SELECT srcId, targetId, weight FROM tbl_edge_source; + +INSERT INTO tbl_result(vid, label) +CALL lpa(100, 'label') YIELD (id, label) +RETURN id, label +ORDER BY id +;