Skip to content

Commit d1e7754

Browse files
rootfsAias00
authored andcommitted
fix: use both unified and legacy classifier to prevent failure (vllm-project#332)
Signed-off-by: Huamin Chen <[email protected]> Signed-off-by: liuhy <[email protected]>
1 parent febcbe2 commit d1e7754

File tree

1 file changed

+63
-9
lines changed

1 file changed

+63
-9
lines changed

src/semantic-router/pkg/services/classification.go

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package services
33
import (
44
"fmt"
55
"os"
6+
"strings"
67
"sync"
78
"time"
89

@@ -35,9 +36,9 @@ func NewClassificationService(classifier *classification.Classifier, config *con
3536
}
3637

3738
// NewUnifiedClassificationService creates a new service with unified classifier
38-
func NewUnifiedClassificationService(unifiedClassifier *classification.UnifiedClassifier, config *config.RouterConfig) *ClassificationService {
39+
func NewUnifiedClassificationService(unifiedClassifier *classification.UnifiedClassifier, legacyClassifier *classification.Classifier, config *config.RouterConfig) *ClassificationService {
3940
service := &ClassificationService{
40-
classifier: nil, // Legacy classifier not used
41+
classifier: legacyClassifier,
4142
unifiedClassifier: unifiedClassifier,
4243
config: config,
4344
}
@@ -54,16 +55,69 @@ func NewClassificationServiceWithAutoDiscovery(config *config.RouterConfig) (*Cl
5455
observability.Debugf("Debug: Attempting to discover models in: ./models")
5556

5657
// Always try to auto-discover and initialize unified classifier for batch processing
57-
unifiedClassifier, err := classification.AutoInitializeUnifiedClassifier("./models")
58+
// Use model path from config, fallback to "./models" if not specified
59+
modelsPath := "./models"
60+
if config != nil && config.Classifier.CategoryModel.ModelID != "" {
61+
// Extract the models directory from the model path
62+
// e.g., "models/category_classifier_modernbert-base_model" -> "models"
63+
if idx := strings.Index(config.Classifier.CategoryModel.ModelID, "/"); idx > 0 {
64+
modelsPath = config.Classifier.CategoryModel.ModelID[:idx]
65+
}
66+
}
67+
unifiedClassifier, ucErr := classification.AutoInitializeUnifiedClassifier(modelsPath)
68+
if ucErr != nil {
69+
observability.Infof("Unified classifier auto-discovery failed: %v", ucErr)
70+
}
71+
// create legacy classifier
72+
legacyClassifier, lcErr := createLegacyClassifier(config)
73+
if lcErr != nil {
74+
observability.Warnf("Legacy classifier initialization failed: %v", lcErr)
75+
}
76+
if unifiedClassifier == nil && legacyClassifier == nil {
77+
observability.Warnf("No classifier initialized. Using placeholder service.")
78+
}
79+
return NewUnifiedClassificationService(unifiedClassifier, legacyClassifier, config), nil
80+
}
81+
82+
// createLegacyClassifier creates a legacy classifier with proper model loading
83+
func createLegacyClassifier(config *config.RouterConfig) (*classification.Classifier, error) {
84+
// Load category mapping
85+
var categoryMapping *classification.CategoryMapping
86+
if config.Classifier.CategoryModel.CategoryMappingPath != "" {
87+
var err error
88+
categoryMapping, err = classification.LoadCategoryMapping(config.Classifier.CategoryModel.CategoryMappingPath)
89+
if err != nil {
90+
return nil, fmt.Errorf("failed to load category mapping: %w", err)
91+
}
92+
}
93+
94+
// Load PII mapping
95+
var piiMapping *classification.PIIMapping
96+
if config.Classifier.PIIModel.PIIMappingPath != "" {
97+
var err error
98+
piiMapping, err = classification.LoadPIIMapping(config.Classifier.PIIModel.PIIMappingPath)
99+
if err != nil {
100+
return nil, fmt.Errorf("failed to load PII mapping: %w", err)
101+
}
102+
}
103+
104+
// Load jailbreak mapping
105+
var jailbreakMapping *classification.JailbreakMapping
106+
if config.PromptGuard.JailbreakMappingPath != "" {
107+
var err error
108+
jailbreakMapping, err = classification.LoadJailbreakMapping(config.PromptGuard.JailbreakMappingPath)
109+
if err != nil {
110+
return nil, fmt.Errorf("failed to load jailbreak mapping: %w", err)
111+
}
112+
}
113+
114+
// Create classifier
115+
classifier, err := classification.NewClassifier(config, categoryMapping, piiMapping, jailbreakMapping)
58116
if err != nil {
59-
// Log the discovery failure but don't fail - fall back to legacy processing
60-
observability.Infof("Unified classifier auto-discovery failed: %v. Using legacy processing.", err)
61-
return NewClassificationService(nil, config), nil
117+
return nil, fmt.Errorf("failed to create classifier: %w", err)
62118
}
63119

64-
// Success! Create service with unified classifier
65-
observability.Infof("Unified classifier auto-discovered and initialized. Using batch processing.")
66-
return NewUnifiedClassificationService(unifiedClassifier, config), nil
120+
return classifier, nil
67121
}
68122

69123
// GetGlobalClassificationService returns the global classification service instance

0 commit comments

Comments
 (0)