Skip to content

Commit 398a07b

Browse files
committed
Add classifier and model test
1 parent 0e3b8c9 commit 398a07b

File tree

10 files changed

+560
-14
lines changed

10 files changed

+560
-14
lines changed

src/__tests__/ml/classifier.test.ts

Lines changed: 191 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,22 @@
77
* SPDX-License-Identifier: MIT
88
*/
99

10-
import { writable } from 'svelte/store';
10+
import { get, writable } from 'svelte/store';
1111
import BaseVector from '../../script/domain/BaseVector';
1212
import { ClassifierInput } from '../../script/domain/ClassifierInput';
1313
import Filters from '../../script/domain/Filters';
1414
import { stores } from '../../script/stores/Stores';
1515
import TestMLModelTrainer from '../mocks/mlmodel/TestMLModelTrainer';
1616
import type { Filter } from '../../script/domain/Filter';
1717
import FilterTypes, { FilterType } from '../../script/domain/FilterTypes';
18+
import ClassifierFactory from '../../script/domain/ClassifierFactory';
19+
import LayersModelTrainer from '../../script/mlmodels/LayersModelTrainer';
20+
import StaticConfiguration from '../../StaticConfiguration';
21+
import TestTrainingDataRepository from '../mocks/TestTrainingDataRepository';
22+
import TestGestureRepository from '../mocks/TestGestureRepository';
23+
import Confidences from '../../script/domain/stores/Confidences';
24+
import Snackbar from '../../components/snackbar/Snackbar';
25+
import { repeat } from '../testUtils';
1826

