Skip to content

Commit a977b35

Browse files
authored
Merge pull request #572 from microbit-foundation/571-add-classifierinput-test
Add test for classifier input
2 parents 19dac1f + 30a5c9a commit a977b35

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

src/__tests__/ml/classifier.test.ts

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@
77
* SPDX-License-Identifier: MIT
88
*/
99

10+
import { writable } from 'svelte/store';
11+
import BaseVector from '../../script/domain/BaseVector';
12+
import { ClassifierInput } from '../../script/domain/ClassifierInput';
13+
import Filters from '../../script/domain/Filters';
1014
import { stores } from '../../script/stores/Stores';
1115
import TestMLModelTrainer from '../mocks/mlmodel/TestMLModelTrainer';
16+
import type { Filter } from '../../script/domain/Filter';
17+
import FilterTypes, { FilterType } from '../../script/domain/FilterTypes';
1218

1319
describe('Classifier tests', () => {
1420
test('Changing matrix does not mark model as untrained', async () => {
@@ -40,4 +46,62 @@ describe('Classifier tests', () => {
4046

4147
expect(stores.getClassifier().getModel().isTrained()).toBe(false);
4248
});
49+
50+
test('Classifier input should be correct size', () => {
51+
const vectors = [
52+
new BaseVector([1, 1, 1], ['x', 'y', 'z']),
53+
new BaseVector([2, 2, 2], ['x', 'y', 'z']),
54+
new BaseVector([3, 3, 3], ['x', 'y', 'z']),
55+
];
56+
const input = new ClassifierInput(vectors);
57+
const filterMax: Filter = FilterTypes.createFilter(FilterType.MAX);
58+
const filterMean: Filter = FilterTypes.createFilter(FilterType.MEAN);
59+
const filterMin: Filter = FilterTypes.createFilter(FilterType.MIN);
60+
const filters: Filters = new Filters(writable([filterMax, filterMean, filterMin]));
61+
expect(input.getInput(filters).length).toBe(3 * 3);
62+
});
63+
64+
test('Max Filter should return max of two vectors', () => {
65+
const vectors = [
66+
new BaseVector([1, 2, 3], ['x', 'y', 'z']),
67+
new BaseVector([4, 5, 6], ['x', 'y', 'z']),
68+
];
69+
const input = new ClassifierInput(vectors);
70+
const filterMax: Filter = FilterTypes.createFilter(FilterType.MAX);
71+
const filters: Filters = new Filters(writable([filterMax]));
72+
expect(input.getInput(filters)).toStrictEqual([4, 5, 6]);
73+
});
74+
75+
test('Filters should correctly consider all vectors 1d', () => {
76+
const vectors = [
77+
new BaseVector([1], ['x']),
78+
new BaseVector([4], ['x']),
79+
new BaseVector([10], ['x']),
80+
];
81+
const input = new ClassifierInput(vectors);
82+
const filterMax: Filter = FilterTypes.createFilter(FilterType.MAX);
83+
const filterMean: Filter = FilterTypes.createFilter(FilterType.MEAN);
84+
const filterMin: Filter = FilterTypes.createFilter(FilterType.MIN);
85+
const filters: Filters = new Filters(writable([filterMax, filterMean, filterMin]));
86+
expect(input.getInput(filters).length).toBe(3);
87+
expect(input.getInput(filters)).toStrictEqual([10, 5, 1]);
88+
});
89+
90+
test('Filters should correctly consider all vectors 2d', () => {
91+
const vectors = [
92+
new BaseVector([1, 2], ['x', 'y']),
93+
new BaseVector([4, 8], ['x', 'y']),
94+
new BaseVector([10, 20], ['x', 'y']),
95+
];
96+
const input = new ClassifierInput(vectors);
97+
const filterMax: Filter = FilterTypes.createFilter(FilterType.MAX);
98+
const filterMin: Filter = FilterTypes.createFilter(FilterType.MIN);
99+
const filters: Filters = new Filters(writable([filterMax, filterMin]));
100+
expect(input.getInput(filters)).toStrictEqual([
101+
// x value max/min
102+
1, 10,
103+
// y value max/min
104+
20, 2,
105+
]);
106+
});
43107
});

src/__tests__/testUtils.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@
33
*
44
* SPDX-License-Identifier: MIT
55
*/
6+
7+
import exampleDataset from '../exampleDataset.json';
68
export const repeat = (func: (a?: any) => any, n: number) => {
79
for (let i = 0; i < n; i++) {
810
func();
911
}
1012
};
13+
14+
export const generateRecordings = () => {
15+
return exampleDataset[0].recordings;
16+
};

0 commit comments

Comments
 (0)