Skip to content

Commit 87a2918

Browse files
authored
Bump OpenML API version to 1.0.1 (#30)
* Bump OpenML API version to 1.0.1 Summary: Since OpenML 1.0.0 had some missing changes, this commits bumps the version to 1.0.1 and adapts the code the new changes * Simplify implementation * Test invalid classes on ClassificationPythonModel * Address codacy reports
1 parent 17cda05 commit 87a2918

File tree

4 files changed

+91
-13
lines changed

4 files changed

+91
-13
lines changed

openml-python-common/src/main/java/com/feedzai/openml/python/ClassificationPythonModel.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import com.feedzai.openml.data.schema.AbstractValueSchema;
2121
import com.feedzai.openml.data.schema.CategoricalValueSchema;
2222
import com.feedzai.openml.data.schema.DatasetSchema;
23+
import com.feedzai.openml.data.schema.FieldSchema;
2324
import com.feedzai.openml.model.ClassificationMLModel;
2425
import com.feedzai.openml.provider.exception.ModelLoadingException;
2526
import com.feedzai.openml.python.jep.instance.JepInstance;
@@ -191,7 +192,11 @@ public int classify(final Instance instance) {
191192
asNotNullable = this.classToIndexConverter.apply(classValue);
192193
} catch (final NullPointerException e) {
193194

194-
final AbstractValueSchema targetVarSchema = this.schema.getTargetFieldSchema().getValueSchema();
195+
//noinspection OptionalGetWithoutIsPresent
196+
final AbstractValueSchema targetVarSchema = this.schema.getTargetFieldSchema()
197+
.map(FieldSchema::getValueSchema)
198+
// since the dataset schema is immutable and the target variable existence was already checked in construction
199+
.get();
195200
final Function<CategoricalValueSchema, String> block = targetSchema -> String.format(
196201
"Unexpected class provided by model: %s. Expected values: %s",
197202
classValue,
@@ -201,7 +206,7 @@ public int classify(final Instance instance) {
201206
final String msg = ClassificationDatasetSchemaUtil.withCategoricalValueSchema(targetVarSchema, block)
202207
.orElseThrow(() -> new RuntimeException("The target variable is not a categorical value: " + targetVarSchema));
203208

204-
logger.warn(msg, e);
209+
logger.error(msg, e);
205210
throw e;
206211
}
207212
return asNotNullable;
@@ -255,9 +260,14 @@ public void close() {
255260
* @return The conversion function.
256261
*/
257262
private Function<Serializable, Integer> getClassToIndexConverter(final DatasetSchema schema) {
258-
final AbstractValueSchema targetVariableSchema = schema.getTargetFieldSchema().getValueSchema();
263+
//noinspection OptionalGetWithoutIsPresent
264+
final AbstractValueSchema targetVariableSchema = schema.getTargetFieldSchema()
265+
.map(FieldSchema::getValueSchema)
266+
// since the dataset schema is immutable and the target variable existence was already checked in construction
267+
.get();
268+
259269
if (!(targetVariableSchema instanceof CategoricalValueSchema)) {
260-
logger.warn("Provided schema's target field is not categorical: {}", schema);
270+
logger.error("Provided schema's target field is not categorical: {}", schema);
261271
throw new IllegalArgumentException("Classification models require Categorical target fields. Got " + targetVariableSchema);
262272
}
263273
return EncodingHelper.classToIndexConverter((CategoricalValueSchema) targetVariableSchema);

openml-python-common/src/test/java/com/feedzai/openml/python/ClassificationPythonModelTest.java

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,60 @@
1-
package com.feedzai.openml.python;/*
2-
* The copyright of this file belongs to Feedzai. The file cannot be
3-
* reproduced in whole or in part, stored in a retrieval system,
4-
* transmitted in any form, or by any means electronic, mechanical,
5-
* photocopying, or otherwise, without the prior permission of the owner.
1+
/*
2+
* Copyright (c) 2019 Feedzai
63
*
7-
* © 2019 Feedzai, Strictly Confidential
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
815
*/
916

17+
package com.feedzai.openml.python;
18+
19+
import com.feedzai.openml.data.schema.CategoricalValueSchema;
1020
import com.feedzai.openml.data.schema.DatasetSchema;
1121
import com.feedzai.openml.data.schema.FieldSchema;
1222
import com.feedzai.openml.data.schema.NumericValueSchema;
23+
import com.feedzai.openml.mocks.MockInstance;
1324
import com.feedzai.openml.python.jep.instance.JepInstance;
1425
import com.google.common.collect.ImmutableList;
26+
import com.google.common.collect.ImmutableSet;
1527
import org.junit.Before;
1628
import org.junit.Test;
1729

30+
import java.net.URISyntaxException;
31+
import java.nio.file.Path;
32+
import java.nio.file.Paths;
33+
import java.util.List;
34+
import java.util.Random;
35+
import java.util.concurrent.ExecutionException;
36+
1837
import static org.assertj.core.api.Assertions.assertThatThrownBy;
1938

2039
/**
2140
* Tests the behaviour of a {@link ClassificationPythonModel}.
2241
*
2342
* @author Joao Sousa (joao.sousa@feedzai.com)
24-
* @since @@@feedzai.next.release@@@
43+
* @since 0.2.0
2544
*/
2645
public class ClassificationPythonModelTest {
2746

2847
/**
29-
* A field to use in the tests.
48+
* A numerical field to use in the tests.
3049
*/
3150
private static final FieldSchema FIELD_SCHEMA =
3251
new FieldSchema("field", 0, new NumericValueSchema(false));
3352

53+
/**
54+
* A categorical field to use in the tests.
55+
*/
56+
private static final FieldSchema CATEGORICAL_FIELD_SCHEMA =
57+
new FieldSchema("categorical", 1, new CategoricalValueSchema(false, ImmutableSet.of("this", "that")));
3458

3559
/**
3660
* The wrapper for the Jep object used in the tests.
@@ -59,4 +83,36 @@ public final void testSchemaWithoutTargetVariable() {
5983
.isInstanceOf(IllegalArgumentException.class);
6084
}
6185

86+
/**
87+
* Tests that a classifier which classifies an instance with an invalid target value will cause the model to return
88+
* NullPointerException.
89+
*/
90+
@Test
91+
public final void testInvalidClass() throws URISyntaxException, ExecutionException, InterruptedException {
92+
final Path modelPath = Paths.get(getClass().getResource("/dummy_model").toURI());
93+
final String id = "classificationModel";
94+
final List<FieldSchema> fields = ImmutableList.of(FIELD_SCHEMA, CATEGORICAL_FIELD_SCHEMA);
95+
final String illegalTargetValue = "those";
96+
final DatasetSchema schema = new DatasetSchema(1, fields);
97+
final Random random = new Random();
98+
99+
this.jepInstance.submitEvaluation(jep -> {
100+
// Add the model folder to the python import path
101+
jep.eval("import sys");
102+
jep.eval(String.format("sys.path.append(\"%s\")", modelPath.toAbsolutePath()));
103+
104+
// Import the Classifier custom class and store an instance of it in a variable with the name passed in "id"
105+
jep.eval("from classifier import Classifier");
106+
jep.eval(String.format("%s = Classifier('%s')", id, illegalTargetValue));
107+
108+
return null;
109+
}).get();
110+
111+
final ClassificationPythonModel model = new ClassificationPythonModel(this.jepInstance, schema, id, "classify", "getClassDistribution");
112+
113+
assertThatThrownBy(() -> model.classify(new MockInstance(schema, random)))
114+
.as("A classifier that does not return a valid target value will fail with a null pointer exception")
115+
.isInstanceOf(NullPointerException.class);
116+
}
117+
62118
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
class Classifier(object):
2+
3+
def __init__(self, target_value):
4+
self.target_value = target_value
5+
self.multiplier = [1, 0, 0]
6+
7+
def classify(self, instances):
8+
return self.target_value
9+
10+
11+
def getClassDistribution(self, instances):
12+
return [self.multiplier] * len(instances)

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
<assertj.version>3.7.0</assertj.version>
8080
<jackson.version>2.6.7</jackson.version>
8181
<jackson-databind.version>2.6.7</jackson-databind.version>
82-
<openml-api.version>1.0.0</openml-api.version>
82+
<openml-api.version>1.0.1</openml-api.version>
8383
<jep.version>3.7.0</jep.version>
8484
</properties>
8585

0 commit comments

Comments
 (0)