1927
describe('Classifier tests', () => {
2028
test('Changing matrix does not mark model as untrained', async () => {
@@ -104,4 +112,186 @@ describe('Classifier tests', () => {
104112
20, 2,
105113
]);
106114
});
115+
116+
test('Classifying Should Not Throw', async () => {
117+
const vectors = [
118+
new BaseVector([1, 2, 4], ['x', 'y', 'z']),
119+
new BaseVector([4, 8, 16], ['x', 'y', 'z']),
120+
new BaseVector([10, 20, 40], ['x', 'y', 'z']),
121+
];
122+
const classifierInput = new ClassifierInput(vectors);
123+
const filterMax: Filter = FilterTypes.createFilter(FilterType.MAX);
124+
const filterMean: Filter = FilterTypes.createFilter(FilterType.MEAN);
125+
const filterMin: Filter = FilterTypes.createFilter(FilterType.MIN);
126+
const filters: Filters = new Filters(writable([filterMax, filterMean, filterMin]));
127+
128+
let iterations = 0;
129+
130+
const trainingData = new TestTrainingDataRepository().getTrainingData();
131+
const trainedModel = await new LayersModelTrainer(
132+
StaticConfiguration.defaultNeuralNetworkSettings,
133+
() => (iterations += 1),
134+
).trainModel(trainingData);
135+
const model = writable(trainedModel);
136+
137+
const gestureRepository = new TestGestureRepository();
138+
gestureRepository.addGesture({
139+
color: 'blue',
140+
ID: 1,
141+
name: 'test',
142+
output: {},
143+
recordings: [],
144+
});
145+
gestureRepository.addGesture({
146+
color: 'blue',
147+
ID: 2,
148+
name: 'test',
149+
output: {},
150+
recordings: [],
151+
});
152+
gestureRepository.addGesture({
153+
color: 'blue',
154+
ID: 3,
155+
name: 'test',
156+
output: {},
157+
recordings: [],
158+
});
159+
160+
const confidences = new Confidences();
161+
const classifier = new ClassifierFactory().buildClassifier(
162+
model,
163+
async () => void 0,
164+
filters,
165+
gestureRepository,
166+
(gestureId, confidence) => confidences.setConfidence(gestureId, confidence),
167+
new Snackbar(),
168+
);
169+
170+
expect(async () => await classifier.classify(classifierInput)).not.throws();
171+
});
172+
173+
test('Classifier should set confidence', async () => {
174+
const vectors = [
175+
new BaseVector([1, 2, 4], ['x', 'y', 'z']),
176+
new BaseVector([4, 8, 16], ['x', 'y', 'z']),
177+
new BaseVector([10, 20, 40], ['x', 'y', 'z']),
178+
];
179+
const classifierInput = new ClassifierInput(vectors);
180+
const filterMax: Filter = FilterTypes.createFilter(FilterType.MAX);
181+
const filterMean: Filter = FilterTypes.createFilter(FilterType.MEAN);
182+
const filterMin: Filter = FilterTypes.createFilter(FilterType.MIN);
183+
const filters: Filters = new Filters(writable([filterMax, filterMean, filterMin]));
184+
185+
let iterations = 0;
186+
187+
const trainingData = new TestTrainingDataRepository().getTrainingData();
188+
const trainedModel = await new LayersModelTrainer(
189+
StaticConfiguration.defaultNeuralNetworkSettings,
190+
() => (iterations += 1),
191+
).trainModel(trainingData);
192+
const model = writable(trainedModel);
193+
194+
const gestureRepository = new TestGestureRepository();
195+
gestureRepository.addGesture({
196+
color: 'blue',
197+
ID: 1,
198+
name: 'test',
199+
output: {},
200+
recordings: [],
201+
});
202+
gestureRepository.addGesture({
203+
color: 'blue',
204+
ID: 2,
205+
name: 'test',
206+
output: {},
207+
recordings: [],
208+
});
209+
gestureRepository.addGesture({
210+
color: 'blue',
211+
ID: 3,
212+
name: 'test',
213+
output: {},
214+
recordings: [],
215+
});
216+
217+
const confidences = new Confidences();
218+
219+
const classifier = new ClassifierFactory().buildClassifier(
220+
model,
221+
async () => void 0,
222+
filters,
223+
gestureRepository,
224+
(gestureId, confidence) => confidences.setConfidence(gestureId, confidence),
225+
new Snackbar(),
226+
);
227+
228+
await classifier.classify(classifierInput);
229+
230+
expect(get(confidences).size).toBe(3);
231+
});
232+
233+
test('Classifier should correctly classify', async () => {
234+
const vectors = [
235+
new BaseVector([1, 2, 4], ['x', 'y', 'z']),
236+
new BaseVector([4, 8, 16], ['x', 'y', 'z']),
237+
new BaseVector([10, 20, 40], ['x', 'y', 'z']),
238+
];
239+
const classifierInput = new ClassifierInput(vectors);
240+
const filterMax: Filter = FilterTypes.createFilter(FilterType.MAX);
241+
const filterMean: Filter = FilterTypes.createFilter(FilterType.MEAN);
242+
const filterMin: Filter = FilterTypes.createFilter(FilterType.MIN);
243+
const filters: Filters = new Filters(writable([filterMax, filterMean, filterMin]));
244+
245+
let iterations = 0;
246+
247+
const trainingData = new TestTrainingDataRepository().getTrainingData();
248+
const trainedModel = await new LayersModelTrainer(
249+
StaticConfiguration.defaultNeuralNetworkSettings,
250+
() => (iterations += 1),
251+
).trainModel(trainingData);
252+
const model = writable(trainedModel);
253+
254+
const gestureRepository = new TestGestureRepository();
255+
gestureRepository.addGesture({
256+
color: 'blue',
257+
ID: 1,
258+
name: 'test',
259+
output: {},
260+
recordings: [],
261+
});
262+
gestureRepository.addGesture({
263+
color: 'blue',
264+
ID: 2,
265+
name: 'test',
266+
output: {},
267+
recordings: [],
268+
});
269+
gestureRepository.addGesture({
270+
color: 'blue',
271+
ID: 3,
272+
name: 'test',
273+
output: {},
274+
recordings: [],
275+
});
276+
277+
const confidences = new Confidences();
278+
279+
const classifier = new ClassifierFactory().buildClassifier(
280+
model,
281+
async () => void 0,
282+
filters,
283+
gestureRepository,
284+
(gestureId, confidence) => confidences.setConfidence(gestureId, confidence),
285+
new Snackbar(),
286+
);
287+
288+
// This is based on known correct results
289+
await classifier.classify(classifierInput)
290+
291+
expect(get(confidences).get(1)).toBeCloseTo(0);
292+
expect(get(confidences).get(2)).toBeCloseTo(0);
293+
expect(get(confidences).get(3)).toBeCloseTo(1);
294+
}, {
295+
repeats: 20, retry: 2
296+
});
107297
});

