Skip to content

Commit ef85461

Browse files
author
dillon
committed
adding training for new model
1 parent 7a7d00d commit ef85461

File tree

12 files changed

+170
-85
lines changed

12 files changed

+170
-85
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"C": 1,
3+
"class_weight": "balanced",
4+
"max_iter": 500,
5+
"penalty": "l2",
6+
"solver": "lbfgs"
7+
}

main.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
from src.data.processor import unzip_data_extract_contents
2-
from src.model.train_model import train_bow_logreg, train_tfidf_logreg
3-
from src.model.predict_model import predict_sentiment_tfidf, predict_sentiment_bow
2+
from src.model.train_model import train_bow_logreg, train_tfidf_logreg, finetune_bert
3+
from src.model.predict_model import predict_sentiment_tfidf, predict_sentiment_bow, predict_sentiment_bert
44

55
def main():
66
test_archive, train_archive, unsup_archive, imdb_vocab, imdb_expected_rating = unzip_data_extract_contents()
77
train_tfidf_logreg(test_archive, train_archive, imdb_vocab, imdb_expected_rating)
88
train_bow_logreg(test_archive, train_archive, imdb_vocab)
9+
finetune_bert(train_archive, test_archive)
910

1011
sample_text = "this movie was mid"
1112
print("TFIDF Prediction:", predict_sentiment_tfidf(sample_text))
1213
print("BoW Prediction:", predict_sentiment_bow(sample_text, imdb_vocab))
14+
print("BERT Prediction:", predict_sentiment_bert(sample_text))
1315

1416
sample_text = "it wasnt bad"
1517
print("TFIDF Prediction:", predict_sentiment_tfidf(sample_text))
1618
print("BoW Prediction:", predict_sentiment_bow(sample_text, imdb_vocab))
19+
print("BERT Prediction:", predict_sentiment_bert(sample_text))
1720

1821

1922
main()

models/bow_sentiment_model.joblib

700 KB
Binary file not shown.
700 KB
Binary file not shown.
700 KB
Binary file not shown.

requirements.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
transformers==4.53.3
21
jupyter==1.1.1
32
numpy==2.3.1
43
matplotlib==3.10.3
@@ -7,4 +6,8 @@ joblib==1.5.1
76
flake8==7.0.0
87
pre-commit==4.2.0
98
pytest==8.4.1
10-
coverage==7.4.4
9+
coverage==7.4.4
10+
torch==2.7.1
11+
transformers==4.54.1
12+
datasets==4.0.0
13+
transformers[torch]==4.54.1

src/common/utils.py

Lines changed: 29 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,84 +4,61 @@
44
import joblib
55
from scipy.sparse import save_npz, load_npz
66

7+
8+
def get_project_root():
9+
return os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10+
11+
12+
def get_processed_path(file_name, is_json=False):
13+
ext = '.json' if is_json else '.gz'
14+
if not file_name.endswith(ext):
15+
file_name += ext
16+
return os.path.join(get_project_root(), "data", "processed", file_name)
17+
18+
719
def export_data_to_json(data, file_name, is_json=False):
8-
project_root = os.path.dirname(
9-
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
10-
)
20+
path = get_processed_path(file_name, is_json)
1121
if is_json:
12-
if not file_name.endswith('.json'):
13-
file_name += '.json'
22+
with open(path, 'w', encoding='utf-8') as f:
23+
json.dump(data, f, ensure_ascii=False, indent=2)
1424
else:
15-
if not file_name.endswith('.gz'):
16-
file_name += '.gz'
17-
path = os.path.join(project_root, "data", "processed", file_name)
25+
with gzip.open(path, 'wt', encoding='utf-8') as f:
26+
json.dump(data, f, ensure_ascii=False, indent=2)
1827

19-
with gzip.open(path, 'wt', encoding='utf-8') as f:
20-
json.dump(data, f, ensure_ascii=False, indent=2)
21-
2228

