@@ -61,9 +61,11 @@ class KDTreeImpl with SerializableMixin implements KDTree {
6161 return neighbours.toList ().reversed;
6262 }
6363
64- void _findKNNRecursively (KDTreeNode node, Vector point, int k,
64+ void _findKNNRecursively (KDTreeNode ? node, Vector point, int k,
6565 HeapPriorityQueue <KDTreeNeighbour > neighbours) {
66- searchIterationCount++ ;
66+ if (node == null ) {
67+ return ;
68+ }
6769
6870 if (node.isLeaf) {
6971 _knnSearch (point, node.pointIndices, neighbours, k);
@@ -72,31 +74,30 @@ class KDTreeImpl with SerializableMixin implements KDTree {
7274 }
7375
7476 final nodePoint = points[node.pointIndices[0 ]];
75- final boundaryPoint = _getBoundary (point, node);
7677 final isNodeTooFar = neighbours.length > 0 &&
77- boundaryPoint.distanceTo (nodePoint) > neighbours.first.distance;
78+ (point[node.splitIndex] - nodePoint[node.splitIndex]).abs () >
79+ neighbours.first.distance;
7880 final isQueueFilled = neighbours.length == k;
7981
8082 if (isQueueFilled && isNodeTooFar) {
8183 return ;
8284 }
8385
84- if (node.left == null ) {
85- _findKNNRecursively (node.right! , point, k, neighbours);
86- } else if (node.right == null ) {
87- _findKNNRecursively (node.left! , point, k, neighbours);
88- } else if (point[node.splitIndex! ] < nodePoint[node.splitIndex! ]) {
89- _findKNNRecursively (node.left! , point, k, neighbours);
90- _findKNNRecursively (node.right! , point, k, neighbours);
86+ _knnSearch (point, node.pointIndices, neighbours, k);
87+
88+ if (point[node.splitIndex] < nodePoint[node.splitIndex]) {
89+ _findKNNRecursively (node.left, point, k, neighbours);
90+ _findKNNRecursively (node.right, point, k, neighbours);
9191 } else {
92- _findKNNRecursively (node.right! , point, k, neighbours);
93- _findKNNRecursively (node.left! , point, k, neighbours);
92+ _findKNNRecursively (node.right, point, k, neighbours);
93+ _findKNNRecursively (node.left, point, k, neighbours);
9494 }
9595 }
9696
9797 void _knnSearch (Vector point, List <int > pointIndices,
9898 HeapPriorityQueue <KDTreeNeighbour > neighbours, int k) {
9999 pointIndices.forEach ((candidateIdx) {
100+ searchIterationCount++ ;
100101 final candidate = points[candidateIdx];
101102 final candidateDistance = candidate.distanceTo (point);
102103 final lastNeighbourDistance =
@@ -113,13 +114,4 @@ class KDTreeImpl with SerializableMixin implements KDTree {
113114 }
114115 });
115116 }
116-
117- Vector _getBoundary (Vector point, KDTreeNode node) {
118- final nodePoint = points[node.pointIndices[0 ]];
119- final boundarySrc = [...nodePoint];
120-
121- boundarySrc[node.splitIndex! ] = point[node.splitIndex! ];
122-
123- return Vector .fromList (boundarySrc, dtype: point.dtype);
124- }
125117}
0 commit comments