diff --git a/dedupe/api.py b/dedupe/api.py index e0932c2bd..cca694f97 100644 --- a/dedupe/api.py +++ b/dedupe/api.py @@ -15,7 +15,7 @@ import tempfile import numpy -import sklearn.linear_model +import sklearn.ensemble import sklearn.model_selection import dedupe.core as core @@ -1091,8 +1091,8 @@ def __init__( ] ] self.classifier = sklearn.model_selection.GridSearchCV( - estimator=sklearn.linear_model.LogisticRegression(), - param_grid={"C": [0.00001, 0.0001, 0.001, 0.01, 0.1, 1, 10]}, + estimator=sklearn.ensemble.RandomForestClassifier(), + param_grid={"n_estimators": [100, 200, 400, 800]}, scoring="f1", n_jobs=-1, ) diff --git a/dedupe/labeler.py b/dedupe/labeler.py index c95e08aca..d99b85e11 100644 --- a/dedupe/labeler.py +++ b/dedupe/labeler.py @@ -5,7 +5,7 @@ import numpy from typing import List from typing_extensions import Protocol -import sklearn.linear_model +import sklearn.ensemble import dedupe.core as core import dedupe.training as training @@ -38,7 +38,7 @@ class HasDataModel(Protocol): data_model: datamodel.DataModel -class RLRLearner(sklearn.linear_model.LogisticRegression, ActiveLearner): +class RFLearner(sklearn.ensemble.RandomForestClassifier, ActiveLearner): def __init__(self, data_model): super().__init__() self.data_model = data_model @@ -304,7 +304,7 @@ def _sample(self, data_1, data_2, sample_size): class DisagreementLearner(ActiveLearner): - classifier: RLRLearner + classifier: RFLearner blocker: BlockLearner candidates: List[TrainingExample] @@ -409,7 +409,7 @@ def __init__( self.candidates = self.blocker.candidates - self.classifier = RLRLearner(self.data_model) + self.classifier = RFLearner(self.data_model) self.classifier.candidates = self.candidates self._common_init() @@ -441,7 +441,7 @@ def __init__( self.blocker = RecordLinkBlockLearner(data_model, data_1, data_2, index_include) self.candidates = self.blocker.candidates - self.classifier = RLRLearner(self.data_model) + self.classifier = RFLearner(self.data_model) self.classifier.candidates = self.candidates self._common_init() diff --git a/tests/test_labeler.py b/tests/test_labeler.py index 48293ab3a..1bd4e8e4b 100644 --- a/tests/test_labeler.py +++ b/tests/test_labeler.py @@ -20,7 +20,7 @@ def setUp(self): def test_AL(self): random.seed(1111111111110) original_N = len(SAMPLE) - active_learner = dedupe.labeler.RLRLearner(self.data_model) + active_learner = dedupe.labeler.RFLearner(self.data_model) active_learner.candidates = SAMPLE assert len(active_learner) == original_N