|
6 | 6 |
|
7 | 7 | INPUT_FILENAME = "./data/city_wikipedia_summaries.csv" |
8 | 8 | EXPORT_FILENAME = "./data/city_wikipedia_summaries_with_embeddings.parquet" |
9 | | -TOKENIZER = 'sentence-transformers/all-MiniLM-L6-v2' |
10 | | -MODEL = 'sentence-transformers/all-MiniLM-L6-v2' |
| 9 | +TOKENIZER = "sentence-transformers/all-MiniLM-L6-v2" |
| 10 | +MODEL = "sentence-transformers/all-MiniLM-L6-v2" |
| 11 | + |
11 | 12 |
|
12 | 13 | def mean_pooling(model_output, attention_mask): |
13 | | - token_embeddings = model_output[0] #First element of model_output contains all token embeddings |
14 | | - input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
15 | | - return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
| 14 | + token_embeddings = model_output[ |
| 15 | + 0 |
| 16 | + ] # First element of model_output contains all token embeddings |
| 17 | + input_mask_expanded = ( |
| 18 | + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| 19 | + ) |
| 20 | + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
| 21 | + input_mask_expanded.sum(1), min=1e-9 |
| 22 | + ) |
| 23 | + |
16 | 24 |
|
17 | 25 | def run_model(sentences, tokenizer, model): |
18 | | - encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') |
| 26 | + encoded_input = tokenizer( |
| 27 | + sentences, padding=True, truncation=True, return_tensors="pt" |
| 28 | + ) |
19 | 29 | # Compute token embeddings |
20 | 30 | with torch.no_grad(): |
21 | 31 | model_output = model(**encoded_input) |
22 | 32 |
|
23 | | - sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) |
| 33 | + sentence_embeddings = mean_pooling(model_output, encoded_input["attention_mask"]) |
24 | 34 | sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) |
25 | 35 | return sentence_embeddings |
26 | 36 |
|
| 37 | + |
27 | 38 | def score_data() -> None: |
28 | 39 | if EXPORT_FILENAME not in os.listdir(): |
29 | 40 | print("scored data not found...generating embeddings...") |
30 | 41 | df = pd.read_csv(INPUT_FILENAME) |
31 | 42 | tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) |
32 | 43 | model = AutoModel.from_pretrained(MODEL) |
33 | | - embeddings = run_model(df['Wiki Summary'].tolist(), tokenizer, model) |
| 44 | + embeddings = run_model(df["Wiki Summary"].tolist(), tokenizer, model) |
34 | 45 | print(embeddings) |
35 | | - print('shape = ', df.shape) |
36 | | - df['Embeddings'] = list(embeddings.detach().cpu().numpy()) |
| 46 | + print("shape = ", df.shape) |
| 47 | + df["Embeddings"] = list(embeddings.detach().cpu().numpy()) |
37 | 48 | print("embeddings generated...") |
38 | | - df['event_timestamp'] = pd.to_datetime('today') |
| 49 | + df["event_timestamp"] = pd.to_datetime("today") |
39 | 50 | df["item_id"] = df.index |
40 | 51 | print(df.head()) |
41 | 52 | df.to_parquet(EXPORT_FILENAME, index=False) |
42 | 53 | print("...data exported. job complete") |
43 | 54 | else: |
44 | 55 | print("scored data found...skipping generating embeddings.") |
45 | 56 |
|
46 | | -if __name__ == '__main__': |
| 57 | + |
| 58 | +if __name__ == "__main__": |
47 | 59 | score_data() |
0 commit comments