|
4 | 4 | import os |
5 | 5 |
|
6 | 6 |
|
7 | | -def vectorise_clinical_concepts(csv_file_path, output_dir, batch_size=32): |
| 7 | +def vectorise_clinical_concepts(csv_file_path, output_path, batch_size=32): |
8 | 8 | """ |
9 | 9 | Vectorises a CSV of clinical concepts using the MedEmbed-Large-v1 model. |
10 | 10 |
|
11 | 11 | Args: |
12 | 12 | csv_file_path (str): The path to the input CSV file. The CSV must |
13 | 13 | have a column named 'concept_name' containing |
14 | 14 | the text to embed. |
15 | | - output_dir (str): The directory to save the output files. |
| 15 | + output_path (str): The full path to save the embeddings .pt file. |
16 | 16 | batch_size (int): The number of concepts to process in each batch. |
17 | 17 | """ |
18 | 18 |
|
@@ -61,18 +61,19 @@ def vectorise_clinical_concepts(csv_file_path, output_dir, batch_size=32): |
61 | 61 | final_embeddings = torch.cat(embeddings) |
62 | 62 |
|
63 | 63 | # Create the output directory if it doesn't exist |
64 | | - os.makedirs(output_dir, exist_ok=True) |
65 | | - |
66 | | - # Save embeddings to a file |
67 | | - embeddings_file_path = os.path.join( |
68 | | - output_dir, |
69 | | - 'concept_embeddings.pt' |
| 64 | + output_dir = os.path.dirname(output_path) |
| 65 | + if output_dir: |
| 66 | + os.makedirs(output_dir, exist_ok=True) |
| 67 | + |
| 68 | + # Save embeddings to the specified file path |
| 69 | + torch.save(final_embeddings, output_path) |
| 70 | + print(f"Embeddings saved to {output_path}") |
| 71 | + |
| 72 | + # Save the original concepts CSV to the same directory |
| 73 | + concepts_output_path = os.path.join( |
| 74 | + output_dir if output_dir else '.', |
| 75 | + 'clinical_concepts.csv' |
70 | 76 | ) |
71 | | - torch.save(final_embeddings, embeddings_file_path) |
72 | | - print(f"Embeddings saved to {embeddings_file_path}") |
73 | | - |
74 | | - # Save the original concepts CSV to the output directory as well |
75 | | - concepts_output_path = os.path.join(output_dir, 'clinical_concepts.csv') |
76 | 77 | df.to_csv(concepts_output_path, index=False) |
77 | 78 | print(f"Original concepts saved to {concepts_output_path}") |
78 | 79 |
|
|
0 commit comments