Skip to content

Commit 69b0934

Browse files
committed
added parameters class_for_ir_statistics, attribute_id, pred_target_column to constructor of SimpleExperiment class and derived classes that configure the ClassifierSplitEvaluator accordingly
1 parent d0b2980 commit 69b0934

File tree

3 files changed

+55
-5
lines changed

3 files changed

+55
-5
lines changed

CHANGES.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ Changelog
1010
train/test tuples as used by cross-validation
1111
- the `Tester` class (module: `weka.experiments`) now has an option to swap columns/rows for comparing
1212
datasets rather than classifiers
13+
- the `SimpleExperiment` class and derived classes (module: `weka.experiments`) now have the additional
14+
parameters in the constructor: class_for_ir_statistics, attribute_id, pred_target_column
1315
- ...
1416

1517

doc/source/examples.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,14 @@ Here is an example for performing a cross-validated classification experiment:
720720
print(tester.multi_resultset_full(1, comparison_col))
721721
722722
723+
Other parameters that can be supplied to the constructor of the `SimpleCrossValidationExperiment` or
724+
`SimpleRandomSplitExperiment` classes:
725+
726+
* `class_for_ir_statistics` - defines the class label to use for computing IR statistics like AUC
727+
* `attribute_id` - the 0-based index of the attribute that identifies rows
728+
* `pred_target_column` - for outputting the predictions and ground truth in separate columns in case of classification, e.g., for calculating confusion matrices manually afterwards
729+
730+
723731
And a setup for performing regression experiments on random splits on the datasets:
724732

725733
.. code-block:: python
@@ -753,6 +761,15 @@ And a setup for performing regression experiments on random splits on the datase
753761
print(tester.multi_resultset_full(0, comparison_col))
754762
755763
764+
The `Tester` class allows you to swap columns and rows, therefore comparing datasets rather than classifiers:
765+
766+
.. code-block:: python
767+
768+
tester = Tester(classname="weka.experiment.PairedCorrectedTTester")
769+
tester.swap_rows_and_cols = True
770+
tester.resultmatrix = matrix
771+
772+
756773
Partial classnames
757774
------------------
758775

