Skip to content
This repository was archived by the owner on Aug 2, 2022. It is now read-only.

Commit bf69177

Browse files
committed
Closes #3
Updated use of bidict to take advantage of new API.
1 parent 1e0edaa commit bf69177

File tree

1 file changed

+9
-58
lines changed

1 file changed

+9
-58
lines changed

biosppy/biometrics.py

Lines changed: 9 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -57,55 +57,6 @@ def __str__(self):
5757
return str("Combination of empty array.")
5858

5959

60-
class SubjectDict(bidict):
61-
"""Adaptation of bidirectional dictionary to return default values
62-
on KeyError.
63-
64-
Attributes
65-
----------
66-
LEFT : hashable
67-
Left default token.
68-
Right : hashable
69-
Right default token.
70-
71-
"""
72-
73-
LEFT = ''
74-
RIGHT = ''
75-
76-
def __getitem__(self, keyorslice):
77-
"""Get an item; based on the bidict source."""
78-
79-
try:
80-
start = keyorslice.start
81-
stop = keyorslice.stop
82-
step = keyorslice.step
83-
except AttributeError:
84-
# keyorslice is a key, e.g. b[key]
85-
try:
86-
return self._fwd[keyorslice]
87-
except KeyError:
88-
return self.RIGHT
89-
90-
# keyorslice is a slice
91-
if (not ((start is None) ^ (stop is None))) or step is not None:
92-
raise TypeError('Slice must only specify either start or stop')
93-
94-
if start is not None:
95-
# forward lookup (by key), e.g. b[key:]
96-
try:
97-
return self._fwd[start]
98-
except KeyError:
99-
return self.RIGHT
100-
101-
# inverse lookup (by val), e.g. b[:val]
102-
assert stop is not None
103-
try:
104-
return self._bwd[stop]
105-
except KeyError:
106-
return self.LEFT
107-
108-
10960
class BaseClassifier(object):
11061
"""Base biometric classifier class.
11162
@@ -133,7 +84,7 @@ class BaseClassifier(object):
13384
def __init__(self):
13485
# generic self things
13586
self.is_trained = False
136-
self._subject2label = SubjectDict()
87+
self._subject2label = bidict()
13788
self._nbSubjects = 0
13889
self._thresholds = {}
13990
self._autoThresholds = None
@@ -296,7 +247,7 @@ def list_subjects(self):
296247
297248
"""
298249

299-
subjects = [self._subject2label[:i] for i in xrange(self._nbSubjects)]
250+
subjects = self._subject2label.keys()
300251

301252
return subjects
302253

@@ -448,7 +399,7 @@ def update_thresholds(self, fraction=1.):
448399

449400
# gather data to test
450401
data = {}
451-
for subject, label in self._subject2label.items():
402+
for subject, label in self._subject2label.iteritems():
452403
# select a random fraction of the training data
453404
aux = self.io_load(label)
454405
indx = range(len(aux))
@@ -460,7 +411,7 @@ def update_thresholds(self, fraction=1.):
460411
_, res = self.evaluate(data, ths)
461412

462413
# choose thresholds at EER
463-
for subject, label in self._subject2label.items():
414+
for subject, label in self._subject2label.iteritems():
464415
EER_auth = res['subject'][subject]['authentication']['rates']['EER']
465416
self.set_auth_thr(label, EER_auth[self.EER_IDX, 0], ready=True)
466417

@@ -653,7 +604,7 @@ def identify(self, data, threshold=None):
653604
labels = self._identify(aux, threshold)
654605

655606
# translate class labels
656-
subjects = [self._subject2label[:item] for item in labels]
607+
subjects = [self._subject2label.inv.get(item, '') for item in labels]
657608

658609
return subjects
659610

@@ -1216,7 +1167,7 @@ class SVM(BaseClassifier):
12161167
Degree of the polynomial kernel function (‘poly’). Ignored by all other
12171168
kernels.
12181169
gamma : float, optional
1219-
Kernel coefficient for ‘rbf’, ‘poly’ and ‘sigmoid’. If gamma is 0.0
1170+
Kernel coefficient for ‘rbf’, ‘poly’ and ‘sigmoid’. If gamma is 'auto'
12201171
then 1/n_features will be used instead.
12211172
coef0 : float, optional
12221173
Independent term in kernel function. It is only significant in ‘poly’
@@ -1246,7 +1197,7 @@ def __init__(self,
12461197
C=1.0,
12471198
kernel='linear',
12481199
degree=3,
1249-
gamma=0.0,
1200+
gamma='auto',
12501201
coef0=0.0,
12511202
shrinking=True,
12521203
tol=0.001,
@@ -1843,7 +1794,7 @@ def get_subject_results(results=None,
18431794
Classifier thresholds.
18441795
subjects : list
18451796
Target subject classes.
1846-
subject_dict : SubjectDict
1797+
subject_dict : bidict
18471798
Subject-label conversion dictionary.
18481799
subject_idx : list
18491800
Subject index.
@@ -1927,7 +1878,7 @@ def get_subject_results(results=None,
19271878
misses = np.logical_not(np.logical_or(hits, rejects))
19281879
nmisses = ns - (nhits + nrejects)
19291880
missCounts = {
1930-
subject_dict[:ms]: np.sum(res == ms)
1881+
subject_dict.inv[ms]: np.sum(res == ms)
19311882
for ms in np.unique(res[misses])
19321883
}
19331884

0 commit comments

Comments
 (0)