Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ Adaptive Classifier is a PyTorch-based machine learning library that revolutioni

### 🎯 **Core Capabilities**
- **🚀 Universal Compatibility** - Works with any HuggingFace transformer model
- **⚡ Optimized Inference** - Built-in ONNX Runtime for 2-4x faster CPU predictions
- **📈 Continuous Learning** - Add new examples without catastrophic forgetting
- **🔄 Dynamic Classes** - Add new classes at runtime without retraining
- ** Zero Downtime** - Update models in production without service interruption
- **⏱️ Zero Downtime** - Update models in production without service interruption

### 🛡️ **Advanced Defense**
- **🎮 Strategic Classification** - Game-theoretic defense against adversarial manipulation
Expand Down Expand Up @@ -99,6 +100,8 @@ Tested on arena-hard-auto-v0.1 dataset (500 queries):
pip install adaptive-classifier
```

**Includes:** ONNX Runtime for 2-4x faster CPU inference out-of-the-box

### 🛠️ Development Setup
```bash
# Clone the repository
Expand Down Expand Up @@ -191,6 +194,74 @@ predictions = strategic_classifier.predict("This product has amazing quality fea
# Returns predictions that consider potential gaming attempts
```

### ⚡ Optimized CPU Inference with ONNX

Adaptive Classifier includes **built-in ONNX Runtime support** for **2-4x faster CPU inference** with zero code changes required.

#### Automatic Optimization (Default)

ONNX Runtime is automatically used on CPU for optimal performance:

```python
# Automatically uses ONNX on CPU, PyTorch on GPU
classifier = AdaptiveClassifier("bert-base-uncased")

# That's it! Predictions are 2-4x faster on CPU
predictions = classifier.predict("Fast inference!")
```

#### Performance Comparison

| Configuration | Speed | Use Case |
|--------------|-------|----------|
| PyTorch (GPU) | Fastest | GPU servers |
| **ONNX (CPU)** | **2-4x faster** | **Production CPU deployments** |
| PyTorch (CPU) | Baseline | Development, training |

#### Save & Deploy with ONNX

```python
# Save with ONNX export (both quantized & unquantized versions)
classifier.save("./model")

# Push to Hub with ONNX (both versions included by default)
classifier.push_to_hub("username/model")

# Load automatically uses quantized ONNX on CPU (fastest, 4x smaller)
fast_classifier = AdaptiveClassifier.load("./model")

# Choose unquantized ONNX for maximum accuracy
accurate_classifier = AdaptiveClassifier.load("./model", prefer_quantized=False)

# Force PyTorch (no ONNX)
pytorch_classifier = AdaptiveClassifier.load("./model", use_onnx=False)

# Opt-out of ONNX export when saving
classifier.save("./model", include_onnx=False)
```

**ONNX Model Versions:**
- **Quantized (default)**: INT8 quantized, 4x smaller, ~1.14x faster on ARM, 2-4x faster on x86
- **Unquantized**: Full precision, maximum accuracy, larger file size

By default, models are saved with both versions, and the quantized version is automatically loaded for best performance. Use `prefer_quantized=False` if you need maximum accuracy.

#### Benchmark Your Model

```bash
# Compare PyTorch vs ONNX performance
python scripts/benchmark_onnx.py --model bert-base-uncased --runs 100
```

**Example Results:**
```
Model: bert-base-uncased (CPU)
PyTorch: 8.3ms/query (baseline)
ONNX: 2.1ms/query (4.0x faster) ✓
```

> **Note:** ONNX optimization is included by default. For GPU inference, PyTorch is automatically used for best performance.

## Advanced Usage

### Adding New Classes Dynamically
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ tqdm>=4.65.0
setuptools>=65.0.0
wheel>=0.40.0
scikit-learn
huggingface_hub>=0.17.0
huggingface_hub>=0.17.0
optimum[onnxruntime]>=1.14.0
178 changes: 178 additions & 0 deletions scripts/benchmark_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""Benchmark script comparing PyTorch vs ONNX vs Quantized ONNX performance."""

import time
import argparse
import tempfile
from pathlib import Path
import numpy as np
from adaptive_classifier import AdaptiveClassifier


def check_optimum_installed():
"""Check if optimum is installed."""
try:
import optimum.onnxruntime
return True
except ImportError:
return False


def benchmark_inference(classifier, texts, num_runs=100):
"""Benchmark inference speed."""
# Warmup
for _ in range(5):
classifier.predict(texts[0])

# Benchmark
start_time = time.time()
for _ in range(num_runs):
for text in texts:
classifier.predict(text)

end_time = time.time()
total_time = end_time - start_time
avg_time_per_query = (total_time / (num_runs * len(texts))) * 1000 # ms

return avg_time_per_query, total_time


def main():
parser = argparse.ArgumentParser(description="Benchmark ONNX vs PyTorch performance")
parser.add_argument("--model", type=str, default="prajjwal1/bert-tiny",
help="HuggingFace model name to benchmark")
parser.add_argument("--runs", type=int, default=100,
help="Number of benchmark runs")
parser.add_argument("--skip-quantized", action="store_true",
help="Skip quantized ONNX benchmarking")
args = parser.parse_args()

if not check_optimum_installed():
print("⚠️ optimum[onnxruntime] not installed. Skipping ONNX benchmarks.")
print("Install with: pip install optimum[onnxruntime]")
return

print("=" * 70)
print("ONNX Runtime Benchmark for Adaptive Classifier")
print("=" * 70)
print(f"Model: {args.model}")
print(f"Runs per test: {args.runs}")
print()

# Prepare test data
test_texts = [
"This is a positive example",
"This seems negative to me",
"A neutral statement here",
"Another test case for benchmarking performance",
"The quick brown fox jumps over the lazy dog"
]

print("Preparing classifiers...")
print()

# Train a baseline classifier
classifier_base = AdaptiveClassifier(args.model, use_onnx=False, device="cpu")
training_texts = [
"great product", "terrible experience", "okay item",
"loved it", "hated it", "it's fine",
"amazing quality", "poor service", "average performance"
]
training_labels = [
"positive", "negative", "neutral",
"positive", "negative", "neutral",
"positive", "negative", "neutral"
]
classifier_base.add_examples(training_texts, training_labels)

# Save and create ONNX versions
with tempfile.TemporaryDirectory() as tmpdir:
save_path = Path(tmpdir) / "classifier"

# Save with ONNX versions
print("Exporting ONNX models...")
classifier_base._save_pretrained(
save_path,
include_onnx=True,
quantize_onnx=not args.skip_quantized
)

# Load PyTorch version
print("Loading PyTorch model...")
classifier_pytorch = AdaptiveClassifier._from_pretrained(
str(save_path),
use_onnx=False
)

# Load ONNX version
print("Loading ONNX model...")
classifier_onnx = AdaptiveClassifier._from_pretrained(
str(save_path),
use_onnx=True
)

print()
print("Starting benchmarks...")
print("-" * 70)

# Benchmark PyTorch
print("\n1. PyTorch Baseline")
print(" Running benchmark...")
pytorch_avg, pytorch_total = benchmark_inference(
classifier_pytorch, test_texts, args.runs
)
print(f" ✓ Average time per query: {pytorch_avg:.2f}ms")
print(f" ✓ Total time: {pytorch_total:.2f}s")

# Benchmark ONNX
print("\n2. ONNX Runtime")
print(" Running benchmark...")
onnx_avg, onnx_total = benchmark_inference(
classifier_onnx, test_texts, args.runs
)
print(f" ✓ Average time per query: {onnx_avg:.2f}ms")
print(f" ✓ Total time: {onnx_total:.2f}s")
speedup = pytorch_avg / onnx_avg
print(f" ✓ Speedup: {speedup:.2f}x faster than PyTorch")

# Test prediction accuracy
print("\n3. Accuracy Verification")
test_text = "This is amazing!"
pred_pytorch = classifier_pytorch.predict(test_text)
pred_onnx = classifier_onnx.predict(test_text)

print(f" PyTorch top prediction: {pred_pytorch[0]}")
print(f" ONNX top prediction: {pred_onnx[0]}")

if pred_pytorch[0][0] == pred_onnx[0][0]:
print(" ✓ Predictions match!")
else:
print(" ⚠️ Predictions differ slightly")

print()
print("=" * 70)
print("SUMMARY")
print("=" * 70)
print(f"PyTorch: {pytorch_avg:.2f}ms/query (baseline)")
print(f"ONNX: {onnx_avg:.2f}ms/query ({speedup:.2f}x faster)")
print()

if speedup > 2.0:
print("🚀 ONNX provides significant speedup! (>2x)")
elif speedup > 1.2:
print("⚡ ONNX provides moderate speedup")
else:
print("ℹ️ ONNX provides marginal speedup")

print()
print("=" * 70)
print("\nRecommendation:")
if speedup > 1.5:
print("✓ Use ONNX for CPU inference for better performance!")
print(" classifier = AdaptiveClassifier(model_name, use_onnx=True)")
else:
print("ℹ️ ONNX speedup is modest for this model.")
print(" Consider using smaller models (distilbert, MiniLM) for better gains.")


if __name__ == "__main__":
main()
Loading