Skip to content

Commit ad920fd

Browse files
committed
fix: Correct MLflow tracking function name typo
1 parent 71712db commit ad920fd

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/components/model_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self,
3939
except Exception as e:
4040
raise CustomException(e)
4141

42-
def track_mlfow(self, best_model, classificationmetric: ClassificationMetricArtifact):
42+
def track_mlflow(self, best_model, classificationmetric: ClassificationMetricArtifact):
4343
with mlflow.start_run():
4444
f1_score = classificationmetric.f1_score
4545
precision_score = classificationmetric.precision_score
@@ -105,13 +105,13 @@ def train_model(self, x_train: np.ndarray, y_train: np.ndarray, x_test: np.ndarr
105105

106106
classification_train_metric = get_classification_score(y_true=y_train,y_pred=y_train_pred)
107107

108-
self.track_mlfow(best_model, classification_train_metric)
108+
self.track_mlflow(best_model, classification_train_metric)
109109

110110
y_test_pred = best_model.predict(x_test)
111111
classification_test_metric = get_classification_score(y_true=y_test,y_pred=y_test_pred)
112112

113113
# TODO: track with mlflow -> me
114-
self.track_mlfow(best_model, classification_test_metric)
114+
self.track_mlflow(best_model, classification_test_metric)
115115

116116
preprocessor = load_object(file_path=self.data_transformation_artifact.transformed_object_file_path)
117117

0 commit comments

Comments
 (0)