Skip to content

Commit e80a703

Browse files
authored
feat: 为 Provider 添加 SupportModels 模型过滤功能 (#121)
1 parent 16e1fa0 commit e80a703

File tree

7 files changed

+152
-5
lines changed

7 files changed

+152
-5
lines changed

internal/domain/model.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ type Provider struct {
9696

9797
// 支持的 Client
9898
SupportedClientTypes []ClientType `json:"supportedClientTypes"`
99+
100+
// 支持的模型列表(通配符模式)
101+
// 如果配置了,在 Route 匹配时会检查前置映射后的模型是否在支持列表中
102+
// 空数组表示支持所有模型
103+
SupportModels []string `json:"supportModels,omitempty"`
99104
}
100105

101106
type Project struct {

internal/executor/executor.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,12 @@ func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http
166166
}
167167

168168
// Match routes
169-
routes, err := e.router.Match(clientType, projectID)
169+
routes, err := e.router.Match(&router.MatchContext{
170+
ClientType: clientType,
171+
ProjectID: projectID,
172+
RequestModel: requestModel,
173+
APITokenID: apiTokenID,
174+
})
170175
if err != nil {
171176
proxyReq.Status = "FAILED"
172177
proxyReq.Error = "no routes available"
@@ -258,6 +263,7 @@ func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http
258263
}
259264

260265
// Determine model mapping
266+
// Model mapping is done in Executor after Router has filtered by SupportModels
261267
clientType := ctxutil.GetClientType(ctx)
262268
mappedModel := e.mapModel(requestModel, matchedRoute.Route, matchedRoute.Provider, clientType, projectID, apiTokenID)
263269
ctx = ctxutil.WithMappedModel(ctx, mappedModel)

internal/repository/sqlite/models.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ type Provider struct {
118118
Name string `gorm:"not null"`
119119
Config string `gorm:"type:longtext"`
120120
SupportedClientTypes string `gorm:"type:text"`
121+
SupportModels string `gorm:"type:text"`
121122
}
122123

123124
func (Provider) TableName() string { return "providers" }

internal/repository/sqlite/provider.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ func (r *ProviderRepository) toModel(p *domain.Provider) *Provider {
8484
Name: p.Name,
8585
Config: toJSON(p.Config),
8686
SupportedClientTypes: toJSON(p.SupportedClientTypes),
87+
SupportModels: toJSON(p.SupportModels),
8788
}
8889
}
8990

@@ -98,5 +99,6 @@ func (r *ProviderRepository) toDomain(m *Provider) *domain.Provider {
9899
Name: m.Name,
99100
Config: fromJSON[*domain.ProviderConfig](m.Config),
100101
SupportedClientTypes: fromJSON[[]domain.ClientType](m.SupportedClientTypes),
102+
SupportModels: fromJSON[[]string](m.SupportModels),
101103
}
102104
}

internal/router/router.go

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ type MatchedRoute struct {
1919
RetryConfig *domain.RetryConfig
2020
}
2121

