Skip to content

Commit 3f7b695

Browse files
feat: support cluster coefficient (#640)
* chore: support coefficient * refactor: enhance logic && add tests
1 parent 5c89285 commit 3f7b695

15 files changed

+6017
-0
lines changed

geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.apache.geaflow.dsl.schema.GeaFlowFunction;
3737
import org.apache.geaflow.dsl.udf.graph.AllSourceShortestPath;
3838
import org.apache.geaflow.dsl.udf.graph.ClosenessCentrality;
39+
import org.apache.geaflow.dsl.udf.graph.ClusterCoefficient;
3940
import org.apache.geaflow.dsl.udf.graph.CommonNeighbors;
4041
import org.apache.geaflow.dsl.udf.graph.IncKHopAlgorithm;
4142
import org.apache.geaflow.dsl.udf.graph.IncMinimumSpanningTree;
@@ -217,6 +218,7 @@ public class BuildInSqlFunctionTable extends ListSqlOperatorTable {
217218
.add(GeaFlowFunction.of(ClosenessCentrality.class))
218219
.add(GeaFlowFunction.of(WeakConnectedComponents.class))
219220
.add(GeaFlowFunction.of(TriangleCount.class))
221+
.add(GeaFlowFunction.of(ClusterCoefficient.class))
220222
.add(GeaFlowFunction.of(IncWeakConnectedComponents.class))
221223
.add(GeaFlowFunction.of(CommonNeighbors.class))
222224
.add(GeaFlowFunction.of(JaccardSimilarity.class))
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.geaflow.dsl.udf.graph;
21+
22+
import com.google.common.collect.Lists;
23+
import com.google.common.collect.Sets;
24+
import java.util.Iterator;
25+
import java.util.List;
26+
import java.util.Objects;
27+
import java.util.Optional;
28+
import java.util.Set;
29+
import org.apache.geaflow.common.type.primitive.DoubleType;
30+
import org.apache.geaflow.common.type.primitive.IntegerType;
31+
import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext;
32+
import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction;
33+
import org.apache.geaflow.dsl.common.data.Row;
34+
import org.apache.geaflow.dsl.common.data.RowEdge;
35+
import org.apache.geaflow.dsl.common.data.RowVertex;
36+
import org.apache.geaflow.dsl.common.data.impl.ObjectRow;
37+
import org.apache.geaflow.dsl.common.function.Description;
38+
import org.apache.geaflow.dsl.common.types.GraphSchema;
39+
import org.apache.geaflow.dsl.common.types.StructType;
40+
import org.apache.geaflow.dsl.common.types.TableField;
41+
import org.apache.geaflow.model.graph.edge.EdgeDirection;
42+
43+
/**
44+
* ClusterCoefficient Algorithm Implementation.
45+
*
46+
* <p>The clustering coefficient of a node measures how close its neighbors are to being
47+
* a complete graph (clique). It is calculated as the ratio of the number of edges between
48+
* neighbors to the maximum possible number of edges between them.
49+
*
50+
* <p>Formula: C(v) = 2 * T(v) / (k(v) * (k(v) - 1))
51+
* where:
52+
* - T(v) is the number of triangles through node v
53+
* - k(v) is the degree of node v
54+
*
55+
* <p>The algorithm consists of 3 iteration phases:
56+
* 1. First iteration: Each node sends its neighbor list to all neighbors
57+
* 2. Second iteration: Each node receives neighbor lists and calculates connections
58+
* 3. Third iteration: Output final clustering coefficient results
59+
*
60+
* <p>Supports parameters:
61+
* - vertexType (optional): Filter nodes by vertex type
62+
* - minDegree (optional): Minimum degree threshold (default: 2)
63+
*/
64+
@Description(name = "cluster_coefficient", description = "built-in udga for Cluster Coefficient.")
65+
public class ClusterCoefficient implements AlgorithmUserFunction<Object, ObjectRow> {
66+
67+
private AlgorithmRuntimeContext<Object, ObjectRow> context;
68+
69+
private static final int MAX_ITERATION = 3;
70+
71+
// Parameters
72+
private String vertexType = null;
73+
private int minDegree = 2;
74+
75+
// Exclude set for nodes that don't match the vertex type filter
76+
private final Set<Object> excludeSet = Sets.newHashSet();
77+
78+
@Override
79+
public void init(AlgorithmRuntimeContext<Object, ObjectRow> context, Object[] params) {
80+
this.context = context;
81+
82+
// Validate parameter count
83+
if (params.length > 2) {
84+
throw new IllegalArgumentException(
85+
"Maximum parameter limit exceeded. Expected: [vertexType], [minDegree]");
86+
}
87+
88+
// Parse parameters based on type
89+
// If first param is String, it's vertexType; if it's Integer/Long, it's minDegree
90+
if (params.length >= 1 && params[0] != null) {
91+
if (params[0] instanceof String) {
92+
// First param is vertexType
93+
vertexType = (String) params[0];
94+
95+
// Second param (if exists) is minDegree
96+
if (params.length >= 2 && params[1] != null) {
97+
if (!(params[1] instanceof Integer || params[1] instanceof Long)) {
98+
throw new IllegalArgumentException(
99+
"Minimum degree parameter should be integer.");
100+
}
101+
minDegree = params[1] instanceof Integer
102+
? (Integer) params[1]
103+
: ((Long) params[1]).intValue();
104+
}
105+
} else if (params[0] instanceof Integer || params[0] instanceof Long) {
106+
// First param is minDegree (no vertexType filter)
107+
vertexType = null;
108+
minDegree = params[0] instanceof Integer
109+
? (Integer) params[0]
110+
: ((Long) params[0]).intValue();
111+
} else {
112+
throw new IllegalArgumentException(
113+
"Parameter should be either string (vertexType) or integer (minDegree).");
114+
}
115+
}
116+
}
117+
118+
@Override
119+
public void process(RowVertex vertex, Optional<Row> updatedValues, Iterator<ObjectRow> messages) {
120+
updatedValues.ifPresent(vertex::setValue);
121+
122+
Object vertexId = vertex.getId();
123+
long currentIteration = context.getCurrentIterationId();
124+
125+
if (currentIteration == 1L) {
126+
// First iteration: Check vertex type filter and send neighbor lists
127+
if (Objects.nonNull(vertexType) && !vertexType.equals(vertex.getLabel())) {
128+
excludeSet.add(vertexId);
129+
// Send heartbeat to keep vertex alive
130+
context.sendMessage(vertexId, ObjectRow.create(-1));
131+
return;
132+
}
133+
134+
// Load all neighbors (both directions for undirected graph)
135+
List<RowEdge> edges = context.loadEdges(EdgeDirection.BOTH);
136+
137+
// Get unique neighbor IDs
138+
Set<Object> neighborSet = Sets.newHashSet();
139+
for (RowEdge edge : edges) {
140+
Object neighborId = edge.getTargetId();
141+
if (!excludeSet.contains(neighborId)) {
142+
neighborSet.add(neighborId);
143+
}
144+
}
145+
146+
int degree = neighborSet.size();
147+
148+
// For nodes with degree < minDegree, clustering coefficient is 0
149+
if (degree < minDegree) {
150+
// Store degree and triangle count = 0
151+
context.updateVertexValue(ObjectRow.create(degree, 0));
152+
context.sendMessage(vertexId, ObjectRow.create(-1));
153+
return;
154+
}
155+
156+
// Build neighbor list message: [degree, neighbor1, neighbor2, ...]
157+
List<Object> neighborInfo = Lists.newArrayList();
158+
neighborInfo.add(degree);
159+
neighborInfo.addAll(neighborSet);
160+
161+
ObjectRow neighborListMsg = ObjectRow.create(neighborInfo.toArray());
162+
163+
// Send neighbor list to all neighbors
164+
for (Object neighborId : neighborSet) {
165+
context.sendMessage(neighborId, neighborListMsg);
166+
}
167+
168+
// Store neighbor list in vertex value for next iteration
169+
context.updateVertexValue(neighborListMsg);
170+
171+
// Send heartbeat to self
172+
context.sendMessage(vertexId, ObjectRow.create(-1));
173+
174+
} else if (currentIteration == 2L) {
175+
// Second iteration: Calculate connections between neighbors
176+
if (excludeSet.contains(vertexId)) {
177+
context.sendMessage(vertexId, ObjectRow.create(-1));
178+
return;
179+
}
180+
181+
Row vertexValue = vertex.getValue();
182+
if (vertexValue == null) {
183+
context.sendMessage(vertexId, ObjectRow.create(-1));
184+
return;
185+
}
186+
187+
int degree = (int) vertexValue.getField(0, IntegerType.INSTANCE);
188+
189+
// For nodes with degree < minDegree, skip calculation
190+
if (degree < minDegree) {
191+
context.sendMessage(vertexId, ObjectRow.create(-1));
192+
return;
193+
}
194+
195+
// Get this vertex's neighbor set
196+
Set<Object> myNeighbors = row2Set(vertexValue);
197+
198+
// Count triangles by checking common neighbors
199+
int triangleCount = 0;
200+
while (messages.hasNext()) {
201+
ObjectRow msg = messages.next();
202+
203+
// Skip heartbeat messages
204+
int msgDegree = (int) msg.getField(0, IntegerType.INSTANCE);
205+
if (msgDegree < 0) {
206+
continue;
207+
}
208+
209+
// Get neighbor's neighbor set
210+
Set<Object> neighborNeighbors = row2Set(msg);
211+
212+
// Count common neighbors (forming triangles)
213+
neighborNeighbors.retainAll(myNeighbors);
214+
triangleCount += neighborNeighbors.size();
215+
}
216+
217+
// Store degree and triangle count for final calculation
218+
context.updateVertexValue(ObjectRow.create(degree, triangleCount));
219+
context.sendMessage(vertexId, ObjectRow.create(-1));
220+
221+
} else if (currentIteration == 3L) {
222+
// Third iteration: Calculate and output clustering coefficient
223+
if (excludeSet.contains(vertexId)) {
224+
return;
225+
}
226+
227+
Row vertexValue = vertex.getValue();
228+
if (vertexValue == null) {
229+
return;
230+
}
231+
232+
int degree = (int) vertexValue.getField(0, IntegerType.INSTANCE);
233+
int triangleCount = (int) vertexValue.getField(1, IntegerType.INSTANCE);
234+
235+
// Calculate clustering coefficient
236+
double coefficient;
237+
if (degree < minDegree) {
238+
coefficient = 0.0;
239+
} else {
240+
// C(v) = 2 * T(v) / (k(v) * (k(v) - 1))
241+
// Note: triangleCount is already counting edges, so we divide by 2
242+
double actualTriangles = triangleCount / 2.0;
243+
double maxPossibleEdges = degree * (degree - 1.0);
244+
coefficient = maxPossibleEdges > 0
245+
? (2.0 * actualTriangles) / maxPossibleEdges
246+
: 0.0;
247+
}
248+
249+
context.take(ObjectRow.create(vertexId, coefficient));
250+
}
251+
}
252+
253+
@Override
254+
public void finish(RowVertex graphVertex, Optional<Row> updatedValues) {
255+
// No action needed in finish
256+
}
257+
258+
@Override
259+
public StructType getOutputType(GraphSchema graphSchema) {
260+
return new StructType(
261+
new TableField("vid", graphSchema.getIdType(), false),
262+
new TableField("coefficient", DoubleType.INSTANCE, false)
263+
);
264+
}
265+
266+
/**
267+
* Convert Row to Set of neighbor IDs.
268+
* Row format: [degree, neighbor1, neighbor2, ...]
269+
*/
270+
private Set<Object> row2Set(Row row) {
271+
int degree = (int) row.getField(0, IntegerType.INSTANCE);
272+
Set<Object> neighborSet = Sets.newHashSet();
273+
for (int i = 1; i <= degree; i++) {
274+
Object neighborId = row.getField(i, context.getGraphSchema().getIdType());
275+
if (!excludeSet.contains(neighborId)) {
276+
neighborSet.add(neighborId);
277+
}
278+
}
279+
return neighborSet;
280+
}
281+
}

geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,44 @@ public void testAlgorithmTriangleCount() throws Exception {
143143
.checkSinkResult();
144144
}
145145

146+
@Test
147+
public void testAlgorithmClusterCoefficient() throws Exception {
148+
QueryTester
149+
.build()
150+
.withGraphDefine("/query/modern_graph.sql")
151+
.withQueryPath("/query/gql_algorithm_cluster_coefficient.sql")
152+
.execute()
153+
.checkSinkResult();
154+
}
155+
156+
@Test
157+
public void testAlgorithmClusterCoefficientWithParams() throws Exception {
158+
QueryTester
159+
.build()
160+
.withGraphDefine("/query/modern_graph.sql")
161+
.withQueryPath("/query/gql_algorithm_cluster_coefficient_with_params.sql")
162+
.execute()
163+
.checkSinkResult();
164+
}
165+
166+
@Test
167+
public void testAlgorithmClusterCoefficientMedium() throws Exception {
168+
QueryTester
169+
.build()
170+
.withQueryPath("/query/gql_algorithm_cluster_coefficient_medium.sql")
171+
.execute()
172+
.checkSinkResult();
173+
}
174+
175+
@Test
176+
public void testAlgorithmClusterCoefficientLarge() throws Exception {
177+
QueryTester
178+
.build()
179+
.withQueryPath("/query/gql_algorithm_cluster_coefficient_large.sql")
180+
.execute()
181+
.checkSinkResult();
182+
}
183+
146184
@Test
147185
public void testIncGraphAlgorithm_001() throws Exception {
148186
QueryTester

0 commit comments

Comments
 (0)