src/__tests__/ml/model.test.ts

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/**
2+
* @vitest-environment jsdom
3+
*/
4+
/**
5+
* (c) 2023, Center for Computational Thinking and Design at Aarhus University and contributors
6+
*
7+
* SPDX-License-Identifier: MIT
8+
*/
9+
import TestTrainingDataRepository from '../mocks/TestTrainingDataRepository';
10+
import LayersModelTrainer from '../../script/mlmodels/LayersModelTrainer';
11+
import StaticConfiguration from '../../StaticConfiguration';
12+
import KNNModelTrainer from '../../script/mlmodels/KNNModelTrainer';
13+
import KNNNonNormalizedModelTrainer from '../../script/mlmodels/KNNNonNormalizedModelTrainer';
14+
15+
describe('ML Model tests', async () => {
16+
describe("Layers Model", async () => {
17+
test('Model should train the expected number of times', async () => {
18+
let iterations = 0;
19+
20+
const trainingData = new TestTrainingDataRepository().getTrainingData();
21+
await new LayersModelTrainer(
22+
StaticConfiguration.defaultNeuralNetworkSettings,
23+
() => (iterations += 1),
24+
).trainModel(trainingData);
25+
26+
expect(iterations).toBe(StaticConfiguration.defaultNeuralNetworkSettings.noOfEpochs);
27+
});
28+
29+
30+
})
31+
describe("KNN-non normalized Model", async () => {
32+
test('Model should train the expected number of times', async () => {
33+
let iterations = 0;
34+
35+
const trainingData = new TestTrainingDataRepository().getTrainingData();
36+
const knnModel = await new KNNNonNormalizedModelTrainer(2).trainModel(trainingData);
37+
38+
expect(async () => await knnModel.predict([0, 0, 0, 0, 0, 0, 0, 0, 0])).not.throws()
39+
});
40+
});
41+
});
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/**
2+
* (c) 2023, Center for Computational Thinking and Design at Aarhus University and contributors
3+
*
4+
* SPDX-License-Identifier: MIT
5+
*/
6+
7+
import {
8+
type Subscriber,
9+
type Invalidator,
10+
type Unsubscriber,
11+
writable,
12+
get,
13+
} from 'svelte/store';
14+
import type { GestureRepository } from '../../script/domain/GestureRepository';
15+
import Gesture from '../../script/domain/stores/gesture/Gesture';
16+
import type { PersistantGestureData } from '../../script/domain/stores/gesture/Gestures';
17+
import GestureConfidence from '../../script/domain/stores/gesture/GestureConfidence';
18+
19+
class TestGestureRepository implements GestureRepository {
20+
private gestures = writable<Gesture[]>([]);
21+
22+
getGesture(gestureId: number): Gesture {
23+
const foundGesture = get(this.gestures).find(g => g.getId() === gestureId);
24+
if (!foundGesture) {
25+
throw new Error('Could not find gesture with id ' + gestureId);
26+
}
27+
return foundGesture;
28+
}
29+
30+
clearGestures(): void {
31+
this.gestures.set([]);
32+
}
33+
34+
addGesture(gestureData: PersistantGestureData): Gesture {
35+
const gesture = new Gesture(
36+
writable(gestureData),
37+
new GestureConfidence(0.5, writable(0)),
38+
() => void 0,
39+
);
40+
this.gestures.update(s => {
41+
return [...s, gesture];
42+
});
43+
return gesture;
44+
}
45+
46+
removeGesture(gestureId: number): void {
47+
this.gestures.update(s => {
48+
return s.filter(g => g.getId() !== gestureId);
49+
});
50+
}
51+
52+
subscribe(
53+
run: Subscriber<Gesture[]>,
54+
invalidate?: Invalidator<Gesture[]> | undefined,
55+
): Unsubscriber {
56+
return this.gestures.subscribe(run, invalidate);
57+
}
58+
}
59+
60+
export default TestGestureRepository;

0 commit comments

Comments
 (0)