Skip to content

Commit ee17f08

Browse files
committed
Expose some hyper params in nbsvm
1 parent 37687c0 commit ee17f08

File tree

1 file changed

+5
-4
lines changed
  • sota_extractor2/models/structure

1 file changed

+5
-4
lines changed

sota_extractor2/models/structure/nbsvm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@ def get_number_of_classes(y):
3939
return y.shape[1]
4040

4141
class NBSVM:
42-
def __init__(self, solver='liblinear', dual=True):
42+
def __init__(self, solver='liblinear', dual=True, C=4, ngram_range=(1, 2)):
4343
self.solver = solver # 'lbfgs' - large, liblinear for small datasets
4444
self.dual = dual
45-
pass
45+
self.C = C
46+
self.ngram_range = ngram_range
4647

4748
re_tok = re.compile(f'([{string.punctuation}“”¨«»®´·º½¾¿¡§£₤‘’])')
4849

@@ -56,13 +57,13 @@ def pr(self, y_i, y):
5657
def get_mdl(self, y):
5758
y = y.values
5859
r = np.log(self.pr(1, y) / self.pr(0, y))
59-
m = LogisticRegression(C=4, dual=self.dual, solver=self.solver, max_iter=1000)
60+
m = LogisticRegression(C=self.C, dual=self.dual, solver=self.solver, max_iter=1000)
6061
x_nb = self.trn_term_doc.multiply(r)
6162
return m.fit(x_nb, y), r
6263

6364
def bow(self, X_train):
6465
self.n = X_train.shape[0]
65-
self.vec = TfidfVectorizer(ngram_range=(1, 2), tokenizer=self.tokenize,
66+
self.vec = TfidfVectorizer(ngram_range=self.ngram_range, tokenizer=self.tokenize,
6667
min_df=3, max_df=0.9, strip_accents='unicode', use_idf=1,
6768
smooth_idf=1, sublinear_tf=1)
6869
return self.vec.fit_transform(X_train)

0 commit comments

Comments
 (0)