diff --git a/components/ColorAnalyzer/src/ColorPaletteSampler/ColorPaletteSampler.DBScan.cs b/components/ColorAnalyzer/src/ColorPaletteSampler/ColorPaletteSampler.DBScan.cs new file mode 100644 index 000000000..3ee14cb60 --- /dev/null +++ b/components/ColorAnalyzer/src/ColorPaletteSampler/ColorPaletteSampler.DBScan.cs @@ -0,0 +1,160 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Numerics; + +namespace CommunityToolkit.WinUI.Helpers; + +public partial class ColorPaletteSampler +{ + private ref struct DBScan + { + private const int Unclassified = -1; + + public static Vector3[] Cluster(Span points, float epsilon, int minPoints, ref float[] weights) + { + var centroids = new List(); + var newWeights = new List(); + + // Create context + var context = new DBScan(points, weights, epsilon, minPoints); + + // Attempt to create a cluster around each point, + // skipping that point if already classified + for (int i = 0; i < points.Length; i++) + { + // Already classified, skip + if (context.PointClusterIds[i] is not Unclassified) + continue; + + // Attempt to create cluster + if(context.CreateCluster(i, out var centroid, out var weight)) + { + centroids.Add(centroid); + newWeights.Add(weight); + } + } + + weights = newWeights.ToArray(); + return centroids.ToArray(); + } + + private bool CreateCluster(int originIndex, out Vector3 centroid, out float weight) + { + weight = 0; + centroid = Vector3.Zero; + var seeds = GetSeeds(originIndex, out bool isCore); + + // Not enough seeds to be a core point. + // Cannot create a cluster around it + if (!isCore) + { + return false; + } + + ExpandCluster(seeds, out centroid, out weight); + ClusterId++; + + return true; + } + + private void ExpandCluster(Queue seeds, out Vector3 centroid, out float weight) + { + weight = 0; + centroid = Vector3.Zero; + while(seeds.Count > 0) + { + var seedIndex = seeds.Dequeue(); + + // Skip duplicate seed entries + if (PointClusterIds[seedIndex] is not Unclassified) + continue; + + // Assign this seed's id to the cluster + PointClusterIds[seedIndex] = ClusterId; + var w = Weights[seedIndex]; + centroid += Points[seedIndex] * w; + weight += w; + + // Check if this seed is a core point + var grandSeeds = GetSeeds(seedIndex, out var seedIsCore); + if (!seedIsCore) + continue; + + // This seed is a core point. Enqueue all its seeds + foreach(var grandSeedIndex in grandSeeds) + if (PointClusterIds[grandSeedIndex] is Unclassified) + seeds.Enqueue(grandSeedIndex); + } + + centroid /= weight; + } + + private Queue GetSeeds(int originIndex, out bool isCore) + { + var origin = Points[originIndex]; + + // NOTE: Seeding could be done using a spatial data structure to improve traversal + // speeds. However currently DBSCAN is run after KMeans with a maximum of 8 points. + // There is no need. + + var seeds = new Queue(); + for (int i = 0; i < Points.Length; i++) + { + if (Vector3.DistanceSquared(origin, Points[i]) <= Epsilon2) + seeds.Enqueue(i); + } + + // Count includes self, so compare without checking equals + isCore = seeds.Count > MinPoints; + return seeds; + } + + private DBScan(Span points, Span weights, float epsilon, int minPoints) + { + Points = points; + Weights = weights; + Epsilon2 = epsilon * epsilon; + MinPoints = minPoints; + + ClusterId = 0; + PointClusterIds = new int[points.Length]; + for(int i = 0; i < points.Length; i++) + PointClusterIds[i] = Unclassified; + } + + /// + /// Gets the points being clustered. + /// + public Span Points { get; } + + /// + /// Gets the weights of the points. + /// + public Span Weights { get; } + + /// + /// Gets or sets the id of the currently evaluating cluster. + /// + public int ClusterId { get; set; } + + /// + /// Gets an array containing the id of the cluster each point belongs to. + /// + public int[] PointClusterIds { get; } + + /// + /// Gets epsilon squared. Where epsilon is the max distance to consider two points connected. + /// + /// + /// This is cached as epsilon squared to skip a sqrt operation when comparing distances to epsilon. + /// + public double Epsilon2 { get; } + + /// + /// Gets the minimum number of points required to make a core point. + /// + public int MinPoints { get; } + } +} diff --git a/components/ColorAnalyzer/src/ColorPaletteSampler/ColorPaletteSampler.Clustering.cs b/components/ColorAnalyzer/src/ColorPaletteSampler/ColorPaletteSampler.KMeans.cs similarity index 97% rename from components/ColorAnalyzer/src/ColorPaletteSampler/ColorPaletteSampler.Clustering.cs rename to components/ColorAnalyzer/src/ColorPaletteSampler/ColorPaletteSampler.KMeans.cs index e6a17bfe0..832ee0c4b 100644 --- a/components/ColorAnalyzer/src/ColorPaletteSampler/ColorPaletteSampler.Clustering.cs +++ b/components/ColorAnalyzer/src/ColorPaletteSampler/ColorPaletteSampler.KMeans.cs @@ -69,7 +69,7 @@ private static void Split(int k, int[] clusterIds) /// /// Calculates the centroid of each cluster, and prunes empty clusters. /// - private static void CalculateCentroidsAndPrune(ref Span centroids, ref int[] counts, Span points, int[] clusterIds) + internal static void CalculateCentroidsAndPrune(ref Span centroids, ref int[] counts, Span points, int[] clusterIds) { // Clear centroids and counts before recalculation for (int i = 0; i < centroids.Length; i++) diff --git a/components/ColorAnalyzer/src/ColorPaletteSampler/ColorPaletteSampler.cs b/components/ColorAnalyzer/src/ColorPaletteSampler/ColorPaletteSampler.cs index bfa4b68ab..c5bb74fff 100644 --- a/components/ColorAnalyzer/src/ColorPaletteSampler/ColorPaletteSampler.cs +++ b/components/ColorAnalyzer/src/ColorPaletteSampler/ColorPaletteSampler.cs @@ -52,6 +52,7 @@ public async Task UpdatePaletteAsync() const int sampleCount = 4096; const int k = 8; + const float mergeDistance = 0.12f; // Retreive pixel samples from source var samples = await SampleSourcePixelColorsAsync(sampleCount); @@ -62,8 +63,11 @@ public async Task UpdatePaletteAsync() // Cluster samples in RGB floating-point color space // With Euclidean Squared distance function, then construct palette data. - var clusters = KMeansCluster(samples, k, out var sizes); - var colorData = clusters.Select((vectorColor, i) => new PaletteColor(vectorColor.ToColor(), (float)sizes[i] / samples.Length)); + // Merge KMeans results that are too similar, using DBScan + var kClusters = KMeansCluster(samples, k, out var counts); + var weights = counts.Select(x => (float)x / samples.Length).ToArray(); + var dbCluster = DBScan.Cluster(kClusters, mergeDistance, 0, ref weights); + var colorData = dbCluster.Select((vectorColor, i) => new PaletteColor(vectorColor.ToColor(), weights[i])); // Update palettes on the UI thread foreach (var palette in PaletteSelectors)