python/weka/experiments.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ class SimpleExperiment(OptionHandler):
5353
http://weka.wikispaces.com/Using+the+Experiment+API
5454
"""
5555

56-
def __init__(self, datasets, classifiers, jobject=None, classification=True, runs=10, result=None):
56+
def __init__(self, datasets, classifiers, jobject=None, classification=True, runs=10, result=None,
57+
class_for_ir_statistics=0, attribute_id=-1, pred_target_column=False):
5758
"""
5859
Initializes the experiment.
5960
@@ -69,6 +70,12 @@ def __init__(self, datasets, classifiers, jobject=None, classification=True, run
6970
:type runs: int
7071
:param result: the filename of the file to store the results in
7172
:type result: str
73+
:param class_for_ir_statistics: the class label index to use IR statistics (classification only)
74+
:type class_for_ir_statistics: int
75+
:param attribute_id: the 0-based index of the attribute identifying instances (classification only)
76+
:type attribute_id: int
77+
:param pred_target_column: whether to store the predicted and target columns as well (classification only)
78+
:type pred_target_column: bool
7279
"""
7380

7481
if not jobject is None:
@@ -81,6 +88,9 @@ def __init__(self, datasets, classifiers, jobject=None, classification=True, run
8188
self.datasets = datasets[:]
8289
self.classifiers = classifiers[:]
8390
self.result = result
91+
self.class_for_ir_statistics = class_for_ir_statistics
92+
self.attribute_id = attribute_id
93+
self.pred_target_column = pred_target_column
8494
super(SimpleExperiment, self).__init__(jobject=jobject)
8595

8696
def configure_splitevaluator(self):
@@ -92,8 +102,12 @@ def configure_splitevaluator(self):
92102
"""
93103
if self.classification:
94104
speval = javabridge.make_instance("weka/experiment/ClassifierSplitEvaluator", "()V")
105+
javabridge.call(speval, "setClassForIRStatistics", "(I)V", self.class_for_ir_statistics)
106+
javabridge.call(speval, "setAttributeID", "(I)V", self.attribute_id)
107+
javabridge.call(speval, "setPredTargetColumn", "(Z)V", self.pred_target_column)
95108
else:
96109
speval = javabridge.make_instance("weka/experiment/RegressionSplitEvaluator", "()V")
110+
97111
classifier = javabridge.call(speval, "getClassifier", "()Lweka/classifiers/Classifier;")
98112
return speval, classifier
99113

@@ -221,7 +235,8 @@ class SimpleCrossValidationExperiment(SimpleExperiment):
221235
Performs a simple cross-validation experiment. Can output the results either in ARFF or CSV.
222236
"""
223237

224-
def __init__(self, datasets, classifiers, classification=True, runs=10, folds=10, result=None):
238+
def __init__(self, datasets, classifiers, classification=True, runs=10, folds=10, result=None,
239+
class_for_ir_statistics=0, attribute_id=-1, pred_target_column=False):
225240
"""
226241
Initializes the experiment.
227242
@@ -237,6 +252,12 @@ def __init__(self, datasets, classifiers, classification=True, runs=10, folds=10
237252
:type folds: int
238253
:param result: the filename of the file to store the results in
239254
:type result: str
255+
:param class_for_ir_statistics: the class label index to use IR statistics (classification only)
256+
:type class_for_ir_statistics: int
257+
:param attribute_id: the 0-based index of the attribute identifying instances (classification only)
258+
:type attribute_id: int
259+
:param pred_target_column: whether to store the predicted and target columns as well (classification only)
260+
:type pred_target_column: bool
240261
"""
241262

242263
if runs < 1:
@@ -252,7 +273,9 @@ def __init__(self, datasets, classifiers, classification=True, runs=10, folds=10
252273

253274
super(SimpleCrossValidationExperiment, self).__init__(
254275
classification=classification, runs=runs, datasets=datasets,
255-
classifiers=classifiers, result=result)
276+
classifiers=classifiers, result=result,
277+
class_for_ir_statistics=class_for_ir_statistics, attribute_id=attribute_id,
278+
pred_target_column=pred_target_column)
256279

257280
self.folds = folds
258281

@@ -293,7 +316,7 @@ class SimpleRandomSplitExperiment(SimpleExperiment):
293316
"""
294317

295318
def __init__(self, datasets, classifiers, classification=True, runs=10, percentage=66.6, preserve_order=False,
296-
result=None):
319+
result=None, class_for_ir_statistics=0, attribute_id=-1, pred_target_column=False):
297320
"""
298321
Initializes the experiment.
299322
@@ -311,6 +334,12 @@ def __init__(self, datasets, classifiers, classification=True, runs=10, percenta
311334
:type classifiers: list
312335
:param result: the filename of the file to store the results in
313336
:type result: str
337+
:param class_for_ir_statistics: the class label index to use IR statistics (classification only)
338+
:type class_for_ir_statistics: int
339+
:param attribute_id: the 0-based index of the attribute identifying instances (classification only)
340+
:type attribute_id: int
341+
:param pred_target_column: whether to store the predicted and target columns as well (classification only)
342+
:type pred_target_column: bool
314343
"""
315344

316345
if runs < 1:
@@ -328,7 +357,9 @@ def __init__(self, datasets, classifiers, classification=True, runs=10, percenta
328357

329358
super(SimpleRandomSplitExperiment, self).__init__(
330359
classification=classification, runs=runs, datasets=datasets,
331-
classifiers=classifiers, result=result)
360+
classifiers=classifiers, result=result,
361+
class_for_ir_statistics=class_for_ir_statistics, attribute_id=attribute_id,
362+
pred_target_column=pred_target_column)
332363

333364
self.percentage = percentage
334365
self.preserve_order = preserve_order

0 commit comments

Comments
 (0)