@@ -14,21 +14,11 @@ private static Vector3[] KMeansCluster(Span<Vector3> points, int k, out int[] co
1414 int [ ] clusterIds = new int [ points . Length ] ;
1515
1616 // Track the centroids of each cluster and its member count
17- // TODO: stackalloc is great here, but pooling should be thresholded
18- // just in case
1917 Span < Vector3 > centroids = stackalloc Vector3 [ k ] ;
2018 counts = new int [ k ] ;
21-
19+
2220 // Split the points into arbitrary clusters
23- // NOTE: Can this be rearranged to converge faster?
24- #if NET6_0_OR_GREATER
25- var offset = Random . Shared . Next ( k ) ; // Mathematically true random sampling
26- #else
27- var rand = new Random ( ) ;
28- var offset = rand . Next ( k ) ;
29- #endif
30- for ( int i = 0 ; i < clusterIds . Length ; i ++ )
31- clusterIds [ i ] = ( i + offset ) % k ;
21+ Split ( k , clusterIds ) ;
3222
3323 bool converged = false ;
3424 while ( ! converged )
@@ -37,83 +27,16 @@ private static Vector3[] KMeansCluster(Span<Vector3> points, int k, out int[] co
3727 // to false when adjust the clusters
3828 converged = true ;
3929
40- // KMeans Loop Step 1:
41- // Calculate/Recalculate the centroids of each cluster
42-
43- // Clear centroids and counts before recalculation
44- for ( int i = 0 ; i < centroids . Length ; i ++ )
45- {
46- centroids [ i ] = Vector3 . Zero ;
47- counts [ i ] = 0 ;
48- }
49-
50- // Accumulate step in centroid calculation
51- for ( int i = 0 ; i < clusterIds . Length ; i ++ )
52- {
53- int id = clusterIds [ i ] ;
54- centroids [ id ] += points [ i ] ;
55- counts [ id ] ++ ;
56- }
30+ // Calculate/Recalculate centroids
31+ CalculateCentroidsAndPrune ( ref centroids , ref counts , points , clusterIds ) ;
5732
58- // Prune empty clusters
59- // All empty clusters are swapped to the end of the span
60- // then a slice is taken with only the remaining populated clusters
61- int pivot = counts . Length ;
62- for ( int i = 0 ; i < pivot ; )
63- {
64- // Increment and continue if populated
65- if ( counts [ i ] != 0 )
66- {
67- i ++ ;
68- continue ;
69- }
70-
71- // The item is not populated. Swap to end and move pivot
72- // NOTE: This is a oneway swap. We're discarding the 0 anyways.
73- pivot -- ;
74- counts [ i ] = counts [ pivot ] ;
75- }
76-
77- #if ! WINDOWS_UWP
78- counts = counts [ ..pivot ] ;
79- centroids = centroids [ ..pivot ] ;
80- #elif WINDOWS_UWP
81- Array . Resize ( ref counts , pivot ) ;
82- centroids = centroids . Slice ( 0 , pivot ) ;
83- #endif
84-
85- // Division step in centroid calculation
86- for ( int i = 0 ; i < centroids . Length ; i ++ )
87- centroids [ i ] /= counts [ i ] ;
88-
89- // KMeans Loop Step 2:
9033 // Move each point's clusterId to the nearest cluster centroid
9134 for ( int i = 0 ; i < points . Length ; i ++ )
9235 {
93- Vector3 point = points [ i ] ;
94- var oldId = clusterIds [ i ] ;
95-
96- // Track the nearest centroid's distance and the index of that centroid
97- float nearestDistance = float . PositiveInfinity ;
98- int nearestIndex = - 1 ;
99-
100- for ( int j = 0 ; j < centroids . Length ; j ++ )
101- {
102- // Compare the point to the jth centroid
103- float distance = Vector3 . DistanceSquared ( point , centroids [ j ] ) ;
104-
105- // Skip the cluster if further than the nearest seen cluster
106- if ( nearestDistance < distance )
107- continue ;
108-
109- // This is the nearest cluster
110- // Update the distance and index
111- nearestDistance = distance ;
112- nearestIndex = j ;
113- }
36+ var nearestIndex = FindNearestClusterIndex ( points [ i ] , centroids ) ;
11437
11538 // The nearest cluster hasn't changed. Do nothing
116- if ( oldId == nearestIndex )
39+ if ( clusterIds [ i ] == nearestIndex )
11740 continue ;
11841
11942 // Update the cluster id and note that we have not converged
@@ -125,6 +48,105 @@ private static Vector3[] KMeansCluster(Span<Vector3> points, int k, out int[] co
12548 return centroids . ToArray ( ) ;
12649 }
12750
51+ /// <summary>
52+ /// Assigns arbitrary clusterIds for each point
53+ /// </summary>
54+ private static void Split ( int k , int [ ] clusterIds )
55+ {
56+ // Mathematically true random sampling
57+ #if NET6_0_OR_GREATER
58+ var offset = Random . Shared . Next ( k ) ;
59+ #else
60+ var rand = new Random ( ) ;
61+ var offset = rand . Next ( k ) ;
62+ #endif
63+
64+ // Assign each clusters id
65+ for ( int i = 0 ; i < clusterIds . Length ; i ++ )
66+ clusterIds [ i ] = ( i + offset ) % k ;
67+ }
68+
69+ /// <summary>
70+ /// Calculates the centroid of each cluster, and prunes empty clusters.
71+ /// </summary>
72+ private static void CalculateCentroidsAndPrune ( ref Span < Vector3 > centroids , ref int [ ] counts , Span < Vector3 > points , int [ ] clusterIds )
73+ {
74+ // Clear centroids and counts before recalculation
75+ for ( int i = 0 ; i < centroids . Length ; i ++ )
76+ {
77+ centroids [ i ] = Vector3 . Zero ;
78+ counts [ i ] = 0 ;
79+ }
80+
81+ // Accumulate step in centroid calculation
82+ for ( int i = 0 ; i < clusterIds . Length ; i ++ )
83+ {
84+ int id = clusterIds [ i ] ;
85+ centroids [ id ] += points [ i ] ;
86+ counts [ id ] ++ ;
87+ }
88+
89+ // Prune empty clusters
90+ // All empty clusters are swapped to the end of the span
91+ // then a slice is taken with only the remaining populated clusters
92+ int pivot = counts . Length ;
93+ for ( int i = 0 ; i < pivot ; )
94+ {
95+ // Increment and continue if populated
96+ if ( counts [ i ] != 0 )
97+ {
98+ i ++ ;
99+ continue ;
100+ }
101+
102+ // The item is not populated. Swap to end and move pivot
103+ // NOTE: This is a one-way "swap". We're discarding the 0s anyways.
104+ pivot -- ;
105+ centroids [ i ] = centroids [ pivot ] ;
106+ counts [ i ] = counts [ pivot ] ;
107+ }
108+
109+ // Perform slice
110+ #if ! WINDOWS_UWP
111+ counts = counts [ ..pivot ] ;
112+ centroids = centroids [ ..pivot ] ;
113+ #elif WINDOWS_UWP
114+ Array . Resize ( ref counts , pivot ) ;
115+ centroids = centroids . Slice ( 0 , pivot ) ;
116+ #endif
117+
118+ // Division step in centroid calculation
119+ for ( int i = 0 ; i < centroids . Length ; i ++ )
120+ centroids [ i ] /= counts [ i ] ;
121+ }
122+
123+ /// <summary>
124+ /// Finds the index of the centroid nearest the point
125+ /// </summary>
126+ private static int FindNearestClusterIndex ( Vector3 point , Span < Vector3 > centroids )
127+ {
128+ // Track the nearest centroid's distance and the index of that centroid
129+ float nearestDistance = float . PositiveInfinity ;
130+ int nearestIndex = - 1 ;
131+
132+ for ( int j = 0 ; j < centroids . Length ; j ++ )
133+ {
134+ // Compare the point to the jth centroid
135+ float distance = Vector3 . DistanceSquared ( point , centroids [ j ] ) ;
136+
137+ // Skip the cluster if further than the nearest seen cluster
138+ if ( nearestDistance < distance )
139+ continue ;
140+
141+ // This is the nearest cluster
142+ // Update the distance and index
143+ nearestDistance = distance ;
144+ nearestIndex = j ;
145+ }
146+
147+ return nearestIndex ;
148+ }
149+
128150 private static float FindColorfulness ( Vector3 color )
129151 {
130152 var rg = color . X - color . Y ;
0 commit comments