2329
def import_processed_json(file_name, is_json=False):
24-
project_root = os.path.dirname(
25-
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
26-
)
30+
path = get_processed_path(file_name, is_json)
31+
if not os.path.exists(path):
32+
return None
2733
if is_json:
28-
if not file_name.endswith('.json'):
29-
file_name += '.json'
30-
path = os.path.join(project_root, "data", "processed", file_name)
31-
if not os.path.exists(path):
32-
return None
3334
with open(path, 'r', encoding='utf-8') as f:
34-
data = json.load(f)
35-
return data
35+
return json.load(f)
3636
else:
37-
if not file_name.endswith('.gz'):
38-
file_name += '.gz'
39-
path = os.path.join(project_root, "data", "processed", file_name)
40-
if not os.path.exists(path):
41-
return None
4237
with gzip.open(path, 'rt', encoding='utf-8') as f:
43-
data = json.load(f)
44-
return data
38+
return json.load(f)
4539

46-
def export_models(data, file_name):
47-
project_root = os.path.dirname(
48-
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
49-
)
50-
path = os.path.join(project_root, "models", file_name)
5140

52-
import joblib
53-
joblib.dump(data, path)
41+
def export_models(model, file_name):
42+
path = os.path.join(get_project_root(), "models", file_name)
43+
joblib.dump(model, path)
5444

5545

5646
def import_models(file_name):
57-
project_root = os.path.dirname(
58-
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
59-
)
60-
path = os.path.join(project_root, "models", file_name)
61-
47+
path = os.path.join(get_project_root(), "models", file_name)
6248
if not os.path.exists(path):
6349
return None
64-
65-
data = joblib.load(path)
66-
return data
50+
return joblib.load(path)
6751

6852

6953
def export_processed_data(matrix, filename):
70-
project_root = os.path.dirname(
71-
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
72-
)
73-
processed_dir = os.path.join(project_root, "data", "processed")
54+
processed_dir = os.path.join(get_project_root(), "data", "processed")
7455
os.makedirs(processed_dir, exist_ok=True)
75-
7656
path = os.path.join(processed_dir, filename)
7757
save_npz(path, matrix)
7858

7959

8060
def import_processed_data(filename):
81-
project_root = os.path.dirname(
82-
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
83-
)
84-
processed_dir = os.path.join(project_root, "data", "processed")
61+
processed_dir = os.path.join(get_project_root(), "data", "processed")
8562
path = os.path.join(processed_dir, filename)
8663
if not os.path.exists(path):
8764
return None

src/data/processor.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import html
88
import numpy as np
99
from scipy.sparse import lil_matrix
10+
from datasets import Dataset
1011

1112
def unzip_data_extract_contents():
1213
project_root = os.path.dirname(
@@ -142,4 +143,13 @@ def parse_bow_line(line):
142143
if ':' in part:
143144
idx, val = part.split(':')
144145
bow[int(idx)] = int(val)
145-
return bow
146+
return bow
147+
148+
def prepare_bert_dataset(archive):
149+
texts = [r['contents'] for r in archive.reviews]
150+
labels = [0 if r['type'] == 'neg' else 1 for r in archive.reviews]
151+
return Dataset.from_dict({"text": texts, "label": labels})
152+
153+
154+
def tokenize_function(examples, tokenizer):
155+
return tokenizer(examples["text"], truncation=True, padding="max_length")

src/model/predict_model.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from src.common.utils import import_models
22
from src.data.processor import clean_review_text, bow_dicts_to_matrix
33
from src.common.utils import import_models
4+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
5+
import torch
46

57

68
def predict_sentiment_tfidf(text):
@@ -35,4 +37,15 @@ def text_to_bow_dict(text, vocab_list):
3537
idx = vocab_index.get(word)
3638
if idx is not None:
3739
bow[idx] = bow.get(idx, 0) + 1
38-
return bow
40+
return bow
41+
42+
43+
def predict_sentiment_bert(text, model_dir="./bert_finetuned"):
44+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
45+
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
46+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
47+
with torch.no_grad():
48+
outputs = model(**inputs)
49+
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
50+
sentiment = torch.argmax(probs, dim=1).item()
51+
return sentiment, probs.squeeze().tolist()

0 commit comments

Comments
 (0)