Skip to content

Commit 476484c

Browse files
committed
Make knn test better
1 parent 398a07b commit 476484c

File tree

3 files changed

+92
-86
lines changed

3 files changed

+92
-86
lines changed

src/__tests__/ml/classifier.test.ts

Lines changed: 69 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -230,68 +230,73 @@ describe('Classifier tests', () => {
230230
expect(get(confidences).size).toBe(3);
231231
});
232232

233-
test('Classifier should correctly classify', async () => {
234-
const vectors = [
235-
new BaseVector([1, 2, 4], ['x', 'y', 'z']),
236-
new BaseVector([4, 8, 16], ['x', 'y', 'z']),
237-
new BaseVector([10, 20, 40], ['x', 'y', 'z']),
238-
];
239-
const classifierInput = new ClassifierInput(vectors);
240-
const filterMax: Filter = FilterTypes.createFilter(FilterType.MAX);
241-
const filterMean: Filter = FilterTypes.createFilter(FilterType.MEAN);
242-
const filterMin: Filter = FilterTypes.createFilter(FilterType.MIN);
243-
const filters: Filters = new Filters(writable([filterMax, filterMean, filterMin]));
244-
245-
let iterations = 0;
246-
247-
const trainingData = new TestTrainingDataRepository().getTrainingData();
248-
const trainedModel = await new LayersModelTrainer(
249-
StaticConfiguration.defaultNeuralNetworkSettings,
250-
() => (iterations += 1),
251-
).trainModel(trainingData);
252-
const model = writable(trainedModel);
253-
254-
const gestureRepository = new TestGestureRepository();
255-
gestureRepository.addGesture({
256-
color: 'blue',
257-
ID: 1,
258-
name: 'test',
259-
output: {},
260-
recordings: [],
261-
});
262-
gestureRepository.addGesture({
263-
color: 'blue',
264-
ID: 2,
265-
name: 'test',
266-
output: {},
267-
recordings: [],
268-
});
269-
gestureRepository.addGesture({
270-
color: 'blue',
271-
ID: 3,
272-
name: 'test',
273-
output: {},
274-
recordings: [],
275-
});
276-
277-
const confidences = new Confidences();
278-
279-
const classifier = new ClassifierFactory().buildClassifier(
280-
model,
281-
async () => void 0,
282-
filters,
283-
gestureRepository,
284-
(gestureId, confidence) => confidences.setConfidence(gestureId, confidence),
285-
new Snackbar(),
286-
);
287-
288-
// This is based on known correct results
289-
await classifier.classify(classifierInput)
290-
291-
expect(get(confidences).get(1)).toBeCloseTo(0);
292-
expect(get(confidences).get(2)).toBeCloseTo(0);
293-
expect(get(confidences).get(3)).toBeCloseTo(1);
294-
}, {
295-
repeats: 20, retry: 2
296-
});
233+
test(
234+
'Classifier should correctly classify',
235+
async () => {
236+
const vectors = [
237+
new BaseVector([1, 2, 4], ['x', 'y', 'z']),
238+
new BaseVector([4, 8, 16], ['x', 'y', 'z']),
239+
new BaseVector([10, 20, 40], ['x', 'y', 'z']),
240+
];
241+
const classifierInput = new ClassifierInput(vectors);
242+
const filterMax: Filter = FilterTypes.createFilter(FilterType.MAX);
243+
const filterMean: Filter = FilterTypes.createFilter(FilterType.MEAN);
244+
const filterMin: Filter = FilterTypes.createFilter(FilterType.MIN);
245+
const filters: Filters = new Filters(writable([filterMax, filterMean, filterMin]));
246+
247+
let iterations = 0;
248+
249+
const trainingData = new TestTrainingDataRepository().getTrainingData();
250+
const trainedModel = await new LayersModelTrainer(
251+
StaticConfiguration.defaultNeuralNetworkSettings,
252+
() => (iterations += 1),
253+
).trainModel(trainingData);
254+
const model = writable(trainedModel);
255+
256+
const gestureRepository = new TestGestureRepository();
257+
gestureRepository.addGesture({
258+
color: 'blue',
259+
ID: 1,
260+
name: 'test',
261+
output: {},
262+
recordings: [],
263+
});
264+
gestureRepository.addGesture({
265+
color: 'blue',
266+
ID: 2,
267+
name: 'test',
268+
output: {},
269+
recordings: [],
270+
});
271+
gestureRepository.addGesture({
272+
color: 'blue',
273+
ID: 3,
274+
name: 'test',
275+
output: {},
276+
recordings: [],
277+
});
278+
279+
const confidences = new Confidences();
280+
281+
const classifier = new ClassifierFactory().buildClassifier(
282+
model,
283+
async () => void 0,
284+
filters,
285+
gestureRepository,
286+
(gestureId, confidence) => confidences.setConfidence(gestureId, confidence),
287+
new Snackbar(),
288+
);
289+
290+
// This is based on known correct results
291+
await classifier.classify(classifierInput);
292+
293+
expect(get(confidences).get(1)).toBeCloseTo(0);
294+
expect(get(confidences).get(2)).toBeCloseTo(0);
295+
expect(get(confidences).get(3)).toBeCloseTo(1);
296+
},
297+
{
298+
repeats: 20,
299+
retry: 2,
300+
},
301+
);
297302
});

