-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
38 lines (31 loc) · 1.17 KB
/
main.py
File metadata and controls
38 lines (31 loc) · 1.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from src.preprocess import load_and_clean_data, clean_and_concat_text, vectorize_text
from src.train import train_model, evaluate_model
from src.predict import save_submission
import pandas as pd
from sklearn.model_selection import train_test_split
TARGET_COLUMNS = [
"task_achievement",
"coherence_and_cohesion",
"lexical_resource",
"grammatical_range"
]
def main():
# Load and clean train data
df_train = load_and_clean_data("data/df_train.csv")
df_train = clean_and_concat_text(df_train)
# Vectorize text
X, vectorizer = vectorize_text(df_train)
y = df_train[TARGET_COLUMNS]
# Train/val split and training
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
model = train_model(X_train, y_train)
evaluate_model(model, X_val, y_val)
# Predict test set
df_test = pd.read_csv("data/df_test.csv")
df_test["text"] = (df_test["prompt"] + " " + df_test["essay"]).str.lower()
X_test = vectorizer.transform(df_test["text"])
preds = model.predict(X_test)
# Save submission
save_submission(ids=range(1, len(preds) + 1), predictions=preds)
if __name__ == "__main__":
main()