diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ClusterCoefficient.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ClusterCoefficient.java index 266195467..d9451de55 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ClusterCoefficient.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/ClusterCoefficient.java @@ -21,11 +21,15 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.concurrent.ThreadLocalRandom; import org.apache.geaflow.common.type.primitive.DoubleType; import org.apache.geaflow.common.type.primitive.IntegerType; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; @@ -60,18 +64,19 @@ *

Supports parameters: * - vertexType (optional): Filter nodes by vertex type * - minDegree (optional): Minimum degree threshold (default: 2) + * - samplingThreshold + sampleSize (optional): Minimum sampling threshold and sampling Size(default: 0) */ @Description(name = "cluster_coefficient", description = "built-in udga for Cluster Coefficient.") public class ClusterCoefficient implements AlgorithmUserFunction { private AlgorithmRuntimeContext context; - private static final int MAX_ITERATION = 3; - // Parameters private String vertexType = null; private int minDegree = 2; - + private int samplingThreshold = 0; + private int sampleSize = 0; + // Exclude set for nodes that don't match the vertex type filter private final Set excludeSet = Sets.newHashSet(); @@ -80,39 +85,20 @@ public void init(AlgorithmRuntimeContext context, Object[] pa this.context = context; // Validate parameter count - if (params.length > 2) { + if (params.length > 4) { throw new IllegalArgumentException( - "Maximum parameter limit exceeded. Expected: [vertexType], [minDegree]"); + "Maximum parameter limit exceeded. Expected: [vertexType], [minDegree], [samplingThreshold], [sampleSize]"); } - - // Parse parameters based on type - // If first param is String, it's vertexType; if it's Integer/Long, it's minDegree - if (params.length >= 1 && params[0] != null) { - if (params[0] instanceof String) { - // First param is vertexType - vertexType = (String) params[0]; - - // Second param (if exists) is minDegree - if (params.length >= 2 && params[1] != null) { - if (!(params[1] instanceof Integer || params[1] instanceof Long)) { - throw new IllegalArgumentException( - "Minimum degree parameter should be integer."); - } - minDegree = params[1] instanceof Integer - ? (Integer) params[1] - : ((Long) params[1]).intValue(); - } - } else if (params[0] instanceof Integer || params[0] instanceof Long) { - // First param is minDegree (no vertexType filter) - vertexType = null; - minDegree = params[0] instanceof Integer - ? (Integer) params[0] - : ((Long) params[0]).intValue(); - } else { - throw new IllegalArgumentException( - "Parameter should be either string (vertexType) or integer (minDegree)."); + + // Validate parameter not null + for (Object param : params) { + if (param == null) { + throw new IllegalArgumentException("Parameter should not be null."); } } + + // Parse parameters + parseParameters(params); } @Override @@ -152,7 +138,12 @@ public void process(RowVertex vertex, Optional updatedValues, Iterator samplingThreshold && sampleSize > 0) { + neighborSet = sampleNeighbors(neighborSet); + } + // Build neighbor list message: [degree, neighbor1, neighbor2, ...] List neighborInfo = Lists.newArrayList(); neighborInfo.add(degree); @@ -278,4 +269,105 @@ private Set row2Set(Row row) { } return neighborSet; } + + /** + * sample some neighbors. + * @param neighbors origin neighbors + * @return sampled neighbors + */ + private Set sampleNeighbors(Set neighbors) { + // Strategy selection threshold: + // If only a very small portion needs to be sampled (e.g., less than 5%), use the index randomization method. + // Avoid copying the entire huge list + if (sampleSize < neighbors.size() * 0.05) { + return pickRandomIndices(neighbors, neighbors.size()); + } + + // Otherwise, use partial shuffling. + return partialShuffle(neighbors, neighbors.size()); + } + + /** + * Auxiliary method: Sampling by random index (to save memory). + * @param neighbors origin neighbors + * @param totalSize origin neighbor's size + * @return sampled neighbors + */ + private Set pickRandomIndices(Set neighbors, int totalSize) { + // Use List to index the element + List neighborList = new ArrayList<>(neighbors); + Set result = new HashSet<>(sampleSize); + // Use Set to ensure that indices are not repeated. + Set selectedIndices = new HashSet<>(sampleSize); + ThreadLocalRandom rnd = ThreadLocalRandom.current(); + + while (selectedIndices.size() < sampleSize) { + int idx = rnd.nextInt(totalSize); + // If add returns true, it means it's a new index. + if (selectedIndices.add(idx)) { + result.add(neighborList.get(idx)); + } + } + return result; + } + + /** + * Auxiliary method: Partial shuffling. + * @param neighbors origin neighbors + * @param totalSize origin neighbor's size + * @return sampled neighbors + */ + private Set partialShuffle(Set neighbors, int totalSize) { + List copy = new ArrayList<>(neighbors); + ThreadLocalRandom rnd = ThreadLocalRandom.current(); + for (int i = 0; i < sampleSize; i++) { + Collections.swap(copy, i, i + rnd.nextInt(totalSize - i)); + } + return new HashSet<>(copy.subList(0, sampleSize)); + } + + /** + * Parse parameters from params. + * @param params params + */ + private void parseParameters(Object[] params) { + // If params.length == 1, params are [vertexType] or [minDegree] + if (params.length == 1) { + if (params[0] instanceof String) { + vertexType = (String) params[0]; + } else { + minDegree = (Integer) params[0]; + } + return; + } + // If params.length == 2, params is [vertexType, minDegree] or [samplingThreshold, sampleSize] + if (params.length == 2) { + if (params[0] instanceof String) { + vertexType = (String) params[0]; + minDegree = (Integer) params[1]; + } else { + samplingThreshold = (Integer) params[0]; + sampleSize = (Integer) params[1]; + } + return; + } + // If params.length == 3, params is [vertexType, samplingThreshold, sampleSize] or [minDegree, samplingThreshold, sampleSize] + if (params.length == 3) { + if (params[0] instanceof String) { + vertexType = (String) params[0]; + } else { + minDegree = (Integer) params[0]; + } + samplingThreshold = (Integer) params[1]; + sampleSize = (Integer) params[2]; + return; + } + // If params.length == 4, params is [vertexType, minDegree, samplingThreshold, sampleSize] + if (params.length == 4) { + vertexType = (String) params[0]; + minDegree = (Integer) params[1]; + samplingThreshold = (Integer) params[2]; + sampleSize = (Integer) params[3]; + } + } }