diff --git a/README.md b/README.md index ef1a932..99603b3 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,37 @@ print(predictions) # Output: [('positive', 0.85), ('neutral', 0.12), ('negative', 0.03)] ``` +### 🏷️ Multi-Label Classification + +Classify texts into multiple categories simultaneously with automatic threshold adaptation: + +```python +from adaptive_classifier import MultiLabelAdaptiveClassifier + +# Initialize multi-label classifier +classifier = MultiLabelAdaptiveClassifier( + "bert-base-uncased", + min_predictions=1, # Ensure at least 1 prediction + max_predictions=5 # Limit to top 5 predictions +) + +# Multi-label training data (each text can have multiple labels) +texts = [ + "AI researchers study climate change using machine learning", + "Tech startup develops healthcare solutions" +] +labels = [ + ["technology", "science", "climate", "ai"], + ["technology", "business", "healthcare"] +] + +classifier.add_examples(texts, labels) + +# Make multi-label predictions +predictions = classifier.predict_multilabel("Medical AI breakthrough announced") +# Output: [('healthcare', 0.72), ('technology', 0.68), ('ai', 0.45)] +``` + ### 💾 Save & Load Models ```python @@ -188,6 +219,46 @@ more_labels = ["positive"] * 2 classifier.add_examples(more_examples, more_labels) ``` +### Multi-Label Classification with Advanced Configuration + +```python +from adaptive_classifier import MultiLabelAdaptiveClassifier + +# Configure advanced multi-label settings +classifier = MultiLabelAdaptiveClassifier( + "bert-base-uncased", + default_threshold=0.5, # Base threshold for predictions + min_predictions=1, # Minimum labels to return + max_predictions=10 # Maximum labels to return +) + +# Training with diverse multi-label examples +texts = [ + "Scientists develop AI for medical diagnosis and climate research", + "Tech company launches sustainable energy and healthcare products", + "Olympic athletes use sports science and nutrition technology" +] +labels = [ + ["science", "ai", "healthcare", "research"], + ["technology", "business", "environment", "healthcare"], + ["sports", "science", "health", "technology"] +] + +classifier.add_examples(texts, labels) + +# Advanced prediction options +predictions = classifier.predict_multilabel( + "New research on AI applications in environmental science", + threshold=0.3, # Custom threshold + max_labels=5 # Limit results +) + +# Get detailed statistics +stats = classifier.get_label_statistics() +print(f"Adaptive threshold: {stats['adaptive_threshold']}") +print(f"Label-specific thresholds: {stats['label_thresholds']}") +``` + ### Strategic Classification (Anti-Gaming) ```python @@ -224,6 +295,69 @@ print(f"Strategic: {strategic_preds}") print(f"Robust: {robust_preds}") ``` +## 🏷️ Multi-Label Classification + +The `MultiLabelAdaptiveClassifier` extends adaptive classification to handle scenarios where each text can belong to multiple categories simultaneously. It automatically handles threshold adaptation for scenarios with many labels. + +### Key Features + +- **🎯 Automatic Threshold Adaptation**: Dynamically adjusts thresholds based on the number of labels to prevent empty predictions +- **📊 Sigmoid Activation**: Uses proper multi-label architecture with BCE loss instead of softmax +- **⚙️ Configurable Limits**: Set minimum and maximum number of predictions per input +- **📈 Label-Specific Thresholds**: Automatically adjusts thresholds based on label frequency +- **🔄 Incremental Learning**: Add new labels and examples without retraining from scratch + +### Usage + +```python +from adaptive_classifier import MultiLabelAdaptiveClassifier + +# Initialize with configuration +classifier = MultiLabelAdaptiveClassifier( + "distilbert/distilbert-base-cased", + default_threshold=0.5, + min_predictions=1, + max_predictions=5 +) + +# Multi-label training data +texts = [ + "Breaking: Scientists discover AI can help predict climate change patterns", + "Tech giant announces breakthrough in quantum computing for healthcare", + "Olympic committee adopts new sports technology for athlete performance" +] + +labels = [ + ["science", "technology", "climate", "news"], + ["technology", "healthcare", "quantum", "business"], + ["sports", "technology", "performance", "news"] +] + +# Train the classifier +classifier.add_examples(texts, labels) + +# Make predictions +predictions = classifier.predict_multilabel( + "Revolutionary medical AI system launched by tech startup" +) + +# Results: [('technology', 0.85), ('healthcare', 0.72), ('business', 0.45)] +``` + +### Adaptive Thresholds + +The classifier automatically adjusts prediction thresholds based on the number of labels: + +| Number of Labels | Threshold | Benefit | +|-----------------|-----------|---------| +| 2-4 labels | 0.5 (default) | Standard precision | +| 5-9 labels | 0.4 (20% lower) | Balanced recall | +| 10-19 labels | 0.3 (40% lower) | Better coverage | +| 20-29 labels | 0.2 (60% lower) | Prevents empty results | +| 30+ labels | 0.1 (80% lower) | Ensures predictions | + +This solves the common "No labels met the threshold criteria" issue when dealing with many-label scenarios. + --- ## 🏢 Enterprise Use Cases diff --git a/examples/multilabel_usage.py b/examples/multilabel_usage.py new file mode 100644 index 0000000..b680ce3 --- /dev/null +++ b/examples/multilabel_usage.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 +""" +Multi-Label Adaptive Classifier Example + +This example demonstrates how to use the MultiLabelAdaptiveClassifier +for text classification tasks where each text can belong to multiple categories. + +Key features demonstrated: +1. Training with multi-label data +2. Making multi-label predictions +3. Adaptive threshold handling for many labels +4. Label-specific threshold customization +5. Saving and loading multi-label models +""" + +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +from adaptive_classifier import MultiLabelAdaptiveClassifier +import torch + + +def create_sample_data(): + """Create sample multi-label training data.""" + + # Sample texts with multiple labels each + training_data = [ + { + "text": "Scientists discover new species of butterfly in Amazon rainforest with unique wing patterns", + "labels": ["science", "nature", "discovery", "biology"] + }, + { + "text": "Tech startup raises $50M in Series A funding to develop AI-powered healthcare solutions", + "labels": ["technology", "business", "healthcare", "funding"] + }, + { + "text": "Climate change impacts ocean temperature causing coral bleaching in Great Barrier Reef", + "labels": ["environment", "climate", "nature", "science"] + }, + { + "text": "NBA playoffs feature exciting games with record-breaking performances by star players", + "labels": ["sports", "entertainment", "basketball"] + }, + { + "text": "New renewable energy technology could reduce costs by 40% according to MIT research", + "labels": ["technology", "science", "environment", "energy"] + }, + { + "text": "Archaeological team uncovers 2000-year-old Roman artifacts in excavation site", + "labels": ["history", "science", "discovery", "archaeology"] + }, + { + "text": "Stock market reaches new highs as investors show confidence in economic recovery", + "labels": ["business", "finance", "economy"] + }, + { + "text": "Machine learning breakthrough helps doctors diagnose rare diseases more accurately", + "labels": ["technology", "healthcare", "science", "ai"] + }, + { + "text": "Wildlife conservation efforts show success in protecting endangered tiger populations", + "labels": ["nature", "environment", "conservation", "wildlife"] + }, + { + "text": "Olympic athletes prepare for upcoming games with intensive training programs", + "labels": ["sports", "olympics", "training", "fitness"] + }, + { + "text": "Quantum computing research makes progress toward solving complex optimization problems", + "labels": ["technology", "science", "computing", "research"] + }, + { + "text": "Sustainable agriculture practices help farmers reduce environmental impact while increasing yield", + "labels": ["environment", "agriculture", "sustainability", "farming"] + }, + { + "text": "Music festival features artists from diverse genres attracting thousands of fans", + "labels": ["entertainment", "music", "culture", "events"] + }, + { + "text": "Space agency announces plans for Mars mission with new rocket technology", + "labels": ["science", "space", "technology", "exploration"] + }, + { + "text": "Educational technology helps students learn programming through interactive online courses", + "labels": ["education", "technology", "programming", "learning"] + } + ] + + # Extract texts and labels + texts = [item["text"] for item in training_data] + labels = [item["labels"] for item in training_data] + + return texts, labels + + +def demonstrate_basic_usage(): + """Demonstrate basic multi-label classification.""" + + print("=" * 60) + print("MULTI-LABEL ADAPTIVE CLASSIFIER - BASIC USAGE") + print("=" * 60) + + # Create classifier + classifier = MultiLabelAdaptiveClassifier( + model_name="distilbert/distilbert-base-cased", + default_threshold=0.5, + min_predictions=1, # Ensure at least 1 prediction + max_predictions=5 # Limit to top 5 predictions + ) + + # Load training data + texts, labels = create_sample_data() + + print(f"Training with {len(texts)} examples") + print(f"Example text: {texts[0][:60]}...") + print(f"Example labels: {labels[0]}") + + # Train the classifier + classifier.add_examples(texts, labels) + + # Get statistics + stats = classifier.get_label_statistics() + print(f"\nTraining completed:") + print(f"- Total labels: {stats['num_classes']}") + print(f"- Total examples: {stats['total_examples']}") + print(f"- Adaptive threshold: {stats['adaptive_threshold']:.3f}") + + return classifier + + +def demonstrate_predictions(classifier): + """Demonstrate making predictions.""" + + print("\n" + "=" * 60) + print("MAKING PREDICTIONS") + print("=" * 60) + + # Test texts + test_texts = [ + "Researchers develop new AI algorithm for medical diagnosis", + "Football team wins championship in exciting final match", + "Solar panel efficiency increases with new manufacturing technique", + "Ancient civilization discovered through satellite imagery analysis" + ] + + for text in test_texts: + print(f"\nText: {text}") + + # Make multi-label prediction + predictions = classifier.predict_multilabel(text) + + print("Predictions:") + if predictions: + for label, confidence in predictions: + print(f" {label}: {confidence:.4f}") + else: + print(" No predictions above threshold") + + return test_texts + + +def demonstrate_threshold_adjustment(classifier): + """Demonstrate threshold adjustment for different scenarios.""" + + print("\n" + "=" * 60) + print("THRESHOLD ADJUSTMENT") + print("=" * 60) + + test_text = "AI researchers publish breakthrough study on climate modeling using machine learning" + + print(f"Test text: {test_text}") + + # Try different thresholds + thresholds = [0.1, 0.3, 0.5, 0.7, 0.9] + + print(f"\n{'Threshold':<10} {'Predictions':<12} {'Labels'}") + print("-" * 50) + + for threshold in thresholds: + predictions = classifier.predict_multilabel(test_text, threshold=threshold) + labels_str = ", ".join([label for label, _ in predictions[:3]]) + + print(f"{threshold:<10.1f} {len(predictions):<12} {labels_str}") + + +def demonstrate_saving_loading(classifier): + """Demonstrate saving and loading the model.""" + + print("\n" + "=" * 60) + print("SAVING AND LOADING") + print("=" * 60) + + # Save the model + save_path = "./multilabel_classifier" + print(f"Saving classifier to {save_path}") + classifier.save(save_path) + + # Load the model + print("Loading classifier...") + loaded_classifier = MultiLabelAdaptiveClassifier.load(save_path) + + # Verify it works + test_text = "New medical technology helps treat cancer patients" + + print(f"\nTesting loaded classifier:") + print(f"Text: {test_text}") + + predictions = loaded_classifier.predict_multilabel(test_text) + print("Predictions:") + for label, confidence in predictions: + print(f" {label}: {confidence:.4f}") + + return loaded_classifier + + +def demonstrate_incremental_learning(classifier): + """Demonstrate adding new labels incrementally.""" + + print("\n" + "=" * 60) + print("INCREMENTAL LEARNING - ADDING NEW LABELS") + print("=" * 60) + + # Add new examples with new labels + new_texts = [ + "Chef creates innovative fusion cuisine combining Asian and European flavors", + "Food delivery service expands to new cities with sustainable packaging", + "Restaurant industry adapts to new dining trends post-pandemic", + "Cooking show features celebrity chefs competing in culinary challenges" + ] + + new_labels = [ + ["food", "cuisine", "cooking", "culture"], + ["business", "food", "sustainability"], + ["business", "food", "trends"], + ["entertainment", "food", "cooking", "tv"] + ] + + print("Adding new examples with 'food' and 'cooking' labels...") + classifier.add_examples(new_texts, new_labels) + + # Test with food-related text + food_text = "Nutritionist recommends healthy meal planning for busy professionals" + + print(f"\nTesting with food-related text:") + print(f"Text: {food_text}") + + predictions = classifier.predict_multilabel(food_text) + print("Predictions:") + for label, confidence in predictions: + print(f" {label}: {confidence:.4f}") + + # Show updated statistics + stats = classifier.get_label_statistics() + print(f"\nUpdated statistics:") + print(f"- Total labels: {stats['num_classes']}") + print(f"- Total examples: {stats['total_examples']}") + + +def main(): + """Main function to run all demonstrations.""" + + print("Multi-Label Adaptive Classifier Example") + print("Fixing the 'No labels met the threshold criteria' issue\n") + + try: + # Basic usage + classifier = demonstrate_basic_usage() + + # Making predictions + demonstrate_predictions(classifier) + + # Threshold adjustment + demonstrate_threshold_adjustment(classifier) + + # Saving and loading + loaded_classifier = demonstrate_saving_loading(classifier) + + # Incremental learning + demonstrate_incremental_learning(loaded_classifier) + + print("\n" + "=" * 60) + print("EXAMPLE COMPLETED SUCCESSFULLY") + print("=" * 60) + + # Final statistics + final_stats = loaded_classifier.get_label_statistics() + print(f"\nFinal Model Statistics:") + print(f"- Labels: {final_stats['num_classes']}") + print(f"- Examples: {final_stats['total_examples']}") + print(f"- Default threshold: {final_stats['default_threshold']}") + print(f"- Adaptive threshold: {final_stats['adaptive_threshold']:.3f}") + + except Exception as e: + print(f"Error: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/setup.py b/setup.py index 137719a..2613747 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ setup( name="adaptive-classifier", - version="0.0.17", + version="0.0.18", author="codelion", author_email="codelion@okyasoft.com", description="A flexible, adaptive classification system for dynamic text classification", diff --git a/src/adaptive_classifier/__init__.py b/src/adaptive_classifier/__init__.py index 0f6c1aa..c6231dc 100644 --- a/src/adaptive_classifier/__init__.py +++ b/src/adaptive_classifier/__init__.py @@ -1,12 +1,15 @@ from .classifier import AdaptiveClassifier from .models import Example, AdaptiveHead, ModelConfig from .memory import PrototypeMemory +from .multilabel import MultiLabelAdaptiveClassifier, MultiLabelAdaptiveHead from huggingface_hub import ModelHubMixin -__version__ = "0.0.17" +__version__ = "0.0.18" __all__ = [ "AdaptiveClassifier", + "MultiLabelAdaptiveClassifier", + "MultiLabelAdaptiveHead", "Example", "AdaptiveHead", "ModelConfig", diff --git a/src/adaptive_classifier/multilabel.py b/src/adaptive_classifier/multilabel.py new file mode 100644 index 0000000..d740879 --- /dev/null +++ b/src/adaptive_classifier/multilabel.py @@ -0,0 +1,426 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from typing import List, Dict, Optional, Tuple, Any, Set, Union +import logging +from collections import defaultdict + +from .classifier import AdaptiveClassifier +from .models import AdaptiveHead + +logger = logging.getLogger(__name__) + + +class MultiLabelAdaptiveHead(nn.Module): + """Multi-label version of adaptive head using sigmoid activation.""" + + def __init__(self, input_dim: int, num_classes: int, hidden_dims: List[int] = None): + super().__init__() + + if hidden_dims is None: + hidden_dims = [input_dim // 2] + + layers = [] + prev_dim = input_dim + + for dim in hidden_dims: + layers.extend([ + nn.Linear(prev_dim, dim), + nn.ReLU(), + nn.Dropout(0.1) + ]) + prev_dim = dim + + # Final layer with sigmoid for multi-label + layers.append(nn.Linear(prev_dim, num_classes)) + + self.model = nn.Sequential(*layers) + self.num_classes = num_classes + + def forward(self, x): + logits = self.model(x) + # Apply sigmoid for multi-label prediction + return torch.sigmoid(logits) + + def update_num_classes(self, new_num_classes: int): + """Update the number of output classes while preserving existing weights.""" + if new_num_classes <= self.num_classes: + return + + # Get the final layer + final_layer = self.model[-1] + + # Create new final layer + new_final_layer = nn.Linear(final_layer.in_features, new_num_classes) + + # Copy existing weights + with torch.no_grad(): + new_final_layer.weight[:self.num_classes] = final_layer.weight + new_final_layer.bias[:self.num_classes] = final_layer.bias + + # Initialize new class weights with small random values + nn.init.xavier_uniform_(new_final_layer.weight[self.num_classes:]) + nn.init.zeros_(new_final_layer.bias[self.num_classes:]) + + # Replace the final layer + self.model[-1] = new_final_layer + self.num_classes = new_num_classes + + +class MultiLabelAdaptiveClassifier(AdaptiveClassifier): + """ + Multi-label extension of AdaptiveClassifier that can predict multiple labels per input. + + Handles the "No labels met the threshold criteria" issue by implementing: + 1. Adaptive thresholds based on number of labels + 2. Minimum predictions per sample + 3. Label-specific threshold adjustments + """ + + def __init__( + self, + model_name: str, + device: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + seed: int = 42, + default_threshold: float = 0.5, + min_predictions: int = 1, + max_predictions: Optional[int] = None + ): + super().__init__(model_name, device, config, seed) + + # Multi-label specific configuration + self.default_threshold = default_threshold + self.min_predictions = min_predictions + self.max_predictions = max_predictions + self.label_thresholds = {} # Per-label thresholds + + # Override adaptive head with multi-label version + self.adaptive_head = None + + def _initialize_adaptive_head(self): + """Initialize multi-label adaptive head.""" + num_classes = len(self.label_to_id) + hidden_dims = [self.embedding_dim, self.embedding_dim // 2] + + self.adaptive_head = MultiLabelAdaptiveHead( + self.embedding_dim, + num_classes, + hidden_dims=hidden_dims + ).to(self.device) + + def _get_adaptive_threshold(self, num_labels: int) -> float: + """ + Calculate adaptive threshold based on number of labels. + + With more labels, individual prediction scores tend to be lower, + so we need a lower threshold to avoid "No labels met the threshold criteria". + """ + if num_labels <= 2: + return self.default_threshold + elif num_labels <= 5: + return self.default_threshold * 0.8 + elif num_labels <= 10: + return self.default_threshold * 0.6 + elif num_labels <= 20: + return self.default_threshold * 0.4 + else: + # For many labels (20+), use very low threshold + return self.default_threshold * 0.2 + + def predict_multilabel( + self, + text: str, + threshold: Optional[float] = None, + max_labels: Optional[int] = None + ) -> List[Tuple[str, float]]: + """ + Predict multiple labels for input text. + + Args: + text: Input text to classify + threshold: Confidence threshold for predictions (adaptive if None) + max_labels: Maximum number of labels to return + + Returns: + List of (label, confidence) tuples for labels above threshold + """ + if not text: + raise ValueError("Empty input text") + + num_labels = len(self.label_to_id) + if num_labels == 0: + return [] + + # Use adaptive threshold if not specified + if threshold is None: + threshold = self._get_adaptive_threshold(num_labels) + + max_labels = max_labels or self.max_predictions + + with torch.no_grad(): + # Get embedding + embedding = self._get_embeddings([text])[0] + + # Get predictions from neural head + if self.adaptive_head is not None: + self.adaptive_head.eval() + input_embedding = embedding.unsqueeze(0).to(self.device) + probabilities = self.adaptive_head(input_embedding).squeeze(0) + + # Convert to label predictions + predictions = [] + for i, prob in enumerate(probabilities): + if i < len(self.id_to_label): + label = self.id_to_label[i] + # Use label-specific threshold if available + label_threshold = self.label_thresholds.get(label, threshold) + if prob.item() >= label_threshold: + predictions.append((label, prob.item())) + + # Sort by confidence + predictions.sort(key=lambda x: x[1], reverse=True) + + # Apply max_labels limit + if max_labels and len(predictions) > max_labels: + predictions = predictions[:max_labels] + + else: + # Fallback to prototype-based prediction + proto_predictions = self.memory.get_nearest_prototypes( + embedding, + k=min(num_labels, max_labels) if max_labels else num_labels + ) + + # Filter by threshold + predictions = [ + (label, score) for label, score in proto_predictions + if score >= threshold + ] + + # Ensure minimum predictions if required + if len(predictions) < self.min_predictions and self.adaptive_head is not None: + # Add top predictions even if below threshold + with torch.no_grad(): + input_embedding = embedding.unsqueeze(0).to(self.device) + probabilities = self.adaptive_head(input_embedding).squeeze(0) + + # Get top predictions + values, indices = torch.topk( + probabilities, + min(self.min_predictions, len(self.id_to_label)) + ) + + additional_predictions = [] + for val, idx in zip(values, indices): + if idx.item() < len(self.id_to_label): + label = self.id_to_label[idx.item()] + score = val.item() + + # Only add if not already included + if not any(pred[0] == label for pred in predictions): + additional_predictions.append((label, score)) + + # Add additional predictions to meet minimum + predictions.extend(additional_predictions[:self.min_predictions - len(predictions)]) + predictions.sort(key=lambda x: x[1], reverse=True) + + return predictions + + def predict(self, text: str, k: int = 5) -> List[Tuple[str, float]]: + """ + Override base predict to use multi-label prediction. + Falls back to single-label prediction if needed. + """ + # Use multi-label prediction but limit to k results + multilabel_preds = self.predict_multilabel(text, max_labels=k) + + if multilabel_preds: + return multilabel_preds[:k] + else: + # Fallback to base prediction if no multi-label predictions + return super().predict(text, k) + + def add_examples(self, texts: List[str], labels: List[List[str]]): + """ + Add multi-label training examples. + + Args: + texts: List of input texts + labels: List of label lists (each text can have multiple labels) + """ + if not texts or not labels: + raise ValueError("Empty input lists") + if len(texts) != len(labels): + raise ValueError("Mismatched text and label lists") + + # Flatten labels for single-label training approach + # We'll train one example per text-label pair + flattened_texts = [] + flattened_labels = [] + + for text, text_labels in zip(texts, labels): + if not text_labels: # Skip texts with no labels + continue + + # For multi-label, we create multiple training examples + # Each example represents the text with one of its labels + for label in text_labels: + flattened_texts.append(text) + flattened_labels.append(label) + + if flattened_texts: + # Use parent class method with flattened examples + super().add_examples(flattened_texts, flattened_labels) + + # Update label-specific thresholds based on training data + self._update_label_thresholds() + + def _update_label_thresholds(self): + """Update per-label thresholds based on training data distribution.""" + if not self.memory.examples: + return + + # Calculate label frequencies + label_counts = defaultdict(int) + total_examples = 0 + + for label, examples in self.memory.examples.items(): + label_counts[label] = len(examples) + total_examples += len(examples) + + # Adjust thresholds based on label frequency + # Rare labels get lower thresholds, common labels get higher thresholds + for label, count in label_counts.items(): + frequency = count / total_examples + + if frequency < 0.05: # Very rare labels (< 5%) + self.label_thresholds[label] = self.default_threshold * 0.3 + elif frequency < 0.1: # Rare labels (< 10%) + self.label_thresholds[label] = self.default_threshold * 0.5 + elif frequency > 0.3: # Very common labels (> 30%) + self.label_thresholds[label] = self.default_threshold * 1.2 + else: # Normal frequency labels + self.label_thresholds[label] = self.default_threshold + + logger.debug(f"Updated label thresholds: {self.label_thresholds}") + + def _train_adaptive_head(self, epochs: int = 10): + """Train multi-label adaptive head with BCE loss.""" + if not self.memory.examples: + return + + # Prepare multi-label training data + all_embeddings = [] + all_labels = [] + + # Create label matrix for multi-label training + num_classes = len(self.label_to_id) + + # Collect unique texts and their labels + text_to_labels = defaultdict(set) + for label, examples in self.memory.examples.items(): + for example in examples: + text_to_labels[example.text].add(label) + + # Create training data with proper multi-label targets + for text, labels in text_to_labels.items(): + # Get embedding for this text (take first occurrence) + embedding = None + for label in labels: + for example in self.memory.examples[label]: + if example.text == text: + embedding = example.embedding + break + if embedding is not None: + break + + if embedding is not None: + all_embeddings.append(embedding) + + # Create multi-hot encoded label vector + label_vector = torch.zeros(num_classes) + for label in labels: + if label in self.label_to_id: + label_vector[self.label_to_id[label]] = 1.0 + + all_labels.append(label_vector) + + if not all_embeddings: + return + + all_embeddings = torch.stack(all_embeddings) + all_labels = torch.stack(all_labels) + + # Normalize embeddings + all_embeddings = F.normalize(all_embeddings, p=2, dim=1) + + # Create data loader + dataset = torch.utils.data.TensorDataset(all_embeddings, all_labels) + loader = torch.utils.data.DataLoader( + dataset, + batch_size=min(32, len(all_embeddings)), + shuffle=True, + generator=torch.Generator().manual_seed(42) + ) + + # Training setup + self.adaptive_head.train() + criterion = nn.BCELoss() # Binary Cross Entropy for multi-label + optimizer = torch.optim.AdamW( + self.adaptive_head.parameters(), + lr=0.001, + weight_decay=0.01 + ) + + best_loss = float('inf') + patience_counter = 0 + patience = 3 + + for epoch in range(epochs): + total_loss = 0 + for batch_embeddings, batch_labels in loader: + batch_embeddings = batch_embeddings.to(self.device) + batch_labels = batch_labels.to(self.device) + + optimizer.zero_grad() + outputs = self.adaptive_head(batch_embeddings) + + loss = criterion(outputs, batch_labels) + loss.backward() + + torch.nn.utils.clip_grad_norm_( + self.adaptive_head.parameters(), + max_norm=1.0 + ) + optimizer.step() + + total_loss += loss.item() + + avg_loss = total_loss / len(loader) + + # Early stopping + if avg_loss < best_loss: + best_loss = avg_loss + patience_counter = 0 + else: + patience_counter += 1 + if patience_counter >= patience: + logger.debug(f"Early stopping at epoch {epoch + 1}") + break + + self.train_steps += 1 + + def get_label_statistics(self) -> Dict[str, Any]: + """Get statistics about label distribution and thresholds.""" + stats = super().get_example_statistics() + + # Add multi-label specific stats + stats['label_thresholds'] = dict(self.label_thresholds) + stats['adaptive_threshold'] = self._get_adaptive_threshold(len(self.label_to_id)) + stats['default_threshold'] = self.default_threshold + stats['min_predictions'] = self.min_predictions + stats['max_predictions'] = self.max_predictions + + return stats \ No newline at end of file diff --git a/tests/test_multilabel.py b/tests/test_multilabel.py new file mode 100644 index 0000000..ca637c7 --- /dev/null +++ b/tests/test_multilabel.py @@ -0,0 +1,251 @@ +import pytest +import torch +import tempfile +import os +from pathlib import Path +from adaptive_classifier import MultiLabelAdaptiveClassifier, MultiLabelAdaptiveHead + + +@pytest.fixture +def sample_multilabel_data(): + """Sample multi-label training data.""" + texts = [ + "Scientists study climate change effects on polar ice caps", + "Tech company develops AI for medical diagnosis", + "Athletes train for upcoming Olympic games", + "Researchers discover new species in Amazon rainforest", + "Startup raises funding for sustainable energy project" + ] + + labels = [ + ["science", "climate", "environment"], + ["technology", "healthcare", "ai"], + ["sports", "fitness", "olympics"], + ["science", "nature", "discovery"], + ["business", "technology", "environment"] + ] + + return texts, labels + + +@pytest.fixture +def multilabel_classifier(): + """Create a MultiLabelAdaptiveClassifier instance.""" + return MultiLabelAdaptiveClassifier( + "distilbert/distilbert-base-cased", + default_threshold=0.5, + min_predictions=1, + max_predictions=5 + ) + + +def test_multilabel_classifier_initialization(multilabel_classifier): + """Test MultiLabelAdaptiveClassifier initialization.""" + assert multilabel_classifier.default_threshold == 0.5 + assert multilabel_classifier.min_predictions == 1 + assert multilabel_classifier.max_predictions == 5 + assert multilabel_classifier.adaptive_head is None + + +def test_multilabel_head_initialization(): + """Test MultiLabelAdaptiveHead initialization.""" + head = MultiLabelAdaptiveHead(768, 5) + assert head.num_classes == 5 + assert isinstance(head.model, torch.nn.Sequential) + + # Test forward pass + input_tensor = torch.randn(1, 768) + output = head(input_tensor) + assert output.shape == (1, 5) + assert torch.all(output >= 0) and torch.all(output <= 1) # Sigmoid output + + +def test_multilabel_head_update_classes(): + """Test updating number of classes in MultiLabelAdaptiveHead.""" + head = MultiLabelAdaptiveHead(768, 3) + original_weight = head.model[-1].weight.data.clone() + original_bias = head.model[-1].bias.data.clone() + + # Update to more classes + head.update_num_classes(5) + assert head.num_classes == 5 + + # Check that original weights are preserved + assert torch.equal(head.model[-1].weight[:3], original_weight) + assert torch.equal(head.model[-1].bias[:3], original_bias) + + +def test_adaptive_threshold_calculation(multilabel_classifier): + """Test adaptive threshold calculation for different numbers of labels.""" + # Test threshold scaling with number of labels + assert multilabel_classifier._get_adaptive_threshold(2) == 0.5 + assert multilabel_classifier._get_adaptive_threshold(5) == 0.4 + assert multilabel_classifier._get_adaptive_threshold(10) == 0.3 + assert multilabel_classifier._get_adaptive_threshold(20) == 0.2 + assert multilabel_classifier._get_adaptive_threshold(30) == 0.1 + + +def test_multilabel_training(multilabel_classifier, sample_multilabel_data): + """Test training with multi-label data.""" + texts, labels = sample_multilabel_data + + # Train classifier + multilabel_classifier.add_examples(texts, labels) + + # Check that labels were added correctly + expected_labels = set() + for label_list in labels: + expected_labels.update(label_list) + + assert len(multilabel_classifier.label_to_id) == len(expected_labels) + assert set(multilabel_classifier.label_to_id.keys()) == expected_labels + + # Check that adaptive head was initialized + assert multilabel_classifier.adaptive_head is not None + assert isinstance(multilabel_classifier.adaptive_head, MultiLabelAdaptiveHead) + + +def test_multilabel_prediction(multilabel_classifier, sample_multilabel_data): + """Test multi-label prediction.""" + texts, labels = sample_multilabel_data + + # Train classifier + multilabel_classifier.add_examples(texts, labels) + + # Make prediction + test_text = "AI researchers study climate change using machine learning" + predictions = multilabel_classifier.predict_multilabel(test_text) + + # Check predictions format + assert isinstance(predictions, list) + for label, confidence in predictions: + assert isinstance(label, str) + assert isinstance(confidence, float) + assert 0 <= confidence <= 1 + + # Check that we get at least min_predictions + assert len(predictions) >= multilabel_classifier.min_predictions + + +def test_threshold_filtering(multilabel_classifier, sample_multilabel_data): + """Test that threshold filtering works correctly.""" + texts, labels = sample_multilabel_data + multilabel_classifier.add_examples(texts, labels) + + test_text = "Scientific research on environmental issues" + + # Test with different thresholds + high_threshold_preds = multilabel_classifier.predict_multilabel(test_text, threshold=0.9) + low_threshold_preds = multilabel_classifier.predict_multilabel(test_text, threshold=0.1) + + # Lower threshold should give more predictions (or at least not fewer) + assert len(low_threshold_preds) >= len(high_threshold_preds) + + # With min_predictions=1, we should always get at least 1 prediction + assert len(high_threshold_preds) >= 1 + assert len(low_threshold_preds) >= 1 + + +def test_many_labels_scenario(multilabel_classifier): + """Test the specific scenario that caused 'No labels met the threshold criteria'.""" + # Create many labels + num_labels = 25 + texts = [] + labels = [] + + for i in range(num_labels): + for j in range(3): # 3 examples per label + texts.append(f"This is example {j} about topic {i}") + labels.append([f"label_{i:02d}"]) + + # Train with many labels + multilabel_classifier.add_examples(texts, labels) + + # Test prediction + test_text = "This is a general text about various topics" + predictions = multilabel_classifier.predict_multilabel(test_text) + + # Should not return empty result + assert len(predictions) > 0 + assert not isinstance(predictions, str) # Should not be error message + + # Should respect adaptive threshold + adaptive_threshold = multilabel_classifier._get_adaptive_threshold(num_labels) + assert adaptive_threshold < 0.5 # Should be lower for many labels + + +def test_label_specific_thresholds(multilabel_classifier, sample_multilabel_data): + """Test label-specific threshold updates.""" + texts, labels = sample_multilabel_data + + # Add more examples for some labels to make them common + additional_texts = ["More science content"] * 10 + additional_labels = [["science"]] * 10 + + all_texts = texts + additional_texts + all_labels = labels + additional_labels + + multilabel_classifier.add_examples(all_texts, all_labels) + + # Check that thresholds were updated + assert len(multilabel_classifier.label_thresholds) > 0 + + # Science should have a different threshold due to higher frequency + if "science" in multilabel_classifier.label_thresholds: + science_threshold = multilabel_classifier.label_thresholds["science"] + # Should be adjusted based on frequency + assert isinstance(science_threshold, float) + + +def test_save_load_multilabel(multilabel_classifier, sample_multilabel_data): + """Test saving and loading multi-label classifier.""" + texts, labels = sample_multilabel_data + multilabel_classifier.add_examples(texts, labels) + + test_text = "Test prediction text" + original_predictions = multilabel_classifier.predict_multilabel(test_text, max_labels=5) + + with tempfile.TemporaryDirectory() as temp_dir: + save_path = Path(temp_dir) / "multilabel_classifier" + + # Save classifier + multilabel_classifier.save(str(save_path)) + + # Load classifier with same configuration + loaded_classifier = MultiLabelAdaptiveClassifier.load( + str(save_path), + device=multilabel_classifier.device + ) + # Set same max_predictions for fair comparison + loaded_classifier.max_predictions = multilabel_classifier.max_predictions + + # Test that predictions are similar + loaded_predictions = loaded_classifier.predict_multilabel(test_text, max_labels=5) + + assert len(loaded_predictions) <= 5 # Should respect max_labels + assert len(original_predictions) <= 5 + # Check that we get predictions from both + assert len(loaded_predictions) > 0 + assert len(original_predictions) > 0 + + +def test_statistics_reporting(multilabel_classifier, sample_multilabel_data): + """Test that statistics are reported correctly.""" + texts, labels = sample_multilabel_data + multilabel_classifier.add_examples(texts, labels) + + stats = multilabel_classifier.get_label_statistics() + + # Check required fields + assert 'label_thresholds' in stats + assert 'adaptive_threshold' in stats + assert 'default_threshold' in stats + assert 'min_predictions' in stats + assert 'max_predictions' in stats + assert 'num_classes' in stats + assert 'total_examples' in stats + + # Check values + assert stats['default_threshold'] == multilabel_classifier.default_threshold + assert stats['min_predictions'] == multilabel_classifier.min_predictions + assert stats['max_predictions'] == multilabel_classifier.max_predictions \ No newline at end of file