Skip to content

Commit 2f2b9e6

Browse files
authored
fix: 完善导入导出配置回放还原 (#192)
1 parent 2d6d508 commit 2f2b9e6

File tree

10 files changed

+682
-17
lines changed

10 files changed

+682
-17
lines changed

cmd/maxx/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ func main() {
293293
settingRepo,
294294
cachedAPITokenRepo,
295295
cachedModelMappingRepo,
296+
modelPriceRepo,
296297
r, // Router implements ProviderAdapterRefresher interface
297298
)
298299

internal/core/database.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ func InitializeServerComponents(
342342
repos.SettingRepo,
343343
repos.CachedAPITokenRepo,
344344
repos.CachedModelMappingRepo,
345+
repos.ModelPriceRepo,
345346
r,
346347
)
347348

internal/domain/backup.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type BackupData struct {
2323
RoutingStrategies []BackupRoutingStrategy `json:"routingStrategies,omitempty"`
2424
APITokens []BackupAPIToken `json:"apiTokens,omitempty"`
2525
ModelMappings []BackupModelMapping `json:"modelMappings,omitempty"`
26+
ModelPrices []BackupModelPrice `json:"modelPrices,omitempty"`
2627
}
2728

2829
// BackupSystemSetting represents a system setting for backup
@@ -35,6 +36,7 @@ type BackupSystemSetting struct {
3536
type BackupProvider struct {
3637
Name string `json:"name"`
3738
Type string `json:"type"`
39+
Logo string `json:"logo,omitempty"`
3840
Config *ProviderConfig `json:"config,omitempty"`
3941
SupportedClientTypes []ClientType `json:"supportedClientTypes,omitempty"`
4042
SupportModels []string `json:"supportModels,omitempty"`
@@ -100,6 +102,22 @@ type BackupModelMapping struct {
100102
Priority int `json:"priority"`
101103
}
102104

105+
// BackupModelPrice represents a model price for backup
106+
type BackupModelPrice struct {
107+
ModelID string `json:"modelId"`
108+
InputPriceMicro uint64 `json:"inputPriceMicro"`
109+
OutputPriceMicro uint64 `json:"outputPriceMicro"`
110+
CacheReadPriceMicro uint64 `json:"cacheReadPriceMicro"`
111+
Cache5mWritePriceMicro uint64 `json:"cache5mWritePriceMicro"`
112+
Cache1hWritePriceMicro uint64 `json:"cache1hWritePriceMicro"`
113+
Has1MContext bool `json:"has1mContext"`
114+
Context1MThreshold uint64 `json:"context1mThreshold"`
115+
InputPremiumNum uint64 `json:"inputPremiumNum"`
116+
InputPremiumDenom uint64 `json:"inputPremiumDenom"`
117+
OutputPremiumNum uint64 `json:"outputPremiumNum"`
118+
OutputPremiumDenom uint64 `json:"outputPremiumDenom"`
119+
}
120+
103121
// ImportOptions defines options for import operation
104122
type ImportOptions struct {
105123
ConflictStrategy string `json:"conflictStrategy"` // "skip", "overwrite", "error"

internal/handler/admin.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func (h *AdminHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
102102
// Provider handlers
103103
func (h *AdminHandler) handleProviders(w http.ResponseWriter, r *http.Request, id uint64) {
104104
// Check for special endpoints
105-
path := r.URL.Path
105+
path := strings.TrimSuffix(r.URL.Path, "/")
106106
if strings.HasSuffix(path, "/export") {
107107
h.handleProvidersExport(w, r)
108108
return
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
package handler
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"net/http"
7+
"net/http/httptest"
8+
"testing"
9+
10+
"github.com/awsl-project/maxx/internal/domain"
11+
"github.com/awsl-project/maxx/internal/service"
12+
)
13+
14+
type adminTestProviderRepo struct {
15+
providers []*domain.Provider
16+
}
17+
18+
func (r *adminTestProviderRepo) Create(provider *domain.Provider) error {
19+
provider.ID = uint64(len(r.providers) + 1)
20+
r.providers = append(r.providers, provider)
21+
return nil
22+
}
23+
24+
func (r *adminTestProviderRepo) Update(provider *domain.Provider) error {
25+
for i, p := range r.providers {
26+
if p.ID == provider.ID {
27+
r.providers[i] = provider
28+
return nil
29+
}
30+
}
31+
return domain.ErrNotFound
32+
}
33+
34+
func (r *adminTestProviderRepo) Delete(id uint64) error {
35+
for i, p := range r.providers {
36+
if p.ID == id {
37+
r.providers = append(r.providers[:i], r.providers[i+1:]...)
38+
return nil
39+
}
40+
}
41+
return domain.ErrNotFound
42+
}
43+
44+
func (r *adminTestProviderRepo) GetByID(id uint64) (*domain.Provider, error) {
45+
for _, p := range r.providers {
46+
if p.ID == id {
47+
return p, nil
48+
}
49+
}
50+
return nil, domain.ErrNotFound
51+
}
52+
53+
func (r *adminTestProviderRepo) List() ([]*domain.Provider, error) {
54+
cloned := make([]*domain.Provider, len(r.providers))
55+
copy(cloned, r.providers)
56+
return cloned, nil
57+
}
58+
59+
func newAdminHandlerForProviderImportExportTests(providerRepo *adminTestProviderRepo) *AdminHandler {
60+
adminSvc := service.NewAdminService(
61+
providerRepo,
62+
nil,
63+
nil,
64+
nil,
65+
nil,
66+
nil,
67+
nil,
68+
nil,
69+
nil,
70+
nil,
71+
nil,
72+
nil,
73+
nil,
74+
nil,
75+
"",
76+
nil,
77+
nil,
78+
nil,
79+
)
80+
81+
return NewAdminHandler(adminSvc, nil, "")
82+
}
83+
84+
func TestAdminHandler_ProvidersImport_WithTrailingSlash(t *testing.T) {
85+
providerRepo := &adminTestProviderRepo{}
86+
h := newAdminHandlerForProviderImportExportTests(providerRepo)
87+
88+
body, err := json.Marshal([]map[string]any{{
89+
"name": "imported-provider",
90+
"type": "custom",
91+
}})
92+
if err != nil {
93+
t.Fatalf("marshal request body: %v", err)
94+
}
95+
96+
req := httptest.NewRequest(http.MethodPost, "/admin/providers/import/", bytes.NewReader(body))
97+
rec := httptest.NewRecorder()
98+
99+
h.ServeHTTP(rec, req)
100+
101+
if rec.Code != http.StatusOK {
102+
t.Fatalf("status = %d, want %d, body = %s", rec.Code, http.StatusOK, rec.Body.String())
103+
}
104+
105+
var result service.ImportResult
106+
if err := json.Unmarshal(rec.Body.Bytes(), &result); err != nil {
107+
t.Fatalf("decode response: %v", err)
108+
}
109+
110+
if result.Imported != 1 {
111+
t.Fatalf("imported = %d, want 1", result.Imported)
112+
}
113+
if len(providerRepo.providers) != 1 {
114+
t.Fatalf("provider count = %d, want 1", len(providerRepo.providers))
115+
}
116+
}
117+
118+
func TestAdminHandler_ProvidersExport_WithTrailingSlash(t *testing.T) {
119+
providerRepo := &adminTestProviderRepo{
120+
providers: []*domain.Provider{{
121+
ID: 1,
122+
Name: "exported-provider",
123+
Type: "custom",
124+
}},
125+
}
126+
h := newAdminHandlerForProviderImportExportTests(providerRepo)
127+
128+
req := httptest.NewRequest(http.MethodGet, "/admin/providers/export/", nil)
129+
rec := httptest.NewRecorder()
130+
131+
h.ServeHTTP(rec, req)
132+
133+
if rec.Code != http.StatusOK {
134+
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
135+
}
136+
137+
contentDisposition := rec.Header().Get("Content-Disposition")
138+
if contentDisposition != "attachment; filename=providers.json" {
139+
t.Fatalf("Content-Disposition = %q, want attachment header", contentDisposition)
140+
}
141+
142+
var providers []domain.Provider
143+
if err := json.Unmarshal(rec.Body.Bytes(), &providers); err != nil {
144+
t.Fatalf("decode response: %v", err)
145+
}
146+
147+
if len(providers) != 1 || providers[0].Name != "exported-provider" {
148+
t.Fatalf("providers = %+v, want one exported-provider", providers)
149+
}
150+
}

internal/repository/sqlite/models.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ type Provider struct {
6464
SoftDeleteModel
6565
Type string `gorm:"size:64"`
6666
Name string `gorm:"size:255"`
67+
Logo LongText
6768
Config LongText
6869
SupportedClientTypes LongText
6970
SupportModels LongText

internal/repository/sqlite/provider.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ func (r *ProviderRepository) toModel(p *domain.Provider) *Provider {
8282
},
8383
Type: p.Type,
8484
Name: p.Name,
85+
Logo: LongText(p.Logo),
8586
Config: LongText(toJSON(p.Config)),
8687
SupportedClientTypes: LongText(toJSON(p.SupportedClientTypes)),
8788
SupportModels: LongText(toJSON(p.SupportModels)),
@@ -97,6 +98,7 @@ func (r *ProviderRepository) toDomain(m *Provider) *domain.Provider {
9798
DeletedAt: fromTimestampPtr(m.DeletedAt),
9899
Type: m.Type,
99100
Name: m.Name,
101+
Logo: string(m.Logo),
100102
Config: fromJSON[*domain.ProviderConfig](string(m.Config)),
101103
SupportedClientTypes: fromJSON[[]domain.ClientType](string(m.SupportedClientTypes)),
102104
SupportModels: fromJSON[[]string](string(m.SupportModels)),

0 commit comments

Comments
 (0)