@@ -14,21 +14,11 @@ private static Vector3[] KMeansCluster(Span<Vector3> points, int k, out int[] co
1414        int [ ]  clusterIds  =  new  int [ points . Length ] ; 
1515
1616        // Track the centroids of each cluster and its member count 
17-         // TODO: stackalloc is great here, but pooling should be thresholded 
18-         // just in case 
1917        Span < Vector3 >  centroids  =  stackalloc  Vector3 [ k ] ; 
2018        counts  =  new  int [ k ] ; 
21- 
19+          
2220        // Split the points into arbitrary clusters 
23-         // NOTE: Can this be rearranged to converge faster? 
24-         #if NET6_0_OR_GREATER 
25-         var  offset  =  Random . Shared . Next ( k ) ;  // Mathematically true random sampling 
26-         #else
27-         var  rand  =  new  Random ( ) ; 
28-         var  offset  =  rand . Next ( k ) ; 
29-         #endif
30-         for  ( int  i  =  0 ;  i  <  clusterIds . Length ;  i ++ ) 
31-             clusterIds [ i ]  =  ( i  +  offset )  %  k ; 
21+         Split ( k ,  clusterIds ) ; 
3222
3323        bool  converged  =  false ; 
3424        while  ( ! converged ) 
@@ -37,83 +27,16 @@ private static Vector3[] KMeansCluster(Span<Vector3> points, int k, out int[] co
3727            // to false when adjust the clusters 
3828            converged  =  true ; 
3929
40-             // KMeans Loop Step 1: 
41-             // Calculate/Recalculate the centroids of each cluster 
42- 
43-             // Clear centroids and counts before recalculation 
44-             for ( int  i  =  0 ;  i  <  centroids . Length ;  i ++ ) 
45-             { 
46-                 centroids [ i ]  =  Vector3 . Zero ; 
47-                 counts [ i ]  =  0 ; 
48-             } 
49- 
50-             // Accumulate step in centroid calculation 
51-             for ( int  i  =  0 ;  i  <  clusterIds . Length ;  i ++ ) 
52-             { 
53-                 int  id  =  clusterIds [ i ] ; 
54-                 centroids [ id ]  +=  points [ i ] ; 
55-                 counts [ id ] ++ ; 
56-             } 
30+             // Calculate/Recalculate centroids 
31+             CalculateCentroidsAndPrune ( ref  centroids ,  ref  counts ,  points ,  clusterIds ) ; 
5732
58-             // Prune empty clusters 
59-             // All empty clusters are swapped to the end of the span 
60-             // then a slice is taken with only the remaining populated clusters 
61-             int  pivot  =  counts . Length ; 
62-             for  ( int  i  =  0 ;  i  <  pivot ; ) 
63-             { 
64-                 // Increment and continue if populated 
65-                 if  ( counts [ i ]  !=  0 ) 
66-                 { 
67-                     i ++ ; 
68-                     continue ; 
69-                 } 
70- 
71-                 // The item is not populated. Swap to end and move pivot 
72-                 // NOTE: This is a oneway swap. We're discarding the 0 anyways. 
73-                 pivot -- ; 
74-                 counts [ i ]  =  counts [ pivot ] ; 
75-             } 
76-             
77-             #if ! WINDOWS_UWP 
78-             counts  =  counts [ ..pivot ] ; 
79-             centroids  =  centroids [ ..pivot ] ; 
80-             #elif WINDOWS_UWP 
81-             Array . Resize ( ref  counts ,  pivot ) ; 
82-             centroids  =  centroids . Slice ( 0 ,  pivot ) ; 
83-             #endif
84- 
85-             // Division step in centroid calculation 
86-             for  ( int  i  =  0 ;  i  <  centroids . Length ;  i ++ ) 
87-                 centroids [ i ]  /=  counts [ i ] ; 
88-             
89-             // KMeans Loop Step 2: 
9033            // Move each point's clusterId to the nearest cluster centroid 
9134            for  ( int  i  =  0 ;  i  <  points . Length ;  i ++ ) 
9235            { 
93-                 Vector3  point  =  points [ i ] ; 
94-                 var  oldId  =  clusterIds [ i ] ; 
95- 
96-                 // Track the nearest centroid's distance and the index of that centroid 
97-                 float  nearestDistance  =  float . PositiveInfinity ; 
98-                 int  nearestIndex  =  - 1 ; 
99- 
100-                 for  ( int  j  =  0 ;  j  <  centroids . Length ;  j ++ ) 
101-                 { 
102-                     // Compare the point to the jth centroid 
103-                     float  distance  =  Vector3 . DistanceSquared ( point ,  centroids [ j ] ) ; 
104- 
105-                     // Skip the cluster if further than the nearest seen cluster  
106-                     if  ( nearestDistance  <  distance ) 
107-                         continue ; 
108- 
109-                     // This is the nearest cluster 
110-                     // Update the distance and index 
111-                     nearestDistance  =  distance ; 
112-                     nearestIndex  =  j ; 
113-                 } 
36+                 var  nearestIndex  =  FindNearestClusterIndex ( points [ i ] ,  centroids ) ; 
11437
11538                // The nearest cluster hasn't changed. Do nothing 
116-                 if  ( oldId  ==  nearestIndex ) 
39+                 if  ( clusterIds [ i ]  ==  nearestIndex ) 
11740                    continue ; 
11841
11942                // Update the cluster id and note that we have not converged 
@@ -125,6 +48,105 @@ private static Vector3[] KMeansCluster(Span<Vector3> points, int k, out int[] co
12548        return  centroids . ToArray ( ) ; 
12649    } 
12750
51+     /// <summary> 
52+     /// Assigns arbitrary clusterIds for each point 
53+     /// </summary> 
54+     private  static   void  Split ( int  k ,  int [ ]  clusterIds ) 
55+     { 
56+         // Mathematically true random sampling 
57+ #if NET6_0_OR_GREATER 
58+         var  offset  =  Random . Shared . Next ( k ) ;  
59+ #else
60+         var  rand  =  new  Random ( ) ; 
61+         var  offset  =  rand . Next ( k ) ; 
62+ #endif
63+ 
64+         // Assign each clusters id 
65+         for  ( int  i  =  0 ;  i  <  clusterIds . Length ;  i ++ ) 
66+             clusterIds [ i ]  =  ( i  +  offset )  %  k ; 
67+     } 
68+ 
69+     /// <summary> 
70+     /// Calculates the centroid of each cluster, and prunes empty clusters. 
71+     /// </summary> 
72+     private  static   void  CalculateCentroidsAndPrune ( ref  Span < Vector3 >  centroids ,  ref  int [ ]  counts ,  Span < Vector3 >  points ,  int [ ]  clusterIds ) 
73+     { 
74+         // Clear centroids and counts before recalculation 
75+         for  ( int  i  =  0 ;  i  <  centroids . Length ;  i ++ ) 
76+         { 
77+             centroids [ i ]  =  Vector3 . Zero ; 
78+             counts [ i ]  =  0 ; 
79+         } 
80+         
81+         // Accumulate step in centroid calculation 
82+         for  ( int  i  =  0 ;  i  <  clusterIds . Length ;  i ++ ) 
83+         { 
84+             int  id  =  clusterIds [ i ] ; 
85+             centroids [ id ]  +=  points [ i ] ; 
86+             counts [ id ] ++ ; 
87+         } 
88+         
89+         // Prune empty clusters 
90+         // All empty clusters are swapped to the end of the span 
91+         // then a slice is taken with only the remaining populated clusters 
92+         int  pivot  =  counts . Length ; 
93+         for  ( int  i  =  0 ;  i  <  pivot ; ) 
94+         { 
95+             // Increment and continue if populated 
96+             if  ( counts [ i ]  !=  0 ) 
97+             { 
98+                 i ++ ; 
99+                 continue ; 
100+             } 
101+         
102+             // The item is not populated. Swap to end and move pivot 
103+             // NOTE: This is a one-way "swap". We're discarding the 0s anyways. 
104+             pivot -- ; 
105+             centroids [ i ]  =  centroids [ pivot ] ; 
106+             counts [ i ]  =  counts [ pivot ] ; 
107+         } 
108+         
109+         // Perform slice 
110+ #if ! WINDOWS_UWP 
111+         counts  =  counts [ ..pivot ] ; 
112+         centroids  =  centroids [ ..pivot ] ; 
113+ #elif WINDOWS_UWP 
114+         Array . Resize ( ref  counts ,  pivot ) ; 
115+         centroids  =  centroids . Slice ( 0 ,  pivot ) ; 
116+ #endif
117+         
118+         // Division step in centroid calculation 
119+         for  ( int  i  =  0 ;  i  <  centroids . Length ;  i ++ ) 
120+             centroids [ i ]  /=  counts [ i ] ; 
121+     } 
122+ 
123+     /// <summary> 
124+     /// Finds the index of the centroid nearest the point 
125+     /// </summary> 
126+     private  static   int  FindNearestClusterIndex ( Vector3  point ,  Span < Vector3 >  centroids ) 
127+     { 
128+         // Track the nearest centroid's distance and the index of that centroid 
129+         float  nearestDistance  =  float . PositiveInfinity ; 
130+         int  nearestIndex  =  - 1 ; 
131+         
132+         for  ( int  j  =  0 ;  j  <  centroids . Length ;  j ++ ) 
133+         { 
134+             // Compare the point to the jth centroid 
135+             float  distance  =  Vector3 . DistanceSquared ( point ,  centroids [ j ] ) ; 
136+         
137+             // Skip the cluster if further than the nearest seen cluster  
138+             if  ( nearestDistance  <  distance ) 
139+                 continue ; 
140+         
141+             // This is the nearest cluster 
142+             // Update the distance and index 
143+             nearestDistance  =  distance ; 
144+             nearestIndex  =  j ; 
145+         } 
146+ 
147+         return  nearestIndex ; 
148+     } 
149+ 
128150    private  static   float  FindColorfulness ( Vector3  color ) 
129151    { 
130152        var  rg  =  color . X  -  color . Y ; 
0 commit comments