22+
// MatchContext contains all context needed for route matching
23+
type MatchContext struct {
24+
ClientType domain.ClientType
25+
ProjectID uint64
26+
RequestModel string
27+
APITokenID uint64
28+
}
29+
2230
// Router handles route matching and selection
2331
type Router struct {
2432
routeRepo *cached.RouteRepository
@@ -98,7 +106,11 @@ func (r *Router) RemoveAdapter(providerID uint64) {
98106
}
99107

100108
// Match returns matched routes for a client type and project
101-
func (r *Router) Match(clientType domain.ClientType, projectID uint64) ([]*MatchedRoute, error) {
109+
func (r *Router) Match(ctx *MatchContext) ([]*MatchedRoute, error) {
110+
clientType := ctx.ClientType
111+
projectID := ctx.ProjectID
112+
requestModel := ctx.RequestModel
113+
102114
routes := r.routeRepo.GetAll()
103115

104116
// Check if ClientType has custom routes enabled for this project
@@ -175,7 +187,7 @@ func (r *Router) Match(clientType domain.ClientType, projectID uint64) ([]*Match
175187
providers := r.providerRepo.GetAll()
176188

177189
for _, route := range filtered {
178-
provider, ok := providers[route.ProviderID]
190+
prov, ok := providers[route.ProviderID]
179191
if !ok {
180192
continue
181193
}
@@ -190,6 +202,15 @@ func (r *Router) Match(clientType domain.ClientType, projectID uint64) ([]*Match
190202
continue
191203
}
192204

205+
// Check if provider supports the request model
206+
// SupportModels check is done BEFORE mapping
207+
// If SupportModels is configured, check if the request model is supported
208+
if len(prov.SupportModels) > 0 && requestModel != "" {
209+
if !r.isModelSupported(requestModel, prov.SupportModels) {
210+
continue
211+
}
212+
}
213+
193214
var retryConfig *domain.RetryConfig
194215
if route.RetryConfigID != 0 {
195216
retryConfig, _ = r.retryConfigRepo.GetByID(route.RetryConfigID)
@@ -200,7 +221,7 @@ func (r *Router) Match(clientType domain.ClientType, projectID uint64) ([]*Match
200221

201222
matched = append(matched, &MatchedRoute{
202223
Route: route,
203-
Provider: provider,
224+
Provider: prov,
204225
ProviderAdapter: adp,
205226
RetryConfig: retryConfig,
206227
})
@@ -213,6 +234,16 @@ func (r *Router) Match(clientType domain.ClientType, projectID uint64) ([]*Match
213234
return matched, nil
214235
}
215236

237+
// isModelSupported checks if a model matches any pattern in the support list
238+
func (r *Router) isModelSupported(model string, supportModels []string) bool {
239+
for _, pattern := range supportModels {
240+
if domain.MatchWildcard(pattern, model) {
241+
return true
242+
}
243+
}
244+
return false
245+
}
246+
216247
func (r *Router) getRoutingStrategy(projectID uint64) *domain.RoutingStrategy {
217248
// Try project-specific strategy first
218249
if projectID != 0 {

web/src/lib/transport/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ export interface Provider {
4949
logo?: string; // Logo URL or data URI
5050
config: ProviderConfig | null;
5151
supportedClientTypes: ClientType[];
52+
supportModels?: string[]; // 支持的模型列表(通配符模式),空数组表示支持所有模型
5253
}
5354

5455
// supportedClientTypes 可选,后端会根据 provider type 自动设置
@@ -57,6 +58,7 @@ export type CreateProviderData = Omit<
5758
'id' | 'createdAt' | 'updatedAt' | 'supportedClientTypes'
5859
> & {
5960
supportedClientTypes?: ClientType[];
61+
supportModels?: string[];
6062
};
6163

6264
// ===== Project =====

web/src/pages/providers/components/provider-edit-flow.tsx

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { useState, useMemo } from 'react'
2-
import { Globe, ChevronLeft, Key, Check, Trash2, Plus, ArrowRight, Zap } from 'lucide-react'
2+
import { Globe, ChevronLeft, Key, Check, Trash2, Plus, ArrowRight, Zap, Filter } from 'lucide-react'
33
import { useTranslation } from 'react-i18next'
44
import {
55
Dialog,
@@ -167,6 +167,97 @@ function ProviderModelMappings({ provider }: { provider: Provider }) {
167167
)
168168
}
169169

170+
// Provider Supported Models Section
171+
function ProviderSupportModels({
172+
supportModels,
173+
onChange,
174+
}: {
175+
supportModels: string[]
176+
onChange: (models: string[]) => void
177+
}) {
178+
const { t } = useTranslation()
179+
const [newModel, setNewModel] = useState('')
180+
181+
const handleAddModel = () => {
182+
if (!newModel.trim()) return
183+
const trimmedModel = newModel.trim()
184+
if (!supportModels.includes(trimmedModel)) {
185+
onChange([...supportModels, trimmedModel])
186+
}
187+
setNewModel('')
188+
}
189+
190+
const handleRemoveModel = (model: string) => {
191+
onChange(supportModels.filter(m => m !== model))
192+
}
193+
194+
return (
195+
<div>
196+
<div className="flex items-center gap-2 mb-4 border-b border-border pb-2">
197+
<Filter size={18} className="text-blue-500" />
198+
<h4 className="text-lg font-semibold text-foreground">
199+
{t('providers.supportModels.title', 'Supported Models')}
200+
</h4>
201+
<span className="text-sm text-muted-foreground">
202+
({supportModels.length})
203+
</span>
204+
</div>
205+
206+
<div className="bg-card border border-border rounded-xl p-4">
207+
<p className="text-xs text-muted-foreground mb-4">
208+
{t('providers.supportModels.desc', 'Configure which models this provider supports. If empty, all models are supported. Supports wildcards like claude-* or gemini-*.')}
209+
</p>
210+
211+
{supportModels.length > 0 && (
212+
<div className="flex flex-wrap gap-2 mb-4">
213+
{supportModels.map((model) => (
214+
<div
215+
key={model}
216+
className="flex items-center gap-1 bg-muted/50 border border-border rounded-lg px-3 py-1.5"
217+
>
218+
<span className="text-sm">{model}</span>
219+
<button
220+
type="button"
221+
onClick={() => handleRemoveModel(model)}
222+
className="text-muted-foreground hover:text-destructive ml-1"
223+
>
224+
<Trash2 className="h-3.5 w-3.5" />
225+
</button>
226+
</div>
227+
))}
228+
</div>
229+
)}
230+
231+
{supportModels.length === 0 && (
232+
<div className="text-center py-6 mb-4">
233+
<p className="text-muted-foreground text-sm">
234+
{t('providers.supportModels.empty', 'No model filter configured. All models will be supported.')}
235+
</p>
236+
</div>
237+
)}
238+
239+
<div className="flex items-center gap-2 pt-4 border-t border-border">
240+
<ModelInput
241+
value={newModel}
242+
onChange={setNewModel}
243+
placeholder={t('providers.supportModels.placeholder', 'e.g. claude-* or gemini-2.5-*')}
244+
className="flex-1 min-w-0 h-8 text-sm"
245+
/>
246+
<Button
247+
variant="outline"
248+
size="sm"
249+
onClick={handleAddModel}
250+
disabled={!newModel.trim()}
251+
>
252+
<Plus className="h-4 w-4 mr-1" />
253+
{t('common.add')}
254+
</Button>
255+
</div>
256+
</div>
257+
</div>
258+
)
259+
}
260+
170261
interface ProviderEditFlowProps {
171262
provider: Provider;
172263
onClose: () => void;
@@ -177,6 +268,7 @@ type EditFormData = {
177268
baseURL: string;
178269
apiKey: string;
179270
clients: ClientConfig[];
271+
supportModels: string[];
180272
};
181273

182274
export function ProviderEditFlow({ provider, onClose }: ProviderEditFlowProps) {
@@ -202,6 +294,7 @@ export function ProviderEditFlow({ provider, onClose }: ProviderEditFlowProps) {
202294
baseURL: provider.config?.custom?.baseURL || '',
203295
apiKey: provider.config?.custom?.apiKey || '',
204296
clients: initClients(),
297+
supportModels: provider.supportModels || [],
205298
});
206299

207300
const updateClient = (clientId: ClientType, updates: Partial<ClientConfig>) => {
@@ -245,6 +338,7 @@ export function ProviderEditFlow({ provider, onClose }: ProviderEditFlowProps) {
245338
},
246339
},
247340
supportedClientTypes,
341+
supportModels: formData.supportModels.length > 0 ? formData.supportModels : undefined,
248342
};
249343

250344
await updateProvider.mutateAsync({ id: Number(provider.id), data });
@@ -420,6 +514,12 @@ export function ProviderEditFlow({ provider, onClose }: ProviderEditFlowProps) {
420514
<ClientsConfigSection clients={formData.clients} onUpdateClient={updateClient} />
421515
</div>
422516

517+
{/* Provider Supported Models Filter */}
518+
<ProviderSupportModels
519+
supportModels={formData.supportModels}
520+
onChange={(models) => setFormData((prev) => ({ ...prev, supportModels: models }))}
521+
/>
522+
423523
{/* Provider Model Mappings */}
424524
<ProviderModelMappings provider={provider} />
425525

0 commit comments

Comments
 (0)