Skip to content

Commit 65fbca5

Browse files
committed
bug fixes
1 parent 96e665e commit 65fbca5

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def main():
108108

109109
embeddings = vectorise_clinical_concepts(
110110
args.concepts_file,
111-
os.path.dirname(args.concepts_file),
111+
args.embeddings_file,
112112
batch_size=32
113113
)
114114

src/omop_rag/create_embeddings.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
import os
55

66

7-
def vectorise_clinical_concepts(csv_file_path, output_dir, batch_size=32):
7+
def vectorise_clinical_concepts(csv_file_path, output_path, batch_size=32):
88
"""
99
Vectorises a CSV of clinical concepts using the MedEmbed-Large-v1 model.
1010
1111
Args:
1212
csv_file_path (str): The path to the input CSV file. The CSV must
1313
have a column named 'concept_name' containing
1414
the text to embed.
15-
output_dir (str): The directory to save the output files.
15+
output_path (str): The full path to save the embeddings .pt file.
1616
batch_size (int): The number of concepts to process in each batch.
1717
"""
1818

@@ -61,18 +61,19 @@ def vectorise_clinical_concepts(csv_file_path, output_dir, batch_size=32):
6161
final_embeddings = torch.cat(embeddings)
6262

6363
# Create the output directory if it doesn't exist
64-
os.makedirs(output_dir, exist_ok=True)
65-
66-
# Save embeddings to a file
67-
embeddings_file_path = os.path.join(
68-
output_dir,
69-
'concept_embeddings.pt'
64+
output_dir = os.path.dirname(output_path)
65+
if output_dir:
66+
os.makedirs(output_dir, exist_ok=True)
67+
68+
# Save embeddings to the specified file path
69+
torch.save(final_embeddings, output_path)
70+
print(f"Embeddings saved to {output_path}")
71+
72+
# Save the original concepts CSV to the same directory
73+
concepts_output_path = os.path.join(
74+
output_dir if output_dir else '.',
75+
'clinical_concepts.csv'
7076
)
71-
torch.save(final_embeddings, embeddings_file_path)
72-
print(f"Embeddings saved to {embeddings_file_path}")
73-
74-
# Save the original concepts CSV to the output directory as well
75-
concepts_output_path = os.path.join(output_dir, 'clinical_concepts.csv')
7677
df.to_csv(concepts_output_path, index=False)
7778
print(f"Original concepts saved to {concepts_output_path}")
7879

0 commit comments

Comments
 (0)