Skip to content

Commit 0ff09d3

Browse files
committed
Add UT for method classify
There was not UT to check that #classify returns the correct value. Add UT #canGetClassDistributionMaxValueIndex to check that #classify returns index of max value.
1 parent 8fc2d17 commit 0ff09d3

File tree

3 files changed

+39
-1
lines changed

3 files changed

+39
-1
lines changed

openml-generic-python/src/test/java/com/feedzai/openml/python/PythonModelProviderTest.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import com.feedzai.openml.data.Instance;
2020
import com.feedzai.openml.data.schema.DatasetSchema;
2121
import com.feedzai.openml.mocks.MockDataset;
22+
import com.feedzai.openml.model.ClassificationMLModel;
2223
import com.feedzai.openml.provider.exception.ModelLoadingException;
2324
import com.feedzai.openml.util.algorithm.GenericAlgorithm;
2425
import com.feedzai.openml.util.algorithm.MLAlgorithmEnum;
@@ -104,6 +105,24 @@ public class PythonModelProviderTest extends AbstractProviderModelLoadTest<Class
104105
@Rule
105106
public final ExpectedException exception = ExpectedException.none();
106107

108+
/**
109+
* Verifies that the {@link ClassificationMLModel#classify(Instance)} " returns the index of the greatest value in
110+
* the class probability distribution produced by the calling
111+
* {@link ClassificationMLModel#getClassDistribution(Instance)} on the model
112+
*
113+
* @see ClassificationMLModel
114+
*/
115+
@Test
116+
public void canGetClassDistributionMaxValueIndex() throws Exception {
117+
118+
final ClassificationPythonModel model = getSecondModel();
119+
120+
final Instance instance = getDummyInstance();
121+
122+
this.canGetClassDistributionMaxValueIndex(model, instance);
123+
124+
}
125+
107126
/**
108127
* Tests loading a model that is not compatible with the given schema.
109128
* Should throw a ModelLoadingException.

openml-scikit/src/test/java/com/feedzai/openml/scikit/ScikitModelProviderTest.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import com.feedzai.openml.data.Instance;
2020
import com.feedzai.openml.data.schema.DatasetSchema;
2121
import com.feedzai.openml.mocks.MockDataset;
22+
import com.feedzai.openml.model.ClassificationMLModel;
2223
import com.feedzai.openml.provider.exception.ModelLoadingException;
2324
import com.feedzai.openml.python.ClassificationPythonModel;
2425
import com.feedzai.openml.util.algorithm.MLAlgorithmEnum;
@@ -95,6 +96,24 @@ public class ScikitModelProviderTest extends AbstractProviderModelLoadTest<Class
9596
@Rule
9697
public final ExpectedException exception = ExpectedException.none();
9798

99+
/**
100+
* Verifies that the {@link ClassificationMLModel#classify(Instance)} " returns the index of the greatest value in
101+
* the class probability distribution produced by the calling
102+
* {@link ClassificationMLModel#getClassDistribution(Instance)} on the model
103+
*
104+
* @see ClassificationMLModel
105+
*/
106+
@Test
107+
public void canGetClassDistributionMaxValueIndex() throws Exception {
108+
109+
final ClassificationPythonModel model = getFirstModel();
110+
111+
final Instance instance = getDummyInstance();
112+
113+
this.canGetClassDistributionMaxValueIndex(model, instance);
114+
115+
}
116+
98117
/**
99118
* Tests loading a model that does not support class distribution classification.
100119
* Should throw a ModelLoadingException.

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
<assertj.version>3.7.0</assertj.version>
8080
<jackson.version>2.6.7</jackson.version>
8181
<jackson-databind.version>2.6.7</jackson-databind.version>
82-
<openml-api.version>0.4.2</openml-api.version>
82+
<openml-api.version>0.4.3</openml-api.version>
8383
<jep.version>3.7.0</jep.version>
8484
</properties>
8585

0 commit comments

Comments
 (0)