src/__tests__/ml/model.test.ts

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,30 @@ import KNNModelTrainer from '../../script/mlmodels/KNNModelTrainer';
1313
import KNNNonNormalizedModelTrainer from '../../script/mlmodels/KNNNonNormalizedModelTrainer';
1414

1515
describe('ML Model tests', async () => {
16-
describe("Layers Model", async () => {
17-
test('Model should train the expected number of times', async () => {
18-
let iterations = 0;
16+
describe('Layers Model', async () => {
17+
test('Model should train the expected number of times', async () => {
18+
let iterations = 0;
1919

20-
const trainingData = new TestTrainingDataRepository().getTrainingData();
21-
await new LayersModelTrainer(
22-
StaticConfiguration.defaultNeuralNetworkSettings,
23-
() => (iterations += 1),
24-
).trainModel(trainingData);
20+
const trainingData = new TestTrainingDataRepository().getTrainingData();
21+
await new LayersModelTrainer(
22+
StaticConfiguration.defaultNeuralNetworkSettings,
23+
() => (iterations += 1),
24+
).trainModel(trainingData);
2525

26-
expect(iterations).toBe(StaticConfiguration.defaultNeuralNetworkSettings.noOfEpochs);
27-
});
28-
29-
30-
})
31-
describe("KNN-non normalized Model", async () => {
32-
test('Model should train the expected number of times', async () => {
33-
let iterations = 0;
34-
35-
const trainingData = new TestTrainingDataRepository().getTrainingData();
36-
const knnModel = await new KNNNonNormalizedModelTrainer(2).trainModel(trainingData);
26+
expect(iterations).toBe(
27+
StaticConfiguration.defaultNeuralNetworkSettings.noOfEpochs,
28+
);
29+
});
30+
});
31+
describe('KNN-non normalized Model', async () => {
32+
test('Model should train the expected number of times', async () => {
33+
const trainingData = new TestTrainingDataRepository().getTrainingData();
34+
const knnModel = await new KNNNonNormalizedModelTrainer(2).trainModel(trainingData);
3735

38-
expect(async () => await knnModel.predict([0, 0, 0, 0, 0, 0, 0, 0, 0])).not.throws()
39-
});
36+
const prediction1 = await knnModel.predict([0, 0, 0, 0, 0, 0, 0, 0, 0]);
37+
expect(prediction1).toStrictEqual([0, 1, 0]);
38+
const prediction2 = await knnModel.predict([1, 1, 0, 0, 0, -2, 0, -3, 0]);
39+
expect(prediction2).toStrictEqual([0.5, 0, 0.5]);
4040
});
41+
});
4142
});

src/script/mlmodels/KNNMLModel.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class KNNMLModel implements MLModel {
1414
) {}
1515
public async predict(filteredData: number[]): Promise<number[]> {
1616
const inputTensor = tensor(filteredData);
17-
console.warn(filteredData)
17+
console.warn(filteredData);
1818

1919
try {
2020
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment, @typescript-eslint/no-unsafe-member-access, @typescript-eslint/no-unsafe-call

0 commit comments

Comments
 (0)