Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setup(
name="adaptive-classifier",
version="0.0.16",
version="0.0.17",
author="codelion",
author_email="[email protected]",
description="A flexible, adaptive classification system for dynamic text classification",
Expand Down
17 changes: 1 addition & 16 deletions src/adaptive_classifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,7 @@
from .memory import PrototypeMemory
from huggingface_hub import ModelHubMixin

import os
import re

def get_version_from_setup():
try:
setup_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'setup.py')
with open(setup_path, 'r') as f:
content = f.read()
version_match = re.search(r'version=["\']([^"\']+)["\']', content)
if version_match:
return version_match.group(1)
except Exception:
pass
return "unknown"

__version__ = get_version_from_setup()
__version__ = "0.0.17"

__all__ = [
"AdaptiveClassifier",
Expand Down
182 changes: 125 additions & 57 deletions src/adaptive_classifier/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,48 +83,67 @@ def add_examples(self, texts: List[str], labels: List[str]):
raise ValueError("Empty input lists")
if len(texts) != len(labels):
raise ValueError("Mismatched text and label lists")


# Check if classifier has any existing classes (before updating mappings)
has_existing_classes = len(self.label_to_id) > 0

# Check for new classes
new_classes = set(labels) - set(self.label_to_id.keys())
is_adding_new_classes = len(new_classes) > 0

# Update label mappings - sort new classes alphabetically for consistent IDs
for label in sorted(new_classes):
idx = len(self.label_to_id)
self.label_to_id[label] = idx
self.id_to_label[idx] = label

# Get embeddings for all texts
embeddings = self._get_embeddings(texts)

# Add examples to memory and update training history
for text, embedding, label in zip(texts, embeddings, labels):
example = Example(text, label, embedding)
self.memory.add_example(example, label)

# Update training history
if label not in self.training_history:
self.training_history[label] = 0
self.training_history[label] += 1

# Special handling for new classes
if is_adding_new_classes:
# Store old head for EWC

# Determine training strategy: only use special new class handling for incremental learning
is_incremental_learning = is_adding_new_classes and has_existing_classes

if is_incremental_learning:
# Adding new classes to existing classifier - use special handling
# Store old head for EWC before modifying structure
old_head = copy.deepcopy(self.adaptive_head) if self.adaptive_head is not None else None

# Initialize new head with more output classes
self._initialize_adaptive_head()


# Expand existing head to accommodate new classes (preserves weights)
num_classes = len(self.label_to_id)
self.adaptive_head.update_num_classes(num_classes)
# Move to correct device after update
self.adaptive_head = self.adaptive_head.to(self.device)

# Train with focus on new classes
self._train_new_classes(old_head, new_classes)
else:
# Regular training for existing classes
# Initial training or regular updates - use normal training
# Initialize head if needed
if self.adaptive_head is None:
self._initialize_adaptive_head()
elif is_adding_new_classes:
# Edge case: expanding head for new classes but treating as regular training
num_classes = len(self.label_to_id)
self.adaptive_head.update_num_classes(num_classes)
self.adaptive_head = self.adaptive_head.to(self.device)

# Regular training
self._train_adaptive_head()

# Strategic training step if enabled
if self.strategic_mode and self.train_steps % self.config.strategic_training_frequency == 0:
self._perform_strategic_training()

# Ensure FAISS index is up to date after adding examples
self.memory._rebuild_index()

Expand All @@ -142,48 +161,94 @@ def _train_new_classes(self, old_head: Optional[nn.Module], new_classes: Set[str
for label in self.memory.examples:
examples_per_class[label] = len(self.memory.examples[label])

# Calculate sampling weights to balance old and new classes
# Improved sampling strategy for many-class scenarios
min_examples = min(examples_per_class.values())
sampling_weights = {}

for label, count in examples_per_class.items():
if label in new_classes:
# Oversample new classes
sampling_weights[label] = 2.0
else:
# Sample old classes proportionally
sampling_weights[label] = min_examples / count

# Sample examples with weights
for label, examples in self.memory.examples.items():
weight = sampling_weights[label]
num_samples = max(min_examples, int(len(examples) * weight))

# Randomly sample with replacement if needed
indices = np.random.choice(
len(examples),
size=num_samples,
replace=num_samples > len(examples)
)

for idx in indices:
example = examples[idx]
all_embeddings.append(example.embedding)
all_labels.append(self.label_to_id[label])
max_examples = max(examples_per_class.values())

# For many-class scenarios, use a more balanced approach
num_classes = len(examples_per_class)
target_samples_per_class = max(5, min(10, min_examples * 2)) # Adaptive target

if num_classes > 20: # Many-class scenario
# Use stratified sampling to ensure all classes get representation
for label, examples in self.memory.examples.items():
if label in new_classes:
# Give new classes more representation, but not excessive
num_samples = min(len(examples), target_samples_per_class * 2)
else:
# Ensure old classes maintain representation
num_samples = min(len(examples), target_samples_per_class)

# Sample without replacement first, then with if needed
if num_samples <= len(examples):
indices = np.random.choice(len(examples), size=num_samples, replace=False)
else:
indices = np.random.choice(len(examples), size=num_samples, replace=True)

for idx in indices:
example = examples[idx]
all_embeddings.append(example.embedding)
all_labels.append(self.label_to_id[label])
else:
# Original strategy for fewer classes
sampling_weights = {}

for label, count in examples_per_class.items():
if label in new_classes:
# Oversample new classes
sampling_weights[label] = 2.0
else:
# Sample old classes proportionally
sampling_weights[label] = min_examples / count

# Sample examples with weights
for label, examples in self.memory.examples.items():
weight = sampling_weights[label]
num_samples = max(min_examples, int(len(examples) * weight))

# Randomly sample with replacement if needed
indices = np.random.choice(
len(examples),
size=num_samples,
replace=num_samples > len(examples)
)

for idx in indices:
example = examples[idx]
all_embeddings.append(example.embedding)
all_labels.append(self.label_to_id[label])

all_embeddings = torch.stack(all_embeddings)
all_labels = torch.tensor(all_labels)

# Create dataset and initialize EWC with lower penalty for new classes
dataset = torch.utils.data.TensorDataset(all_embeddings, all_labels)


ewc = None
if old_head is not None:
ewc = EWC(
old_head,
dataset,
device=self.device,
ewc_lambda=10.0 # Lower EWC penalty to allow better learning of new classes
)
# Create a dataset for EWC that only includes examples from old classes
old_embeddings = []
old_labels = []
old_label_to_id = {label: idx for idx, label in enumerate(self.id_to_label.values())
if label not in new_classes}

for label, examples in self.memory.examples.items():
if label not in new_classes: # Only old classes
for example in examples[:5]: # Limit to representative examples
old_embeddings.append(example.embedding)
old_labels.append(old_label_to_id[label])

if old_embeddings: # Only create EWC if we have old examples
old_embeddings = torch.stack(old_embeddings)
old_labels = torch.tensor(old_labels, dtype=torch.long)
old_dataset = torch.utils.data.TensorDataset(old_embeddings, old_labels)

ewc = EWC(
old_head,
old_dataset,
device=self.device,
ewc_lambda=5.0 # Balanced EWC penalty
)

# Training setup
self.adaptive_head.train()
Expand Down Expand Up @@ -220,7 +285,7 @@ def _train_new_classes(self, old_head: Optional[nn.Module], new_classes: Set[str
task_loss = criterion(outputs, batch_labels)

# Add EWC loss if applicable
if old_head is not None:
if ewc is not None:
ewc_loss = ewc.ewc_loss(batch_size=len(batch_embeddings))
loss = task_loss + ewc_loss
else:
Expand Down Expand Up @@ -302,10 +367,12 @@ def _predict_regular(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
# Get embedding
embedding = self._get_embeddings([text])[0]

# Get prototype predictions
proto_preds = self.memory.get_nearest_prototypes(embedding, k=k)

# Get neural predictions if available
# Get prototype predictions for ALL classes (not limited by k)
# This ensures complete scoring information for proper combination
max_classes = len(self.id_to_label) if self.id_to_label else k
proto_preds = self.memory.get_nearest_prototypes(embedding, k=max_classes)

# Get neural predictions if available for ALL classes (not limited by k)
if self.adaptive_head is not None:
self.adaptive_head.eval() # Ensure eval mode
# Add batch dimension and move to device
Expand All @@ -314,8 +381,9 @@ def _predict_regular(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
# Squeeze batch dimension
logits = logits.squeeze(0)
probs = F.softmax(logits, dim=0)

values, indices = torch.topk(probs, min(k, len(self.id_to_label)))

# Get predictions for ALL classes for proper scoring combination
values, indices = torch.topk(probs, len(self.id_to_label))
head_preds = [
(self.id_to_label[idx.item()], val.item())
for val, idx in zip(values, indices)
Expand Down
Loading