Skip to content

Angelakeke/Using-Embeddings-for-Medical-Concept-Matching

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 

Repository files navigation

Using Embeddings for Medical Concept Matching

This guide demonstrates how to use the BioLORD model to match symptom descriptions to standard medical concepts through embedding-based similarity.

1. Overview

Objective: Map free-text symptom descriptions to standardized medical concepts (e.g., SNOMED CT, ICD codes) using embedding similarity.

Core Approach:

  • Use a pre-trained biomedical language model (BioLORD) to convert text into vector representations
  • Find the best matching standard medical concepts through vector similarity computation

2. Environment Setup

2.1 Install Dependencies

pip install torch transformers tqdm

2.2 Import Required Libraries

import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm

3. Load Model

3.1 Load BioLORD Model

BioLORD is a pre-trained model specifically designed for biomedical text, particularly suitable for medical concept semantic matching.

# Load model and tokenizer
model = AutoModel.from_pretrained('FremyCompany/BioLORD-2023-C')
tokenizer = AutoTokenizer.from_pretrained('FremyCompany/BioLORD-2023-C')

# Move to GPU if available
# model = model.to('cuda')

4. Prepare Data

4.1 Define Symptom Queries

# List of symptom queries
symptoms = [
    "headache",
    "fever",
    "cough with chest pain",
    "nausea and vomiting",
    "shortness of breath"
]

4.2 Define Medical Concept Database

Instead of loading pre-computed embeddings, we'll define our concept database and compute embeddings on the fly.

# Medical concept database
# In practice, this could come from SNOMED CT, ICD, or other medical ontologies
concept2id = {}
id2concept = {}

5. Generate Embeddings

5.1 Embedding Generation Function

def generate_embeddings(texts, model, tokenizer, batch_size=30, max_length=128):
    """
    Generate embeddings for a list of texts in batches
    
    Args:
        texts: List of text strings
        model: Pre-trained model
        tokenizer: Tokenizer
        batch_size: Batch size for processing
        max_length: Maximum sequence length
    
    Returns:
        embeddings: Tensor of shape (len(texts), hidden_size)
    """
    embeddings = torch.tensor([])
    
    # Process in batches
    for i in tqdm(range(0, len(texts), batch_size), desc="Generating embeddings"):
        batch_texts = texts[i:i+batch_size]
        
        # Tokenize and encode
        inputs = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        )
        # If using GPU: inputs = {k: v.to('cuda') for k, v in inputs.items()}
        
        # Generate embeddings (no gradient computation to save memory)
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Use [CLS] token representation (first token)
        batch_embeddings = outputs.last_hidden_state[:, 0, :]
        embeddings = torch.cat((embeddings, batch_embeddings.cpu()), dim=0)
    
    return embeddings

Key Points:

  • last_hidden_state[:, 0, :] extracts the [CLS] token representation for sentence-level embedding
  • torch.no_grad() disables gradient computation to save memory
  • Batch processing improves efficiency

5.2 Generate Embeddings for Both Symptoms and Concepts

# Generate embeddings for symptom queries
print("Generating embeddings for symptoms...")
symptom_embeddings = generate_embeddings(symptoms, model, tokenizer)

# Generate embeddings for medical concepts
print("Generating embeddings for medical concepts...")
concept_embeddings = generate_embeddings(medical_concepts, model, tokenizer)

print(f"Symptom embeddings shape: {symptom_embeddings.shape}")
print(f"Concept embeddings shape: {concept_embeddings.shape}")

6. Compute Similarity and Match

6.1 Top-K Similarity Matching

def find_top_matches(query_embeddings, concept_embeddings, k=5):
    """
    Find the top-k most similar concepts for each query
    
    Args:
        query_embeddings: Embeddings of query texts
        concept_embeddings: Embeddings of standard concepts
        k: Number of matches to return
    
    Returns:
        topk_indices: Top-k concept indices for each query
        topk_scores: Corresponding similarity scores
    """
    # L2 normalization for cosine similarity
    query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
    concept_embeddings = F.normalize(concept_embeddings, p=2, dim=1)
    
    topk_indices = []
    topk_scores = []
    
    # Compute similarity for each query
    for query_emb in tqdm(query_embeddings, desc="Matching concepts"):
        # Calculate cosine similarity with all concepts
        similarities = torch.matmul(concept_embeddings, query_emb)
        
        # Find top-k most similar concepts
        scores, indices = torch.topk(similarities, k, largest=True)
        
        topk_indices.append(indices.cpu().numpy().tolist())
        topk_scores.append(scores.cpu().numpy().tolist())
    
    return topk_indices, topk_scores

