Skip to content

Commit daa9154

Browse files
committed
added methods subset_results/rank_results to class AttributeSelection to make results of CV accessible in numeric form (though parsed from textual CV output)
1 parent 3fb5917 commit daa9154

File tree

2 files changed

+95
-1
lines changed

2 files changed

+95
-1
lines changed

CHANGES.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ Changelog
66

77
- moved `Configurable` and `JSONObject` into *configurable-objects* library
88
- moved base flow components into *simple-data-flow* library
9+
- added methods `subset_results`, `rank_results` to class `AttributeSelection`
10+
(module: `weka.attribute_selection`) to give access to cross-validation
11+
output in numeric rather textual form (NB: it has to parse the textual CV output).
912

1013

1114
0.2.12 (2022-12-08)

python/weka/attribute_selection.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def __init__(self):
207207
"""
208208
jobject = AttributeSelection.new_instance("weka.attributeSelection.AttributeSelection")
209209
super(AttributeSelection, self).__init__(jobject)
210+
self._cv_results = None
210211

211212
def evaluator(self, evaluator):
212213
"""
@@ -269,6 +270,7 @@ def select_attributes(self, instances):
269270
:param instances: the data to process
270271
:type instances: Instances
271272
"""
273+
self._cv_results = None
272274
javabridge.call(self.jobject, "SelectAttributes", "(Lweka/core/Instances;)V", instances.jobject)
273275

274276
def select_attributes_cv_split(self, instances):
@@ -312,7 +314,96 @@ def cv_results(self):
312314
:return: the results string
313315
:rtype: str
314316
"""
315-
return javabridge.call(self.jobject, "CVResultsString", "()Ljava/lang/String;")
317+
if self._cv_results is None:
318+
self._cv_results = javabridge.call(self.jobject, "CVResultsString", "()Ljava/lang/String;")
319+
return self._cv_results
320+
321+
@property
322+
def subset_results(self):
323+
"""
324+
Returns the results from the cross-validation subsets, i.e., how often
325+
an attribute was selected.
326+
327+
Unfortunately, the Weka API does not give direct access to underlying
328+
data structures, hence we have to parse the textual output.
329+
330+
:return: the list of results (double)
331+
:rtype: list
332+
"""
333+
if self._cv_results is None:
334+
raise Exception("No attribute selection performed?")
335+
lines = self._cv_results.split("\n")
336+
337+
# ranking or subset eval?
338+
start = 0
339+
for i, l in enumerate(lines):
340+
if "average merit" in l:
341+
raise Exception("Cannot parse output from ranker!")
342+
elif "number of folds" in l:
343+
start = i + 1
344+
break
345+
346+
# parse text
347+
result = []
348+
for i in range(start, len(lines)):
349+
l = lines[i]
350+
if "(" in l:
351+
result.append(float(l[0:l.index("(")].strip()))
352+
353+
return result
354+
355+
@property
356+
def rank_results(self):
357+
"""
358+
Returns the results from the cross-validation for rankers.
359+
360+
Unfortunately, the Weka API does not give direct access to underlying
361+
data structures, hence we have to parse the textual output.
362+
363+
:return: the dictionary of results (mean and stdev for rank and merit)
364+
:rtype: dict
365+
"""
366+
if self._cv_results is None:
367+
raise Exception("No attribute selection performed?")
368+
lines = self.cv_results.split("\n")
369+
370+
# ranking or subset eval?
371+
start = 0
372+
for i, l in enumerate(lines):
373+
if "average merit" in l:
374+
start = i + 1
375+
break
376+
elif "number of folds" in l:
377+
raise Exception("Cannot parse output from non-rankers!")
378+
379+
# parse text
380+
merit_mean = []
381+
merit_stdev = []
382+
rank_mean = []
383+
rank_stdev = []
384+
for i in range(start, len(lines)):
385+
l = lines[i]
386+
if "+-" in l:
387+
parts = l.split(" +- ")
388+
inner = None
389+
right = None
390+
if len(parts) == 3:
391+
inner = [x for x in parts[1].split(" ") if x]
392+
right = [x for x in parts[2].split(" ") if x]
393+
if (len(inner) == 2) and (len(right) > 2):
394+
merit_mean.append(float(parts[0]))
395+
merit_stdev.append(float(inner[0]))
396+
rank_mean.append(float(inner[1]))
397+
rank_stdev.append(float(right[0]))
398+
399+
result = {
400+
"merit_mean": merit_mean,
401+
"merit_stdev": merit_stdev,
402+
"rank_mean": rank_mean,
403+
"rank_stdev": rank_stdev,
404+
}
405+
406+
return result
316407

317408
@property
318409
def number_attributes_selected(self):

0 commit comments

Comments
 (0)