@@ -53,7 +53,8 @@ class SimpleExperiment(OptionHandler):
5353 http://weka.wikispaces.com/Using+the+Experiment+API
5454 """
5555
56- def __init__ (self , datasets , classifiers , jobject = None , classification = True , runs = 10 , result = None ):
56+ def __init__ (self , datasets , classifiers , jobject = None , classification = True , runs = 10 , result = None ,
57+ class_for_ir_statistics = 0 , attribute_id = - 1 , pred_target_column = False ):
5758 """
5859 Initializes the experiment.
5960
@@ -69,6 +70,12 @@ def __init__(self, datasets, classifiers, jobject=None, classification=True, run
6970 :type runs: int
7071 :param result: the filename of the file to store the results in
7172 :type result: str
73+ :param class_for_ir_statistics: the class label index to use IR statistics (classification only)
74+ :type class_for_ir_statistics: int
75+ :param attribute_id: the 0-based index of the attribute identifying instances (classification only)
76+ :type attribute_id: int
77+ :param pred_target_column: whether to store the predicted and target columns as well (classification only)
78+ :type pred_target_column: bool
7279 """
7380
7481 if not jobject is None :
@@ -81,6 +88,9 @@ def __init__(self, datasets, classifiers, jobject=None, classification=True, run
8188 self .datasets = datasets [:]
8289 self .classifiers = classifiers [:]
8390 self .result = result
91+ self .class_for_ir_statistics = class_for_ir_statistics
92+ self .attribute_id = attribute_id
93+ self .pred_target_column = pred_target_column
8494 super (SimpleExperiment , self ).__init__ (jobject = jobject )
8595
8696 def configure_splitevaluator (self ):
@@ -92,8 +102,12 @@ def configure_splitevaluator(self):
92102 """
93103 if self .classification :
94104 speval = javabridge .make_instance ("weka/experiment/ClassifierSplitEvaluator" , "()V" )
105+ javabridge .call (speval , "setClassForIRStatistics" , "(I)V" , self .class_for_ir_statistics )
106+ javabridge .call (speval , "setAttributeID" , "(I)V" , self .attribute_id )
107+ javabridge .call (speval , "setPredTargetColumn" , "(Z)V" , self .pred_target_column )
95108 else :
96109 speval = javabridge .make_instance ("weka/experiment/RegressionSplitEvaluator" , "()V" )
110+
97111 classifier = javabridge .call (speval , "getClassifier" , "()Lweka/classifiers/Classifier;" )
98112 return speval , classifier
99113
@@ -221,7 +235,8 @@ class SimpleCrossValidationExperiment(SimpleExperiment):
221235 Performs a simple cross-validation experiment. Can output the results either in ARFF or CSV.
222236 """
223237
224- def __init__ (self , datasets , classifiers , classification = True , runs = 10 , folds = 10 , result = None ):
238+ def __init__ (self , datasets , classifiers , classification = True , runs = 10 , folds = 10 , result = None ,
239+ class_for_ir_statistics = 0 , attribute_id = - 1 , pred_target_column = False ):
225240 """
226241 Initializes the experiment.
227242
@@ -237,6 +252,12 @@ def __init__(self, datasets, classifiers, classification=True, runs=10, folds=10
237252 :type folds: int
238253 :param result: the filename of the file to store the results in
239254 :type result: str
255+ :param class_for_ir_statistics: the class label index to use IR statistics (classification only)
256+ :type class_for_ir_statistics: int
257+ :param attribute_id: the 0-based index of the attribute identifying instances (classification only)
258+ :type attribute_id: int
259+ :param pred_target_column: whether to store the predicted and target columns as well (classification only)
260+ :type pred_target_column: bool
240261 """
241262
242263 if runs < 1 :
@@ -252,7 +273,9 @@ def __init__(self, datasets, classifiers, classification=True, runs=10, folds=10
252273
253274 super (SimpleCrossValidationExperiment , self ).__init__ (
254275 classification = classification , runs = runs , datasets = datasets ,
255- classifiers = classifiers , result = result )
276+ classifiers = classifiers , result = result ,
277+ class_for_ir_statistics = class_for_ir_statistics , attribute_id = attribute_id ,
278+ pred_target_column = pred_target_column )
256279
257280 self .folds = folds
258281
@@ -293,7 +316,7 @@ class SimpleRandomSplitExperiment(SimpleExperiment):
293316 """
294317
295318 def __init__ (self , datasets , classifiers , classification = True , runs = 10 , percentage = 66.6 , preserve_order = False ,
296- result = None ):
319+ result = None , class_for_ir_statistics = 0 , attribute_id = - 1 , pred_target_column = False ):
297320 """
298321 Initializes the experiment.
299322
@@ -311,6 +334,12 @@ def __init__(self, datasets, classifiers, classification=True, runs=10, percenta
311334 :type classifiers: list
312335 :param result: the filename of the file to store the results in
313336 :type result: str
337+ :param class_for_ir_statistics: the class label index to use IR statistics (classification only)
338+ :type class_for_ir_statistics: int
339+ :param attribute_id: the 0-based index of the attribute identifying instances (classification only)
340+ :type attribute_id: int
341+ :param pred_target_column: whether to store the predicted and target columns as well (classification only)
342+ :type pred_target_column: bool
314343 """
315344
316345 if runs < 1 :
@@ -328,7 +357,9 @@ def __init__(self, datasets, classifiers, classification=True, runs=10, percenta
328357
329358 super (SimpleRandomSplitExperiment , self ).__init__ (
330359 classification = classification , runs = runs , datasets = datasets ,
331- classifiers = classifiers , result = result )
360+ classifiers = classifiers , result = result ,
361+ class_for_ir_statistics = class_for_ir_statistics , attribute_id = attribute_id ,
362+ pred_target_column = pred_target_column )
332363
333364 self .percentage = percentage
334365 self .preserve_order = preserve_order
0 commit comments