Skip to content

Commit 5c4d356

Browse files
authored
fix(inference): forbid invalid endpoint and body content (#352)
1 parent 2b77eb4 commit 5c4d356

File tree

4 files changed

+272
-4
lines changed

4 files changed

+272
-4
lines changed

api/inference/const/const.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ var (
1818
"/audio/transcriptions": {},
1919
}
2020

21+
// FreePrefixes defines path prefixes that can be accessed without charging
22+
// These are typically metadata or system endpoints that don't consume GPU resources
23+
// Note: Paths here should NOT include /v1/proxy prefix (it's already stripped)
24+
FreePrefixes = []string{
25+
"/attestation", // TEE attestation endpoints (e.g., /attestation/report)
26+
"/signature", // TEE signature endpoints (e.g., /signature/{chatID})
27+
}
28+
2129
// Keep this as to remove duplicate headers from incoming request
2230
RequestMetaDataDuplicate = map[string]struct{}{
2331
"Address": {},

api/inference/internal/ctrl/proxy.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package ctrl
33
import (
44
"bytes"
55
"encoding/json"
6+
"fmt"
67
"io"
78
"net/http"
89
"net/url"
@@ -24,6 +25,19 @@ func (c *Ctrl) PrepareHTTPRequest(ctx *gin.Context, targetURL string, reqBody []
2425
return nil, errors.Wrap(err, "ensure stream options")
2526
}
2627
reqBody = modifiedBody
28+
29+
// Enforce configured model to prevent users from requesting more expensive models
30+
// Pass user address from context for rate limiting
31+
userAddr, _ := ctx.Get("userAddress")
32+
userAddrStr, _ := userAddr.(string)
33+
modifiedBody, err = c.EnforceConfiguredModel(reqBody, userAddrStr)
34+
if err != nil {
35+
// Model validation failure is a user input error (similar to invalid token, bad request body)
36+
// Mark as expected error to prevent polluting error monitoring
37+
ctx.Set("ignoreError", true)
38+
return nil, errors.Wrap(err, "enforce configured model")
39+
}
40+
reqBody = modifiedBody
2741
}
2842

2943
// For text-to-image requests, ensure wait=true query parameter is set
@@ -310,3 +324,85 @@ func (c *Ctrl) EnsureStreamOptions(body []byte) ([]byte, error) {
310324

311325
return modifiedBody, nil
312326
}
327+
328+
// EnforceConfiguredModel ensures that requests use the configured model from the service config.
329+
// This prevents users from requesting more expensive models while paying for cheaper ones.
330+
//
331+
// Security rationale:
332+
// - Provider advertises a specific model in the service configuration
333+
// - Pricing is based on that specific model
334+
// - Allowing users to change the model could result in:
335+
// 1. Provider paying more to backend service than they charge users
336+
// 2. Users getting access to premium models at cheaper prices
337+
//
338+
// This function forcibly overwrites any "model" field in the request body with the
339+
// configured model from c.Service.ModelType.
340+
func (c *Ctrl) EnforceConfiguredModel(body []byte, userAddr string) ([]byte, error) {
341+
// Return original body if empty (e.g., GET requests)
342+
if len(body) == 0 {
343+
return body, nil
344+
}
345+
346+
// Return original body if no model is configured
347+
if c.Service.ModelType == "" {
348+
c.logger.Warnf("Model enforcement skipped: c.Service.ModelType is empty (Type=%s)", c.Service.Type)
349+
return body, nil
350+
}
351+
352+
// Debug log to verify configuration
353+
c.logger.Debugf("EnforceConfiguredModel: Service.Type=%s, Service.ModelType=%s",
354+
c.Service.Type, c.Service.ModelType)
355+
356+
var bodyMap map[string]interface{}
357+
358+
err := json.Unmarshal(body, &bodyMap)
359+
if err != nil {
360+
// Return original body for non-JSON requests
361+
return body, nil
362+
}
363+
364+
// Check if request contains a model field
365+
requestModel, hasModel := bodyMap["model"]
366+
if !hasModel {
367+
// No model specified, add the configured model
368+
c.logger.Infof("No model specified in request, adding configured model: %s", c.Service.ModelType)
369+
bodyMap["model"] = c.Service.ModelType
370+
} else {
371+
// Model specified in request, check if it matches configured model
372+
requestModelStr, ok := requestModel.(string)
373+
if !ok {
374+
// Invalid model type, reject request
375+
return nil, errors.New(fmt.Sprintf("invalid model type in request (expected string), configured model is: %s", c.Service.ModelType))
376+
}
377+
378+
if requestModelStr != c.Service.ModelType {
379+
// Model mismatch detected - record in rate limiter and REJECT
380+
c.logger.Warnf("Model mismatch detected and REJECTED: user=%s, requested=%s, configured=%s",
381+
userAddr, requestModelStr, c.Service.ModelType)
382+
383+
// Record this attempt in rate limiter if user address is available
384+
if userAddr != "" {
385+
rateLimiter := GetRateLimiter()
386+
shouldBlock, blockedUntil := rateLimiter.RecordModelMismatch(userAddr)
387+
if shouldBlock {
388+
c.logger.Warnf("User will be blocked due to excessive model mismatch: user=%s, blocked_until=%s",
389+
userAddr, blockedUntil.Format("2006-01-02 15:04:05"))
390+
}
391+
}
392+
393+
return nil, errors.New(fmt.Sprintf("model not supported: requested '%s', only '%s' is available for this service",
394+
requestModelStr, c.Service.ModelType))
395+
}
396+
397+
// Model matches - log for audit
398+
c.logger.Debugf("Model validation passed: requested=%s matches configured=%s", requestModelStr, c.Service.ModelType)
399+
}
400+
401+
// Marshal back to JSON
402+
modifiedBody, err := json.Marshal(bodyMap)
403+
if err != nil {
404+
return body, errors.Wrap(err, "failed to marshal modified JSON body after enforcing model")
405+
}
406+
407+
return modifiedBody, nil
408+
}
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
package ctrl
2+
3+
import (
4+
"sync"
5+
"time"
6+
)
7+
8+
// ModelMismatchLimiter tracks model mismatch attempts for users
9+
// Similar to common/middleware/RateLimiter but with count-based blocking instead of token bucket
10+
type ModelMismatchLimiter struct {
11+
mu sync.RWMutex
12+
users map[string]*UserMismatchInfo
13+
// Configuration
14+
limit int // Max mismatches allowed
15+
window time.Duration // Time window for counting mismatches
16+
block time.Duration // Block duration after exceeding limit
17+
}
18+
19+
// UserMismatchInfo stores model mismatch information for a user
20+
type UserMismatchInfo struct {
21+
Count int // Number of model mismatch attempts
22+
LastAttempt time.Time // Last time model mismatch occurred
23+
BlockedUntil time.Time // Time until which user is blocked
24+
}
25+
26+
var (
27+
globalMismatchLimiter *ModelMismatchLimiter
28+
mismatchLimiterOnce sync.Once
29+
)
30+
31+
// GetRateLimiter returns the global model mismatch limiter instance
32+
func GetRateLimiter() *ModelMismatchLimiter {
33+
mismatchLimiterOnce.Do(func() {
34+
globalMismatchLimiter = &ModelMismatchLimiter{
35+
users: make(map[string]*UserMismatchInfo),
36+
limit: 5, // Max 5 mismatches
37+
window: 5 * time.Minute, // Within 5 minutes
38+
block: 1 * time.Hour, // Block for 1 hour
39+
}
40+
// Start cleanup goroutine (similar to common/middleware/RateLimiter)
41+
go globalMismatchLimiter.cleanup()
42+
})
43+
return globalMismatchLimiter
44+
}
45+
46+
// RecordModelMismatch records a model mismatch attempt for a user
47+
// Returns true if the user should be blocked
48+
func (ml *ModelMismatchLimiter) RecordModelMismatch(userAddr string) (shouldBlock bool, blockedUntil time.Time) {
49+
ml.mu.Lock()
50+
defer ml.mu.Unlock()
51+
52+
now := time.Now()
53+
54+
// Get or create user info
55+
info, exists := ml.users[userAddr]
56+
if !exists {
57+
info = &UserMismatchInfo{}
58+
ml.users[userAddr] = info
59+
}
60+
61+
// Check if user is already blocked
62+
if now.Before(info.BlockedUntil) {
63+
return true, info.BlockedUntil
64+
}
65+
66+
// Reset count if outside the time window
67+
if now.Sub(info.LastAttempt) > ml.window {
68+
info.Count = 0
69+
}
70+
71+
// Increment count
72+
info.Count++
73+
info.LastAttempt = now
74+
75+
// Check if user should be blocked
76+
if info.Count >= ml.limit {
77+
info.BlockedUntil = now.Add(ml.block)
78+
return true, info.BlockedUntil
79+
}
80+
81+
return false, time.Time{}
82+
}
83+
84+
// IsBlocked checks if a user is currently blocked
85+
func (ml *ModelMismatchLimiter) IsBlocked(userAddr string) (blocked bool, blockedUntil time.Time) {
86+
ml.mu.RLock()
87+
defer ml.mu.RUnlock()
88+
89+
info, exists := ml.users[userAddr]
90+
if !exists {
91+
return false, time.Time{}
92+
}
93+
94+
now := time.Now()
95+
if now.Before(info.BlockedUntil) {
96+
return true, info.BlockedUntil
97+
}
98+
99+
return false, time.Time{}
100+
}
101+
102+
// cleanup removes old entries periodically (similar to common/middleware/RateLimiter.cleanupVisitors)
103+
func (ml *ModelMismatchLimiter) cleanup() {
104+
ticker := time.NewTicker(10 * time.Minute)
105+
defer ticker.Stop()
106+
107+
for range ticker.C {
108+
ml.mu.Lock()
109+
now := time.Now()
110+
for addr, info := range ml.users {
111+
// Remove entries that are old (> 24 hours) and not blocked
112+
if now.Sub(info.LastAttempt) > 24*time.Hour && now.After(info.BlockedUntil) {
113+
delete(ml.users, addr)
114+
}
115+
}
116+
ml.mu.Unlock()
117+
}
118+
}

api/inference/internal/proxy/proxy.go

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package proxy
22

33
import (
4+
"fmt"
45
"io"
56
"net/http"
67
"strings"
78
"sync"
9+
"time"
810

911
"github.com/gin-contrib/cors"
1012
"github.com/google/uuid"
@@ -172,16 +174,42 @@ func (p *Proxy) proxyHTTPRequest(ctx *gin.Context) {
172174

173175
// handle endpoints not need to be charged
174176
if _, ok := constant.TargetRoute[targetPath]; !ok {
177+
// Check if this is a signature endpoint with special handling (targetSeparated=false)
178+
// handleSignatureRoute returns true if it handled the request from broker cache
179+
// returns false if it should be forwarded to backend (targetSeparated=true)
175180
if p.handleSignatureRoute(ctx, targetPath) {
176181
return
177182
}
178183

179-
httpReq, err := p.ctrl.PrepareHTTPRequest(ctx, targetURL, reqBody, svcType)
180-
if err != nil {
181-
p.handleBrokerError(ctx, err, "prepare HTTP request")
184+
// Check if this path matches any free prefixes (attestation, signature, etc.)
185+
isFree := false
186+
for _, prefix := range constant.FreePrefixes {
187+
if strings.HasPrefix(strings.ToLower(targetPath), prefix) {
188+
isFree = true
189+
break
190+
}
191+
}
192+
193+
if isFree {
194+
// Log free endpoint access for audit purposes
195+
p.logger.Infof("Free endpoint access: path=%s, method=%s, remote=%s, user_agent=%s",
196+
targetPath, ctx.Request.Method, ctx.Request.RemoteAddr, ctx.Request.UserAgent())
197+
198+
httpReq, err := p.ctrl.PrepareHTTPRequest(ctx, targetURL, reqBody, svcType)
199+
if err != nil {
200+
p.handleBrokerError(ctx, err, "prepare HTTP request")
201+
return
202+
}
203+
p.ctrl.ProcessHTTPRequest(ctx, svcType, httpReq, model.Request{}, "0", false)
182204
return
183205
}
184-
p.ctrl.ProcessHTTPRequest(ctx, svcType, httpReq, model.Request{}, "0", false)
206+
207+
// Reject all other endpoints that are not in TargetRoute or FreePrefixes
208+
// This prevents unauthorized access to unknown endpoints
209+
p.logger.Warnf("Blocked unsupported endpoint: path=%s, method=%s, remote=%s, user_agent=%s",
210+
targetPath, ctx.Request.Method, ctx.Request.RemoteAddr, ctx.Request.UserAgent())
211+
ctx.Set("ignoreError", true)
212+
p.handleBrokerError(ctx, errors.New("endpoint not supported"), "unsupported endpoint")
185213
return
186214
}
187215

@@ -193,6 +221,24 @@ func (p *Proxy) proxyHTTPRequest(ctx *gin.Context) {
193221
return
194222
}
195223

224+
// Store user address in context for rate limiting
225+
ctx.Set("userAddress", userAddress)
226+
227+
// Check if user is rate-limited due to excessive model mismatch attempts
228+
rateLimiter := ctrl.GetRateLimiter()
229+
if blocked, blockedUntil := rateLimiter.IsBlocked(userAddress); blocked {
230+
// User is blocked - return error immediately without processing
231+
ctx.Set("ignoreError", true)
232+
remainingTime := blockedUntil.Sub(time.Now())
233+
p.logger.Warnf("User blocked due to excessive model mismatch attempts: user=%s, blocked_until=%s, remaining=%v",
234+
userAddress, blockedUntil.Format(time.RFC3339), remainingTime)
235+
236+
ctx.JSON(http.StatusTooManyRequests, gin.H{
237+
"error": fmt.Sprintf("Rate limit exceeded: too many invalid model requests. Please try again in %v", remainingTime.Round(time.Minute)),
238+
})
239+
return
240+
}
241+
196242
// Record unique user for DAU tracking
197243
monitor.RecordUniqueUser(userAddress)
198244

0 commit comments

Comments
 (0)