@@ -3,6 +3,7 @@ import 'package:ml_algo/src/retrieval/kd_tree/kd_tree.dart';
33import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_impl.dart' ;
44import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_split_strategy.dart' ;
55import 'package:ml_dataframe/ml_dataframe.dart' ;
6+ import 'package:ml_linalg/distance.dart' ;
67import 'package:ml_linalg/dtype.dart' ;
78import 'package:ml_linalg/vector.dart' ;
89import 'package:test/test.dart' ;
@@ -61,10 +62,10 @@ void main() {
6162 });
6263
6364 test (
64- 'should find the closest neighbours for [2.79, -9.15, 6.56, -18.59, 13.53], leafSize=3, splitStrategy=KDTreeSplitStrategy.inOrder ' ,
65+ 'should find the closest neighbours for [2.79, -9.15, 6.56, -18.59, 13.53], leafSize=3, splitStrategy=KDTreeSplitStrategy.largestVariance ' ,
6566 () {
6667 final kdTree = KDTree (DataFrame (data, headerExists: false ),
67- leafSize: 3 , splitStrategy: KDTreeSplitStrategy .inOrder );
68+ leafSize: 3 , splitStrategy: KDTreeSplitStrategy .largestVariance );
6869 final sample = Vector .fromList ([2.79 , - 9.15 , 6.56 , - 18.59 , 13.53 ]);
6970 final result = kdTree.query (sample, 3 ).toList ();
7071
@@ -84,7 +85,7 @@ void main() {
8485 expect (result[0 ].index, 4 );
8586 expect (result[1 ].index, 12 );
8687 expect (result[2 ].index, 3 );
87- expect (result[3 ].index, 1 );
88+ expect (result[3 ].index, 18 );
8889 expect (result, hasLength (4 ));
8990 });
9091
@@ -98,7 +99,7 @@ void main() {
9899 expect (result[0 ].index, 4 );
99100 expect (result[1 ].index, 12 );
100101 expect (result[2 ].index, 3 );
101- expect (result[3 ].index, 1 );
102+ expect (result[3 ].index, 18 );
102103 expect (result, hasLength (4 ));
103104 }, skip: false );
104105
@@ -112,8 +113,24 @@ void main() {
112113 expect (result[0 ].index, 19 );
113114 expect (result[1 ].index, 11 );
114115 expect (result[2 ].index, 6 );
115- expect (result[4 ].index, 9 );
116- expect (result[3 ].index, 18 );
116+ expect (result[3 ].index, 9 );
117+ expect (result[4 ].index, 18 );
118+ expect (result[5 ].index, 2 );
119+ expect (result[6 ].index, 14 );
120+ expect (result[7 ].index, 10 );
121+ expect (result[8 ].index, 15 );
122+ expect (result[9 ].index, 5 );
123+ expect (result[10 ].index, 7 );
124+ expect (result[10 ].index, 7 );
125+ expect (result[11 ].index, 13 );
126+ expect (result[12 ].index, 3 );
127+ expect (result[13 ].index, 12 );
128+ expect (result[14 ].index, 17 );
129+ expect (result[15 ].index, 1 );
130+ expect (result[16 ].index, 0 );
131+ expect (result[17 ].index, 16 );
132+ expect (result[18 ].index, 4 );
133+ expect (result[19 ].index, 8 );
117134 expect (result, hasLength (20 ));
118135 });
119136
@@ -143,7 +160,7 @@ void main() {
143160
144161 kdTree.query (sample, 1 ).toList ();
145162
146- expect ((kdTree as KDTreeImpl ).searchIterationCount, 14 );
163+ expect ((kdTree as KDTreeImpl ).searchIterationCount, 4 );
147164 });
148165
149166 test (
@@ -154,7 +171,68 @@ void main() {
154171
155172 kdTree.query (sample, 1 ).toList ();
156173
157- expect ((kdTree as KDTreeImpl ).searchIterationCount, 7 );
174+ expect ((kdTree as KDTreeImpl ).searchIterationCount, 6 );
175+ });
176+
177+ test (
178+ 'should find the closest neighbours for [12, 23, 22, 11, -20], k=1, leafSize=1, cosine distance' ,
179+ () {
180+ final kdTree = KDTree (DataFrame (data, headerExists: false ), leafSize: 1 );
181+ final sample = Vector .fromList ([12 , 23 , 22 , 11 , - 20 ]);
182+ final result = kdTree.query (sample, 1 , Distance .cosine).toList ();
183+
184+ expect (result, hasLength (1 ));
185+ expect (result[0 ].index, 17 );
186+ });
187+
188+ test (
189+ 'should find the closest neighbours for [12, 23, 22, 11, -20], k=2, leafSize=1, cosine distance' ,
190+ () {
191+ final kdTree = KDTree (DataFrame (data, headerExists: false ), leafSize: 1 );
192+ final sample = Vector .fromList ([12 , 23 , 22 , 11 , - 20 ]);
193+ final result = kdTree.query (sample, 2 , Distance .cosine).toList ();
194+
195+ expect (result, hasLength (2 ));
196+ expect (result[0 ].index, 17 );
197+ expect (result[1 ].index, 8 );
198+ });
199+
200+ test (
201+ 'should find the closest neighbours for [12, 23, 22, 11, -20], k=3, leafSize=1, cosine distance' ,
202+ () {
203+ final kdTree = KDTree (DataFrame (data, headerExists: false ), leafSize: 1 );
204+ final sample = Vector .fromList ([12 , 23 , 22 , 11 , - 20 ]);
205+ final result = kdTree.query (sample, 3 , Distance .cosine).toList ();
206+
207+ expect (result, hasLength (3 ));
208+ expect (result[0 ].index, 17 );
209+ expect (result[1 ].index, 8 );
210+ expect (result[2 ].index, 0 );
211+ });
212+
213+ test (
214+ 'should find the closest neighbours for [12, 23, 22, 11, -20], k=3, leafSize=3, cosine distance' ,
215+ () {
216+ final kdTree = KDTree (DataFrame (data, headerExists: false ), leafSize: 3 );
217+ final sample = Vector .fromList ([12 , 23 , 22 , 11 , - 20 ]);
218+ final result = kdTree.query (sample, 3 , Distance .cosine).toList ();
219+
220+ expect (result, hasLength (3 ));
221+ expect (result[0 ].index, 17 );
222+ expect (result[1 ].index, 8 );
223+ expect (result[2 ].index, 0 );
224+ });
225+
226+ test (
227+ 'should find the closest neighbours for [12, 23, 22, 11, -20], k=3, leafSize=3, cosine distance for conceivable amount of iterations' ,
228+ () {
229+ final kdTree = KDTree (DataFrame (data, headerExists: false ),
230+ splitStrategy: KDTreeSplitStrategy .largestVariance, leafSize: 1 );
231+ final sample = Vector .fromList ([12 , 23 , 22 , 11 , - 20 ]);
232+
233+ kdTree.query (sample, 3 , Distance .cosine).toList ();
234+
235+ expect ((kdTree as KDTreeImpl ).searchIterationCount, 13 );
158236 });
159237
160238 test ('should throw an exception if the query point is of invalid length' ,
0 commit comments