6.2 Execute Matching

# Find top-5 matches for each symptom
top_indices, top_scores = find_top_matches(
    symptom_embeddings,
    concept_embeddings,
    k=5
)

6.3 Map Indices to Concept Names

def map_indices_to_concepts(topk_indices, id2concept):
    """
    Convert concept indices to concept names
    
    Args:
        topk_indices: List of index lists
        id2concept: Mapping from index to concept name
    
    Returns:
        concept_names: List of concept name lists
    """
    concept_names = [
        [id2concept[idx] for idx in indices]
        for indices in topk_indices
    ]
    
    return concept_names

# Get matched concept names
matched_concepts = map_indices_to_concepts(top_indices, id2concept)

7. Threshold-Based Filtering

7.1 Understanding Similarity Thresholds

Similarity scores range from -1 to 1 (after cosine similarity normalization):

  • 0.9 - 1.0: Very high confidence match
  • 0.8 - 0.9: Good match, likely relevant
  • 0.4 - 0.8: Moderate match, may be relevant
  • < 0.4: Low confidence, likely not a good match

7.2 Filter Matches Based on Threshold

def filter_matches_by_threshold(topk_indices, topk_scores, threshold=0.6):
    """
    Filter matches based on similarity threshold
    
    Args:
        topk_indices: Top-k concept indices for each query
        topk_scores: Corresponding similarity scores
        threshold: Minimum similarity score to consider a match valid
    
    Returns:
        filtered_indices: Filtered concept indices
        filtered_scores: Filtered similarity scores
        match_status: List indicating whether each query has valid matches
    """
    filtered_indices = []
    filtered_scores = []
    match_status = []
    
    for indices, scores in zip(topk_indices, topk_scores):
        # Filter based on threshold
        valid_matches = [(idx, score) for idx, score in zip(indices, scores) 
                        if score >= threshold]
        
        if valid_matches:
            # Has valid matches above threshold
            f_indices, f_scores = zip(*valid_matches)
            filtered_indices.append(list(f_indices))
            filtered_scores.append(list(f_scores))
            match_status.append(True)
        else:
            # No valid matches
            filtered_indices.append([])
            filtered_scores.append([])
            match_status.append(False)
    
    return filtered_indices, filtered_scores, match_status

# Apply threshold filtering
SIMILARITY_THRESHOLD = 0.9  # Adjust based on your requirements

filtered_indices, filtered_scores, match_status = filter_matches_by_threshold(
    top_indices,
    top_scores,
    threshold=SIMILARITY_THRESHOLD
)

8. Display Results with Threshold Information

8.1 Map Indices to Concept Names

def map_indices_to_concepts(topk_indices, id2concept):
    """
    Convert concept indices to concept names
    
    Args:
        topk_indices: List of index lists
        id2concept: Mapping from index to concept name
    
    Returns:
        concept_names: List of concept name lists
    """
    concept_names = [
        [id2concept[idx] for idx in indices]
        for indices in topk_indices
    ]
    
    return concept_names

8.2 Enhanced Display with Match Status

def display_matches_with_threshold(symptoms, topk_indices, topk_scores, 
                                   id2concept, threshold=0.6):
    """
    Display matching results with threshold filtering information
    
    Args:
        symptoms: List of symptom queries
        topk_indices: Top-k concept indices for each query
        topk_scores: Corresponding similarity scores
        id2concept: Mapping from index to concept name
        threshold: Similarity threshold
    """
    for i, symptom in enumerate(symptoms):
        print(f"\n{'='*80}")
        print(f"Symptom Query: {symptom}")
        print(f"{'='*80}")
        
        indices = topk_indices[i]
        scores = topk_scores[i]
        
        # Check if any match meets threshold
        valid_matches = [(idx, score) for idx, score in zip(indices, scores) 
                        if score >= threshold]
        
        if not valid_matches:
            print(f"⚠️  NO VALID MATCH FOUND")
            print(f"   Best similarity score: {scores[0]:.4f} (below threshold {threshold})")
            print(f"   This symptom may need manual review or is not in the database.")
            
            # Show top candidates even if below threshold
            print(f"\n   Top candidates (below threshold):")
            for j in range(min(3, len(indices))):
                concept = id2concept[indices[j]]
                score = scores[j]
                print(f"     • {concept:<40} (similarity: {score:.4f})")
        else:
            print(f"✓ MATCHED ({len(valid_matches)} concept(s) above threshold)")
            
            # Show valid matches
            for j, (idx, score) in enumerate(valid_matches, 1):
                concept = id2concept[idx]
                
                # Add confidence indicator
                if score >= 0.9:
                    confidence = "HIGH"
                    symbol = "★★★"
                elif score >= 0.8:
                    confidence = "GOOD"
                    symbol = "★★"
                else:
                    confidence = "MODERATE"
                    symbol = "★"
                
                print(f"  {j}. {concept:<40} {symbol}")
                print(f"     Similarity: {score:.4f} | Confidence: {confidence}")
            
            # Show below-threshold candidates if any
            below_threshold = [(idx, score) for idx, score in zip(indices, scores) 
                              if score < threshold]
            if below_threshold and len(below_threshold) <= 2:
                print(f"\n   Other candidates (below threshold):")
                for idx, score in below_threshold:
                    concept = id2concept[idx]
                    print(f"     • {concept:<40} (similarity: {score:.4f})")

