Skip to content

Commit fd1718f

Browse files
fix default model loading
1 parent 53701c7 commit fd1718f

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

extractnet/hybrid_extractor.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
from .util import get_and_union_features, convert_segmentation_to_text
55
from .sequence_tagger.models import word2features
66

7+
import os
78
from sklearn.base import BaseEstimator
89
import joblib
910
import numpy as np
1011
import dateparser
1112

13+
EXTRACTOR_DIR = __file__.replace('/hybrid_extractor.py','')
14+
1215
def merge_results(r1, r2):
1316

1417
for key in r2.keys():
@@ -34,17 +37,33 @@ class Extractor(BaseEstimator):
3437

3538

3639
def __init__(self,
37-
stage1_classifer='extractnet/models/final_extractor.pkl.gz',
38-
author_classifier='extractnet/models/author_extractor.pkl.gz',
39-
date_classifier='extractnet/models/datePublishedRaw_extractor.pkl.gz',
40-
author_embeddings='extractnet/models/char_embedding.joblib',
41-
author_tagger='extractnet/models/crf.joblib',
40+
stage1_classifer=None,
41+
author_classifier=None,
42+
date_classifier=None,
43+
author_embeddings=None,
44+
author_tagger=None,
4245
data_prob_threshold=0.5,
4346
author_prob_threshold=0.5,
4447
):
4548
'''
4649
For inference use only
4750
'''
51+
if stage1_classifer is None:
52+
stage1_classifer = os.path.join(EXTRACTOR_DIR, 'models/final_extractor.pkl.gz')
53+
if author_classifier is None:
54+
author_classifier = os.path.join(EXTRACTOR_DIR, 'models/author_extractor.pkl.gz')
55+
56+
if date_classifier is None:
57+
date_classifier = os.path.join(EXTRACTOR_DIR, 'models/datePublishedRaw_extractor.pkl.gz')
58+
59+
if author_embeddings is None:
60+
author_embeddings = os.path.join(EXTRACTOR_DIR, 'models/char_embedding.joblib')
61+
62+
if author_tagger is None:
63+
author_tagger = os.path.join(EXTRACTOR_DIR, 'models/crf.joblib')
64+
65+
66+
4867
self.author_clf = joblib.load(author_classifier)
4968
self.date_clf = joblib.load(date_classifier)
5069

0 commit comments

Comments
 (0)