Skip to content

Commit 6e28da8

Browse files
committed
support balanced class_weight #141
1 parent 8466021 commit 6e28da8

File tree

3 files changed

+47
-11
lines changed

3 files changed

+47
-11
lines changed

python/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ The usage of thundersvm scikit interface is similar to sklearn.svm.
6969
*probability*: boolean, optional(default=False)\
7070
whether to train a SVC or SVR model for probability estimates, True or False
7171

72-
*class_weight*: {dict}, optional(default=None)\
73-
set the parameter C of class i to weight*C, for C-SVC
72+
*class_weight*: {dict, 'balanced'}, optional(default=None)\
73+
set the parameter C of class i to weight*C, for C-SVC. If not given, all classes are supposed to have weight one. The “balanced” mode uses the values of y to automatically adjust weights inversely proportional to class frequencies in the input data as ```n_samples / (n_classes * np.bincount(y))```
7474

7575
*shrinking*: boolean, optional (default=False, not supported yet for True)\
7676
whether to use the shrinking heuristic.

python/thundersvm/thundersvm.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,30 @@ def _dense_fit(self, X, y, solver_type, kernel):
185185
if self.class_weight is None:
186186
weight_size = 0
187187
self.class_weight = dict()
188+
weight_label = (c_int * weight_size)()
189+
weight_label[:] = list(self.class_weight.keys())
190+
weight = (c_float * weight_size)()
191+
weight[:] = list(self.class_weight.values())
192+
elif self.class_weight == 'balanced':
193+
y_unique = np.unique(y)
194+
y_count = np.bincount(y.astype(int))
195+
weight_label_list = []
196+
weight_list = []
197+
for n in range(0, len(y_count)):
198+
if y_count[n] != 0:
199+
weight_label_list.append(n)
200+
weight_list.append(samples/(len(y_unique)*y_count[n]))
201+
weight_size=len(weight_list)
202+
weight_label = (c_int * weight_size)()
203+
weight_label[:] = weight_label_list
204+
weight = (c_float * weight_size)()
205+
weight[:] = weight_list
188206
else:
189207
weight_size = len(self.class_weight)
190-
weight_label = (c_int * weight_size)()
191-
weight_label[:] = list(self.class_weight.keys())
192-
weight = (c_float * weight_size)()
193-
weight[:] = list(self.class_weight.values())
208+
weight_label = (c_int * weight_size)()
209+
weight_label[:] = list(self.class_weight.keys())
210+
weight = (c_float * weight_size)()
211+
weight[:] = list(self.class_weight.values())
194212

195213
n_features = (c_int * 1)()
196214
n_classes = (c_int * 1)()
@@ -228,12 +246,30 @@ def _sparse_fit(self, X, y, solver_type, kernel):
228246
if self.class_weight is None:
229247
weight_size = 0
230248
self.class_weight = dict()
249+
weight_label = (c_int * weight_size)()
250+
weight_label[:] = list(self.class_weight.keys())
251+
weight = (c_float * weight_size)()
252+
weight[:] = list(self.class_weight.values())
253+
elif self.class_weight == 'balanced':
254+
y_unique = np.unique(y)
255+
y_count = np.bincount(y.astype(int))
256+
weight_label_list = []
257+
weight_list = []
258+
for n in range(0, len(y_count)):
259+
if y_count[n] != 0:
260+
weight_label_list.append(n)
261+
weight_list.append(X.shape[0]/(len(y_unique)*y_count[n]))
262+
weight_size=len(weight_list)
263+
weight_label = (c_int * weight_size)()
264+
weight_label[:] = weight_label_list
265+
weight = (c_float * weight_size)()
266+
weight[:] = weight_list
231267
else:
232268
weight_size = len(self.class_weight)
233-
weight_label = (c_int * weight_size)()
234-
weight_label[:] = list(self.class_weight.keys())
235-
weight = (c_float * weight_size)()
236-
weight[:] = list(self.class_weight.values())
269+
weight_label = (c_int * weight_size)()
270+
weight_label[:] = list(self.class_weight.keys())
271+
weight = (c_float * weight_size)()
272+
weight[:] = list(self.class_weight.values())
237273

238274
n_features = (c_int * 1)()
239275
n_classes = (c_int * 1)()

src/thundersvm/thundersvm-scikit.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ extern "C" {
221221
param_cmd.max_mem_size = static_cast<size_t>(max(max_mem_size, 0)) << 20;
222222
if(weight_size != 0) {
223223
param_cmd.nr_weight = weight_size;
224-
param_cmd.weight = (float_type *) malloc(weight_size * sizeof(float_type));
224+
param_cmd.weight = (float_type *) malloc(weight_size * sizeof(double));
225225
param_cmd.weight_label = (int *) malloc(weight_size * sizeof(int));
226226
for (int i = 0; i < weight_size; i++) {
227227
param_cmd.weight[i] = weight[i];

0 commit comments

Comments
 (0)