# Display with threshold filtering
display_matches_with_threshold(
    symptoms,
    top_indices,
    top_scores,
    id2concept,
    threshold=SIMILARITY_THRESHOLD
)

8.3 Summary Statistics

def print_matching_summary(symptoms, match_status, topk_scores, threshold):
    """
    Print summary statistics about matching results
    
    Args:
        symptoms: List of symptom queries
        match_status: Boolean list indicating match success
        topk_scores: Similarity scores
        threshold: Threshold used for filtering
    """
    total_queries = len(symptoms)
    matched_queries = sum(match_status)
    unmatched_queries = total_queries - matched_queries
    
    print(f"\n{'='*80}")
    print(f"MATCHING SUMMARY")
    print(f"{'='*80}")
    print(f"Total queries:              {total_queries}")
    print(f"Successfully matched:       {matched_queries} ({matched_queries/total_queries*100:.1f}%)")
    print(f"No match found:             {unmatched_queries} ({unmatched_queries/total_queries*100:.1f}%)")
    print(f"Similarity threshold used:  {threshold}")
    
    # Show unmatched queries
    if unmatched_queries > 0:
        print(f"\nUnmatched queries:")
        for i, (symptom, matched) in enumerate(zip(symptoms, match_status)):
            if not matched:
                best_score = topk_scores[i][0] if topk_scores[i] else 0.0
                print(f"  • {symptom:<40} (best score: {best_score:.4f})")
    
    # Average similarity statistics
    all_best_scores = [scores[0] for scores in topk_scores if scores]
    if all_best_scores:
        avg_score = sum(all_best_scores) / len(all_best_scores)
        max_score = max(all_best_scores)
        min_score = min(all_best_scores)
        print(f"\nSimilarity score statistics:")
        print(f"  Average (best match): {avg_score:.4f}")
        print(f"  Maximum:              {max_score:.4f}")
        print(f"  Minimum:              {min_score:.4f}")

# Print summary
print_matching_summary(symptoms, match_status, top_scores, SIMILARITY_THRESHOLD)

9. Optimization Tips

9.1 Performance Optimization

Use GPU for Faster Processing:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Modify the generate_embeddings function to use GPU
inputs = {k: v.to(device) for k, v in inputs.items()}

Batch Matrix Multiplication:

def find_top_matches_fast(query_embeddings, concept_embeddings, k=5):
    """Faster version using batch matrix multiplication"""
    # Normalize embeddings
    query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
    concept_embeddings = F.normalize(concept_embeddings, p=2, dim=1)
    
    # Compute all similarities at once
    similarities = torch.matmul(query_embeddings, concept_embeddings.T)
    
    # Get top-k for all queries
    scores, indices = torch.topk(similarities, k, dim=1, largest=True)
    
    topk_indices = indices.cpu().numpy().tolist()
    topk_scores = scores.cpu().numpy().tolist()
    
    return topk_indices, topk_scores

9.2 Save Computed Embeddings

If your concept database is large and static, save the embeddings for reuse:

# Save concept embeddings
torch.save(concept_embeddings, 'concept_embeddings.pt')

# Save concept mappings
import json
with open('concept_mappings.json', 'w') as f:
    json.dump({'concept2id': concept2id, 'id2concept': id2concept}, f)

# Load later
concept_embeddings = torch.load('concept_embeddings.pt')
with open('concept_mappings.json', 'r') as f:
    mappings = json.load(f)
    concept2id = mappings['concept2id']
    id2concept = {int(k): v for k, v in mappings['id2concept'].items()}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors