|
1 | | -package com.feedzai.openml.python;/* |
2 | | - * The copyright of this file belongs to Feedzai. The file cannot be |
3 | | - * reproduced in whole or in part, stored in a retrieval system, |
4 | | - * transmitted in any form, or by any means electronic, mechanical, |
5 | | - * photocopying, or otherwise, without the prior permission of the owner. |
| 1 | +/* |
| 2 | + * Copyright (c) 2019 Feedzai |
6 | 3 | * |
7 | | - * © 2019 Feedzai, Strictly Confidential |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
8 | 15 | */ |
9 | 16 |
|
| 17 | +package com.feedzai.openml.python; |
| 18 | + |
| 19 | +import com.feedzai.openml.data.schema.CategoricalValueSchema; |
10 | 20 | import com.feedzai.openml.data.schema.DatasetSchema; |
11 | 21 | import com.feedzai.openml.data.schema.FieldSchema; |
12 | 22 | import com.feedzai.openml.data.schema.NumericValueSchema; |
| 23 | +import com.feedzai.openml.mocks.MockInstance; |
13 | 24 | import com.feedzai.openml.python.jep.instance.JepInstance; |
14 | 25 | import com.google.common.collect.ImmutableList; |
| 26 | +import com.google.common.collect.ImmutableSet; |
15 | 27 | import org.junit.Before; |
16 | 28 | import org.junit.Test; |
17 | 29 |
|
| 30 | +import java.net.URISyntaxException; |
| 31 | +import java.nio.file.Path; |
| 32 | +import java.nio.file.Paths; |
| 33 | +import java.util.List; |
| 34 | +import java.util.Random; |
| 35 | +import java.util.concurrent.ExecutionException; |
| 36 | + |
18 | 37 | import static org.assertj.core.api.Assertions.assertThatThrownBy; |
19 | 38 |
|
20 | 39 | /** |
21 | 40 | * Tests the behaviour of a {@link ClassificationPythonModel}. |
22 | 41 | * |
23 | 42 | * @author Joao Sousa (joao.sousa@feedzai.com) |
24 | | - * @since @@@feedzai.next.release@@@ |
| 43 | + * @since 0.2.0 |
25 | 44 | */ |
26 | 45 | public class ClassificationPythonModelTest { |
27 | 46 |
|
28 | 47 | /** |
29 | | - * A field to use in the tests. |
| 48 | + * A numerical field to use in the tests. |
30 | 49 | */ |
31 | 50 | private static final FieldSchema FIELD_SCHEMA = |
32 | 51 | new FieldSchema("field", 0, new NumericValueSchema(false)); |
33 | 52 |
|
| 53 | + /** |
| 54 | + * A categorical field to use in the tests. |
| 55 | + */ |
| 56 | + private static final FieldSchema CATEGORICAL_FIELD_SCHEMA = |
| 57 | + new FieldSchema("categorical", 1, new CategoricalValueSchema(false, ImmutableSet.of("this", "that"))); |
34 | 58 |
|
35 | 59 | /** |
36 | 60 | * The wrapper for the Jep object used in the tests. |
@@ -59,4 +83,36 @@ public final void testSchemaWithoutTargetVariable() { |
59 | 83 | .isInstanceOf(IllegalArgumentException.class); |
60 | 84 | } |
61 | 85 |
|
| 86 | + /** |
| 87 | + * Tests that a classifier which classifies an instance with an invalid target value will cause the model to return |
| 88 | + * NullPointerException. |
| 89 | + */ |
| 90 | + @Test |
| 91 | + public final void testInvalidClass() throws URISyntaxException, ExecutionException, InterruptedException { |
| 92 | + final Path modelPath = Paths.get(getClass().getResource("/dummy_model").toURI()); |
| 93 | + final String id = "classificationModel"; |
| 94 | + final List<FieldSchema> fields = ImmutableList.of(FIELD_SCHEMA, CATEGORICAL_FIELD_SCHEMA); |
| 95 | + final String illegalTargetValue = "those"; |
| 96 | + final DatasetSchema schema = new DatasetSchema(1, fields); |
| 97 | + final Random random = new Random(); |
| 98 | + |
| 99 | + this.jepInstance.submitEvaluation(jep -> { |
| 100 | + // Add the model folder to the python import path |
| 101 | + jep.eval("import sys"); |
| 102 | + jep.eval(String.format("sys.path.append(\"%s\")", modelPath.toAbsolutePath())); |
| 103 | + |
| 104 | + // Import the Classifier custom class and store an instance of it in a variable with the name passed in "id" |
| 105 | + jep.eval("from classifier import Classifier"); |
| 106 | + jep.eval(String.format("%s = Classifier('%s')", id, illegalTargetValue)); |
| 107 | + |
| 108 | + return null; |
| 109 | + }).get(); |
| 110 | + |
| 111 | + final ClassificationPythonModel model = new ClassificationPythonModel(this.jepInstance, schema, id, "classify", "getClassDistribution"); |
| 112 | + |
| 113 | + assertThatThrownBy(() -> model.classify(new MockInstance(schema, random))) |
| 114 | + .as("A classifier that does not return a valid target value will fail with a null pointer exception") |
| 115 | + .isInstanceOf(NullPointerException.class); |
| 116 | + } |
| 117 | + |
62 | 118 | } |
0 commit comments