Skip to content

Commit e0f8ede

Browse files
authored
Merge pull request #45 from m3dev/dev_kawai
add OptimizeClassificationModel
2 parents 48c4dcb + 75a62c5 commit e0f8ede

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

redshells/train/train_clasification_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,20 @@ def run(self):
5151
redshells.train.utils.fit_model(self)
5252

5353

54+
class OptimizeClassificationModel(_ClassificationModelTask):
55+
"""
56+
Optimize classification model. Please see `_ClassificationModelTask` for more detail and required parameters.
57+
"""
58+
task_namespace = 'redshells'
59+
test_size = luigi.FloatParameter() # type: float
60+
optuna_param_name = luigi.Parameter(description='The key of "redshells.factory.get_optuna_param".')
61+
output_file_path = luigi.Parameter(default='model/classification_model.pkl') # type: str
62+
63+
def run(self):
64+
redshells.train.utils.optimize_model(
65+
self, param_name=self.optuna_param_name, test_size=self.test_size)
66+
67+
5468
class ValidateClassificationModel(_ClassificationModelTask):
5569
"""
5670
Train classification model. Please see `_ClassificationModelTask` for more detail and required parameters.

0 commit comments

Comments
 (0)