|
7 | 7 | * SPDX-License-Identifier: MIT |
8 | 8 | */ |
9 | 9 |
|
| 10 | +import { writable } from 'svelte/store'; |
| 11 | +import BaseVector from '../../script/domain/BaseVector'; |
| 12 | +import { ClassifierInput } from '../../script/domain/ClassifierInput'; |
| 13 | +import Filters from '../../script/domain/Filters'; |
10 | 14 | import { stores } from '../../script/stores/Stores'; |
11 | 15 | import TestMLModelTrainer from '../mocks/mlmodel/TestMLModelTrainer'; |
| 16 | +import type { Filter } from '../../script/domain/Filter'; |
| 17 | +import FilterTypes, { FilterType } from '../../script/domain/FilterTypes'; |
12 | 18 |
|
13 | 19 | describe('Classifier tests', () => { |
14 | 20 | test('Changing matrix does not mark model as untrained', async () => { |
@@ -40,4 +46,62 @@ describe('Classifier tests', () => { |
40 | 46 |
|
41 | 47 | expect(stores.getClassifier().getModel().isTrained()).toBe(false); |
42 | 48 | }); |
| 49 | + |
| 50 | + test('Classifier input should be correct size', () => { |
| 51 | + const vectors = [ |
| 52 | + new BaseVector([1, 1, 1], ['x', 'y', 'z']), |
| 53 | + new BaseVector([2, 2, 2], ['x', 'y', 'z']), |
| 54 | + new BaseVector([3, 3, 3], ['x', 'y', 'z']), |
| 55 | + ]; |
| 56 | + const input = new ClassifierInput(vectors); |
| 57 | + const filterMax: Filter = FilterTypes.createFilter(FilterType.MAX); |
| 58 | + const filterMean: Filter = FilterTypes.createFilter(FilterType.MEAN); |
| 59 | + const filterMin: Filter = FilterTypes.createFilter(FilterType.MIN); |
| 60 | + const filters: Filters = new Filters(writable([filterMax, filterMean, filterMin])); |
| 61 | + expect(input.getInput(filters).length).toBe(3 * 3); |
| 62 | + }); |
| 63 | + |
| 64 | + test('Max Filter should return max of two vectors', () => { |
| 65 | + const vectors = [ |
| 66 | + new BaseVector([1, 2, 3], ['x', 'y', 'z']), |
| 67 | + new BaseVector([4, 5, 6], ['x', 'y', 'z']), |
| 68 | + ]; |
| 69 | + const input = new ClassifierInput(vectors); |
| 70 | + const filterMax: Filter = FilterTypes.createFilter(FilterType.MAX); |
| 71 | + const filters: Filters = new Filters(writable([filterMax])); |
| 72 | + expect(input.getInput(filters)).toStrictEqual([4, 5, 6]); |
| 73 | + }); |
| 74 | + |
| 75 | + test('Filters should correctly consider all vectors 1d', () => { |
| 76 | + const vectors = [ |
| 77 | + new BaseVector([1], ['x']), |
| 78 | + new BaseVector([4], ['x']), |
| 79 | + new BaseVector([10], ['x']), |
| 80 | + ]; |
| 81 | + const input = new ClassifierInput(vectors); |
| 82 | + const filterMax: Filter = FilterTypes.createFilter(FilterType.MAX); |
| 83 | + const filterMean: Filter = FilterTypes.createFilter(FilterType.MEAN); |
| 84 | + const filterMin: Filter = FilterTypes.createFilter(FilterType.MIN); |
| 85 | + const filters: Filters = new Filters(writable([filterMax, filterMean, filterMin])); |
| 86 | + expect(input.getInput(filters).length).toBe(3); |
| 87 | + expect(input.getInput(filters)).toStrictEqual([10, 5, 1]); |
| 88 | + }); |
| 89 | + |
| 90 | + test('Filters should correctly consider all vectors 2d', () => { |
| 91 | + const vectors = [ |
| 92 | + new BaseVector([1, 2], ['x', 'y']), |
| 93 | + new BaseVector([4, 8], ['x', 'y']), |
| 94 | + new BaseVector([10, 20], ['x', 'y']), |
| 95 | + ]; |
| 96 | + const input = new ClassifierInput(vectors); |
| 97 | + const filterMax: Filter = FilterTypes.createFilter(FilterType.MAX); |
| 98 | + const filterMin: Filter = FilterTypes.createFilter(FilterType.MIN); |
| 99 | + const filters: Filters = new Filters(writable([filterMax, filterMin])); |
| 100 | + expect(input.getInput(filters)).toStrictEqual([ |
| 101 | + // x value max/min |
| 102 | + 1, 10, |
| 103 | + // y value max/min |
| 104 | + 20, 2, |
| 105 | + ]); |
| 106 | + }); |
43 | 107 | }); |
0 commit comments