|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +package org.apache.geaflow.dsl.udf.graph; |
| 21 | + |
| 22 | +import com.google.common.collect.Lists; |
| 23 | +import com.google.common.collect.Sets; |
| 24 | +import java.util.Iterator; |
| 25 | +import java.util.List; |
| 26 | +import java.util.Objects; |
| 27 | +import java.util.Optional; |
| 28 | +import java.util.Set; |
| 29 | +import org.apache.geaflow.common.type.primitive.DoubleType; |
| 30 | +import org.apache.geaflow.common.type.primitive.IntegerType; |
| 31 | +import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; |
| 32 | +import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; |
| 33 | +import org.apache.geaflow.dsl.common.data.Row; |
| 34 | +import org.apache.geaflow.dsl.common.data.RowEdge; |
| 35 | +import org.apache.geaflow.dsl.common.data.RowVertex; |
| 36 | +import org.apache.geaflow.dsl.common.data.impl.ObjectRow; |
| 37 | +import org.apache.geaflow.dsl.common.function.Description; |
| 38 | +import org.apache.geaflow.dsl.common.types.GraphSchema; |
| 39 | +import org.apache.geaflow.dsl.common.types.StructType; |
| 40 | +import org.apache.geaflow.dsl.common.types.TableField; |
| 41 | +import org.apache.geaflow.model.graph.edge.EdgeDirection; |
| 42 | + |
| 43 | +/** |
| 44 | + * ClusterCoefficient Algorithm Implementation. |
| 45 | + * |
| 46 | + * <p>The clustering coefficient of a node measures how close its neighbors are to being |
| 47 | + * a complete graph (clique). It is calculated as the ratio of the number of edges between |
| 48 | + * neighbors to the maximum possible number of edges between them. |
| 49 | + * |
| 50 | + * <p>Formula: C(v) = 2 * T(v) / (k(v) * (k(v) - 1)) |
| 51 | + * where: |
| 52 | + * - T(v) is the number of triangles through node v |
| 53 | + * - k(v) is the degree of node v |
| 54 | + * |
| 55 | + * <p>The algorithm consists of 3 iteration phases: |
| 56 | + * 1. First iteration: Each node sends its neighbor list to all neighbors |
| 57 | + * 2. Second iteration: Each node receives neighbor lists and calculates connections |
| 58 | + * 3. Third iteration: Output final clustering coefficient results |
| 59 | + * |
| 60 | + * <p>Supports parameters: |
| 61 | + * - vertexType (optional): Filter nodes by vertex type |
| 62 | + * - minDegree (optional): Minimum degree threshold (default: 2) |
| 63 | + */ |
| 64 | +@Description(name = "cluster_coefficient", description = "built-in udga for Cluster Coefficient.") |
| 65 | +public class ClusterCoefficient implements AlgorithmUserFunction<Object, ObjectRow> { |
| 66 | + |
| 67 | + private AlgorithmRuntimeContext<Object, ObjectRow> context; |
| 68 | + |
| 69 | + private static final int MAX_ITERATION = 3; |
| 70 | + |
| 71 | + // Parameters |
| 72 | + private String vertexType = null; |
| 73 | + private int minDegree = 2; |
| 74 | + |
| 75 | + // Exclude set for nodes that don't match the vertex type filter |
| 76 | + private final Set<Object> excludeSet = Sets.newHashSet(); |
| 77 | + |
| 78 | + @Override |
| 79 | + public void init(AlgorithmRuntimeContext<Object, ObjectRow> context, Object[] params) { |
| 80 | + this.context = context; |
| 81 | + |
| 82 | + // Validate parameter count |
| 83 | + if (params.length > 2) { |
| 84 | + throw new IllegalArgumentException( |
| 85 | + "Maximum parameter limit exceeded. Expected: [vertexType], [minDegree]"); |
| 86 | + } |
| 87 | + |
| 88 | + // Parse parameters based on type |
| 89 | + // If first param is String, it's vertexType; if it's Integer/Long, it's minDegree |
| 90 | + if (params.length >= 1 && params[0] != null) { |
| 91 | + if (params[0] instanceof String) { |
| 92 | + // First param is vertexType |
| 93 | + vertexType = (String) params[0]; |
| 94 | + |
| 95 | + // Second param (if exists) is minDegree |
| 96 | + if (params.length >= 2 && params[1] != null) { |
| 97 | + if (!(params[1] instanceof Integer || params[1] instanceof Long)) { |
| 98 | + throw new IllegalArgumentException( |
| 99 | + "Minimum degree parameter should be integer."); |
| 100 | + } |
| 101 | + minDegree = params[1] instanceof Integer |
| 102 | + ? (Integer) params[1] |
| 103 | + : ((Long) params[1]).intValue(); |
| 104 | + } |
| 105 | + } else if (params[0] instanceof Integer || params[0] instanceof Long) { |
| 106 | + // First param is minDegree (no vertexType filter) |
| 107 | + vertexType = null; |
| 108 | + minDegree = params[0] instanceof Integer |
| 109 | + ? (Integer) params[0] |
| 110 | + : ((Long) params[0]).intValue(); |
| 111 | + } else { |
| 112 | + throw new IllegalArgumentException( |
| 113 | + "Parameter should be either string (vertexType) or integer (minDegree)."); |
| 114 | + } |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + @Override |
| 119 | + public void process(RowVertex vertex, Optional<Row> updatedValues, Iterator<ObjectRow> messages) { |
| 120 | + updatedValues.ifPresent(vertex::setValue); |
| 121 | + |
| 122 | + Object vertexId = vertex.getId(); |
| 123 | + long currentIteration = context.getCurrentIterationId(); |
| 124 | + |
| 125 | + if (currentIteration == 1L) { |
| 126 | + // First iteration: Check vertex type filter and send neighbor lists |
| 127 | + if (Objects.nonNull(vertexType) && !vertexType.equals(vertex.getLabel())) { |
| 128 | + excludeSet.add(vertexId); |
| 129 | + // Send heartbeat to keep vertex alive |
| 130 | + context.sendMessage(vertexId, ObjectRow.create(-1)); |
| 131 | + return; |
| 132 | + } |
| 133 | + |
| 134 | + // Load all neighbors (both directions for undirected graph) |
| 135 | + List<RowEdge> edges = context.loadEdges(EdgeDirection.BOTH); |
| 136 | + |
| 137 | + // Get unique neighbor IDs |
| 138 | + Set<Object> neighborSet = Sets.newHashSet(); |
| 139 | + for (RowEdge edge : edges) { |
| 140 | + Object neighborId = edge.getTargetId(); |
| 141 | + if (!excludeSet.contains(neighborId)) { |
| 142 | + neighborSet.add(neighborId); |
| 143 | + } |
| 144 | + } |
| 145 | + |
| 146 | + int degree = neighborSet.size(); |
| 147 | + |
| 148 | + // For nodes with degree < minDegree, clustering coefficient is 0 |
| 149 | + if (degree < minDegree) { |
| 150 | + // Store degree and triangle count = 0 |
| 151 | + context.updateVertexValue(ObjectRow.create(degree, 0)); |
| 152 | + context.sendMessage(vertexId, ObjectRow.create(-1)); |
| 153 | + return; |
| 154 | + } |
| 155 | + |
| 156 | + // Build neighbor list message: [degree, neighbor1, neighbor2, ...] |
| 157 | + List<Object> neighborInfo = Lists.newArrayList(); |
| 158 | + neighborInfo.add(degree); |
| 159 | + neighborInfo.addAll(neighborSet); |
| 160 | + |
| 161 | + ObjectRow neighborListMsg = ObjectRow.create(neighborInfo.toArray()); |
| 162 | + |
| 163 | + // Send neighbor list to all neighbors |
| 164 | + for (Object neighborId : neighborSet) { |
| 165 | + context.sendMessage(neighborId, neighborListMsg); |
| 166 | + } |
| 167 | + |
| 168 | + // Store neighbor list in vertex value for next iteration |
| 169 | + context.updateVertexValue(neighborListMsg); |
| 170 | + |
| 171 | + // Send heartbeat to self |
| 172 | + context.sendMessage(vertexId, ObjectRow.create(-1)); |
| 173 | + |
| 174 | + } else if (currentIteration == 2L) { |
| 175 | + // Second iteration: Calculate connections between neighbors |
| 176 | + if (excludeSet.contains(vertexId)) { |
| 177 | + context.sendMessage(vertexId, ObjectRow.create(-1)); |
| 178 | + return; |
| 179 | + } |
| 180 | + |
| 181 | + Row vertexValue = vertex.getValue(); |
| 182 | + if (vertexValue == null) { |
| 183 | + context.sendMessage(vertexId, ObjectRow.create(-1)); |
| 184 | + return; |
| 185 | + } |
| 186 | + |
| 187 | + int degree = (int) vertexValue.getField(0, IntegerType.INSTANCE); |
| 188 | + |
| 189 | + // For nodes with degree < minDegree, skip calculation |
| 190 | + if (degree < minDegree) { |
| 191 | + context.sendMessage(vertexId, ObjectRow.create(-1)); |
| 192 | + return; |
| 193 | + } |
| 194 | + |
| 195 | + // Get this vertex's neighbor set |
| 196 | + Set<Object> myNeighbors = row2Set(vertexValue); |
| 197 | + |
| 198 | + // Count triangles by checking common neighbors |
| 199 | + int triangleCount = 0; |
| 200 | + while (messages.hasNext()) { |
| 201 | + ObjectRow msg = messages.next(); |
| 202 | + |
| 203 | + // Skip heartbeat messages |
| 204 | + int msgDegree = (int) msg.getField(0, IntegerType.INSTANCE); |
| 205 | + if (msgDegree < 0) { |
| 206 | + continue; |
| 207 | + } |
| 208 | + |
| 209 | + // Get neighbor's neighbor set |
| 210 | + Set<Object> neighborNeighbors = row2Set(msg); |
| 211 | + |
| 212 | + // Count common neighbors (forming triangles) |
| 213 | + neighborNeighbors.retainAll(myNeighbors); |
| 214 | + triangleCount += neighborNeighbors.size(); |
| 215 | + } |
| 216 | + |
| 217 | + // Store degree and triangle count for final calculation |
| 218 | + context.updateVertexValue(ObjectRow.create(degree, triangleCount)); |
| 219 | + context.sendMessage(vertexId, ObjectRow.create(-1)); |
| 220 | + |
| 221 | + } else if (currentIteration == 3L) { |
| 222 | + // Third iteration: Calculate and output clustering coefficient |
| 223 | + if (excludeSet.contains(vertexId)) { |
| 224 | + return; |
| 225 | + } |
| 226 | + |
| 227 | + Row vertexValue = vertex.getValue(); |
| 228 | + if (vertexValue == null) { |
| 229 | + return; |
| 230 | + } |
| 231 | + |
| 232 | + int degree = (int) vertexValue.getField(0, IntegerType.INSTANCE); |
| 233 | + int triangleCount = (int) vertexValue.getField(1, IntegerType.INSTANCE); |
| 234 | + |
| 235 | + // Calculate clustering coefficient |
| 236 | + double coefficient; |
| 237 | + if (degree < minDegree) { |
| 238 | + coefficient = 0.0; |
| 239 | + } else { |
| 240 | + // C(v) = 2 * T(v) / (k(v) * (k(v) - 1)) |
| 241 | + // Note: triangleCount is already counting edges, so we divide by 2 |
| 242 | + double actualTriangles = triangleCount / 2.0; |
| 243 | + double maxPossibleEdges = degree * (degree - 1.0); |
| 244 | + coefficient = maxPossibleEdges > 0 |
| 245 | + ? (2.0 * actualTriangles) / maxPossibleEdges |
| 246 | + : 0.0; |
| 247 | + } |
| 248 | + |
| 249 | + context.take(ObjectRow.create(vertexId, coefficient)); |
| 250 | + } |
| 251 | + } |
| 252 | + |
| 253 | + @Override |
| 254 | + public void finish(RowVertex graphVertex, Optional<Row> updatedValues) { |
| 255 | + // No action needed in finish |
| 256 | + } |
| 257 | + |
| 258 | + @Override |
| 259 | + public StructType getOutputType(GraphSchema graphSchema) { |
| 260 | + return new StructType( |
| 261 | + new TableField("vid", graphSchema.getIdType(), false), |
| 262 | + new TableField("coefficient", DoubleType.INSTANCE, false) |
| 263 | + ); |
| 264 | + } |
| 265 | + |
| 266 | + /** |
| 267 | + * Convert Row to Set of neighbor IDs. |
| 268 | + * Row format: [degree, neighbor1, neighbor2, ...] |
| 269 | + */ |
| 270 | + private Set<Object> row2Set(Row row) { |
| 271 | + int degree = (int) row.getField(0, IntegerType.INSTANCE); |
| 272 | + Set<Object> neighborSet = Sets.newHashSet(); |
| 273 | + for (int i = 1; i <= degree; i++) { |
| 274 | + Object neighborId = row.getField(i, context.getGraphSchema().getIdType()); |
| 275 | + if (!excludeSet.contains(neighborId)) { |
| 276 | + neighborSet.add(neighborId); |
| 277 | + } |
| 278 | + } |
| 279 | + return neighborSet; |
| 280 | + } |
| 281 | +} |
0 commit comments