Skip to content

Commit 55896dd

Browse files
rootfsAias00
authored andcommitted
feat: add system prompt toggle endpoint (vllm-project#301)
* feat: add system prompt toggle endpoint Signed-off-by: Huamin Chen <[email protected]> * add cli option to explicitly enable the prompt toggle Signed-off-by: Huamin Chen <[email protected]> * fix test failure Signed-off-by: Huamin Chen <[email protected]> * fix test failure Signed-off-by: Huamin Chen <[email protected]> * fix test failure Signed-off-by: Huamin Chen <[email protected]> * adding system prompt endpoint option to makefile target Signed-off-by: Huamin Chen <[email protected]> * update doc Signed-off-by: Huamin Chen <[email protected]> * address review comment Signed-off-by: Huamin Chen <[email protected]> --------- Signed-off-by: Huamin Chen <[email protected]> Signed-off-by: liuhy <[email protected]>
1 parent c220ea6 commit 55896dd

File tree

8 files changed

+738
-28
lines changed

8 files changed

+738
-28
lines changed

src/semantic-router/cmd/main.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ import (
1515
func main() {
1616
// Parse command-line flags
1717
var (
18-
configPath = flag.String("config", "config/config.yaml", "Path to the configuration file")
19-
port = flag.Int("port", 50051, "Port to listen on for gRPC ExtProc")
20-
apiPort = flag.Int("api-port", 8080, "Port to listen on for Classification API")
21-
metricsPort = flag.Int("metrics-port", 9190, "Port for Prometheus metrics")
22-
enableAPI = flag.Bool("enable-api", true, "Enable Classification API server")
23-
secure = flag.Bool("secure", false, "Enable secure gRPC server with TLS")
24-
certPath = flag.String("cert-path", "", "Path to TLS certificate directory (containing tls.crt and tls.key)")
18+
configPath = flag.String("config", "config/config.yaml", "Path to the configuration file")
19+
port = flag.Int("port", 50051, "Port to listen on for gRPC ExtProc")
20+
apiPort = flag.Int("api-port", 8080, "Port to listen on for Classification API")
21+
metricsPort = flag.Int("metrics-port", 9190, "Port for Prometheus metrics")
22+
enableAPI = flag.Bool("enable-api", true, "Enable Classification API server")
23+
enableSystemPromptAPI = flag.Bool("enable-system-prompt-api", false, "Enable system prompt configuration endpoints (SECURITY: only enable in trusted environments)")
24+
secure = flag.Bool("secure", false, "Enable secure gRPC server with TLS")
25+
certPath = flag.String("cert-path", "", "Path to TLS certificate directory (containing tls.crt and tls.key)")
2526
)
2627
flag.Parse()
2728

@@ -58,7 +59,7 @@ func main() {
5859
if *enableAPI {
5960
go func() {
6061
observability.Infof("Starting Classification API server on port %d", *apiPort)
61-
if err := api.StartClassificationAPI(*configPath, *apiPort); err != nil {
62+
if err := api.StartClassificationAPI(*configPath, *apiPort, *enableSystemPromptAPI); err != nil {
6263
observability.Errorf("Classification API server error: %v", err)
6364
}
6465
}()

src/semantic-router/pkg/api/server.go

Lines changed: 165 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ import (
1717

1818
// ClassificationAPIServer holds the server state and dependencies
1919
type ClassificationAPIServer struct {
20-
classificationSvc *services.ClassificationService
21-
config *config.RouterConfig
20+
classificationSvc *services.ClassificationService
21+
config *config.RouterConfig
22+
enableSystemPromptAPI bool
2223
}
2324

2425
// ModelsInfoResponse represents the response for models info endpoint
@@ -101,7 +102,7 @@ type ClassificationOptions struct {
101102
}
102103

103104
// StartClassificationAPI starts the Classification API server
104-
func StartClassificationAPI(configPath string, port int) error {
105+
func StartClassificationAPI(configPath string, port int, enableSystemPromptAPI bool) error {
105106
// Load configuration
106107
cfg, err := config.LoadConfig(configPath)
107108
if err != nil {
@@ -139,8 +140,9 @@ func StartClassificationAPI(configPath string, port int) error {
139140

140141
// Create server instance
141142
apiServer := &ClassificationAPIServer{
142-
classificationSvc: classificationSvc,
143-
config: cfg,
143+
classificationSvc: classificationSvc,
144+
config: cfg,
145+
enableSystemPromptAPI: enableSystemPromptAPI,
144146
}
145147

146148
// Create HTTP server with routes
@@ -203,6 +205,15 @@ func (s *ClassificationAPIServer) setupRoutes() *http.ServeMux {
203205
mux.HandleFunc("GET /config/classification", s.handleGetConfig)
204206
mux.HandleFunc("PUT /config/classification", s.handleUpdateConfig)
205207

208+
// System prompt configuration endpoints (only if explicitly enabled)
209+
if s.enableSystemPromptAPI {
210+
observability.Infof("System prompt configuration endpoints enabled")
211+
mux.HandleFunc("GET /config/system-prompts", s.handleGetSystemPrompts)
212+
mux.HandleFunc("PUT /config/system-prompts", s.handleUpdateSystemPrompts)
213+
} else {
214+
observability.Infof("System prompt configuration endpoints disabled for security")
215+
}
216+
206217
return mux
207218
}
208219

@@ -705,3 +716,152 @@ func (s *ClassificationAPIServer) calculateUnifiedStatistics(unifiedResults *ser
705716
LowConfidenceCount: lowConfidenceCount,
706717
}
707718
}
719+
720+
// SystemPromptInfo represents system prompt information for a category
721+
type SystemPromptInfo struct {
722+
Category string `json:"category"`
723+
Prompt string `json:"prompt"`
724+
Enabled bool `json:"enabled"`
725+
Mode string `json:"mode"` // "replace" or "insert"
726+
}
727+
728+
// SystemPromptsResponse represents the response for GET /config/system-prompts
729+
type SystemPromptsResponse struct {
730+
SystemPrompts []SystemPromptInfo `json:"system_prompts"`
731+
}
732+
733+
// SystemPromptUpdateRequest represents a request to update system prompt settings
734+
type SystemPromptUpdateRequest struct {
735+
Category string `json:"category,omitempty"` // If empty, applies to all categories
736+
Enabled *bool `json:"enabled,omitempty"` // true to enable, false to disable
737+
Mode string `json:"mode,omitempty"` // "replace" or "insert"
738+
}
739+
740+
// handleGetSystemPrompts handles GET /config/system-prompts
741+
func (s *ClassificationAPIServer) handleGetSystemPrompts(w http.ResponseWriter, r *http.Request) {
742+
cfg := s.config
743+
if cfg == nil {
744+
http.Error(w, "Configuration not available", http.StatusInternalServerError)
745+
return
746+
}
747+
748+
var systemPrompts []SystemPromptInfo
749+
for _, category := range cfg.Categories {
750+
systemPrompts = append(systemPrompts, SystemPromptInfo{
751+
Category: category.Name,
752+
Prompt: category.SystemPrompt,
753+
Enabled: category.IsSystemPromptEnabled(),
754+
Mode: category.GetSystemPromptMode(),
755+
})
756+
}
757+
758+
response := SystemPromptsResponse{
759+
SystemPrompts: systemPrompts,
760+
}
761+
762+
w.Header().Set("Content-Type", "application/json")
763+
if err := json.NewEncoder(w).Encode(response); err != nil {
764+
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
765+
return
766+
}
767+
}
768+
769+
// handleUpdateSystemPrompts handles PUT /config/system-prompts
770+
func (s *ClassificationAPIServer) handleUpdateSystemPrompts(w http.ResponseWriter, r *http.Request) {
771+
var req SystemPromptUpdateRequest
772+
if err := s.parseJSONRequest(r, &req); err != nil {
773+
http.Error(w, err.Error(), http.StatusBadRequest)
774+
return
775+
}
776+
777+
if req.Enabled == nil && req.Mode == "" {
778+
http.Error(w, "either enabled or mode field is required", http.StatusBadRequest)
779+
return
780+
}
781+
782+
// Validate mode if provided
783+
if req.Mode != "" && req.Mode != "replace" && req.Mode != "insert" {
784+
http.Error(w, "mode must be either 'replace' or 'insert'", http.StatusBadRequest)
785+
return
786+
}
787+
788+
cfg := s.config
789+
if cfg == nil {
790+
http.Error(w, "Configuration not available", http.StatusInternalServerError)
791+
return
792+
}
793+
794+
// Create a copy of the config to modify
795+
newCfg := *cfg
796+
newCategories := make([]config.Category, len(cfg.Categories))
797+
copy(newCategories, cfg.Categories)
798+
newCfg.Categories = newCategories
799+
800+
updated := false
801+
if req.Category == "" {
802+
// Update all categories
803+
for i := range newCfg.Categories {
804+
if newCfg.Categories[i].SystemPrompt != "" {
805+
if req.Enabled != nil {
806+
newCfg.Categories[i].SystemPromptEnabled = req.Enabled
807+
}
808+
if req.Mode != "" {
809+
newCfg.Categories[i].SystemPromptMode = req.Mode
810+
}
811+
updated = true
812+
}
813+
}
814+
} else {
815+
// Update specific category
816+
for i := range newCfg.Categories {
817+
if newCfg.Categories[i].Name == req.Category {
818+
if newCfg.Categories[i].SystemPrompt == "" {
819+
http.Error(w, fmt.Sprintf("Category '%s' has no system prompt configured", req.Category), http.StatusBadRequest)
820+
return
821+
}
822+
if req.Enabled != nil {
823+
newCfg.Categories[i].SystemPromptEnabled = req.Enabled
824+
}
825+
if req.Mode != "" {
826+
newCfg.Categories[i].SystemPromptMode = req.Mode
827+
}
828+
updated = true
829+
break
830+
}
831+
}
832+
if !updated {
833+
http.Error(w, fmt.Sprintf("Category '%s' not found", req.Category), http.StatusNotFound)
834+
return
835+
}
836+
}
837+
838+
if !updated {
839+
http.Error(w, "No categories with system prompts found to update", http.StatusBadRequest)
840+
return
841+
}
842+
843+
// Update the configuration
844+
s.config = &newCfg
845+
s.classificationSvc.UpdateConfig(&newCfg)
846+
847+
// Return the updated system prompts
848+
var systemPrompts []SystemPromptInfo
849+
for _, category := range newCfg.Categories {
850+
systemPrompts = append(systemPrompts, SystemPromptInfo{
851+
Category: category.Name,
852+
Prompt: category.SystemPrompt,
853+
Enabled: category.IsSystemPromptEnabled(),
854+
Mode: category.GetSystemPromptMode(),
855+
})
856+
}
857+
858+
response := SystemPromptsResponse{
859+
SystemPrompts: systemPrompts,
860+
}
861+
862+
w.Header().Set("Content-Type", "application/json")
863+
if err := json.NewEncoder(w).Encode(response); err != nil {
864+
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
865+
return
866+
}
867+
}

0 commit comments

Comments
 (0)