Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/maxx/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ func main() {
settingRepo,
cachedAPITokenRepo,
cachedModelMappingRepo,
modelPriceRepo,
r, // Router implements ProviderAdapterRefresher interface
)

Expand Down
1 change: 1 addition & 0 deletions internal/core/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ func InitializeServerComponents(
repos.SettingRepo,
repos.CachedAPITokenRepo,
repos.CachedModelMappingRepo,
repos.ModelPriceRepo,
r,
)

Expand Down
18 changes: 18 additions & 0 deletions internal/domain/backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type BackupData struct {
RoutingStrategies []BackupRoutingStrategy `json:"routingStrategies,omitempty"`
APITokens []BackupAPIToken `json:"apiTokens,omitempty"`
ModelMappings []BackupModelMapping `json:"modelMappings,omitempty"`
ModelPrices []BackupModelPrice `json:"modelPrices,omitempty"`
}

// BackupSystemSetting represents a system setting for backup
Expand All @@ -35,6 +36,7 @@ type BackupSystemSetting struct {
type BackupProvider struct {
Name string `json:"name"`
Type string `json:"type"`
Logo string `json:"logo,omitempty"`
Config *ProviderConfig `json:"config,omitempty"`
SupportedClientTypes []ClientType `json:"supportedClientTypes,omitempty"`
SupportModels []string `json:"supportModels,omitempty"`
Expand Down Expand Up @@ -100,6 +102,22 @@ type BackupModelMapping struct {
Priority int `json:"priority"`
}

// BackupModelPrice represents a model price for backup
type BackupModelPrice struct {
ModelID string `json:"modelId"`
InputPriceMicro uint64 `json:"inputPriceMicro"`
OutputPriceMicro uint64 `json:"outputPriceMicro"`
CacheReadPriceMicro uint64 `json:"cacheReadPriceMicro"`
Cache5mWritePriceMicro uint64 `json:"cache5mWritePriceMicro"`
Cache1hWritePriceMicro uint64 `json:"cache1hWritePriceMicro"`
Has1MContext bool `json:"has1mContext"`
Context1MThreshold uint64 `json:"context1mThreshold"`
InputPremiumNum uint64 `json:"inputPremiumNum"`
InputPremiumDenom uint64 `json:"inputPremiumDenom"`
OutputPremiumNum uint64 `json:"outputPremiumNum"`
OutputPremiumDenom uint64 `json:"outputPremiumDenom"`
}

// ImportOptions defines options for import operation
type ImportOptions struct {
ConflictStrategy string `json:"conflictStrategy"` // "skip", "overwrite", "error"
Expand Down
2 changes: 1 addition & 1 deletion internal/handler/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (h *AdminHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Provider handlers
func (h *AdminHandler) handleProviders(w http.ResponseWriter, r *http.Request, id uint64) {
// Check for special endpoints
path := r.URL.Path
path := strings.TrimSuffix(r.URL.Path, "/")
if strings.HasSuffix(path, "/export") {
h.handleProvidersExport(w, r)
return
Expand Down
150 changes: 150 additions & 0 deletions internal/handler/admin_import_export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package handler

import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

"github.com/awsl-project/maxx/internal/domain"
"github.com/awsl-project/maxx/internal/service"
)

type adminTestProviderRepo struct {
providers []*domain.Provider
}

func (r *adminTestProviderRepo) Create(provider *domain.Provider) error {
provider.ID = uint64(len(r.providers) + 1)
r.providers = append(r.providers, provider)
return nil
}

func (r *adminTestProviderRepo) Update(provider *domain.Provider) error {
for i, p := range r.providers {
if p.ID == provider.ID {
r.providers[i] = provider
return nil
}
}
return domain.ErrNotFound
}

func (r *adminTestProviderRepo) Delete(id uint64) error {
for i, p := range r.providers {
if p.ID == id {
r.providers = append(r.providers[:i], r.providers[i+1:]...)
return nil
}
}
return domain.ErrNotFound
}

func (r *adminTestProviderRepo) GetByID(id uint64) (*domain.Provider, error) {
for _, p := range r.providers {
if p.ID == id {
return p, nil
}
}
return nil, domain.ErrNotFound
}

func (r *adminTestProviderRepo) List() ([]*domain.Provider, error) {
cloned := make([]*domain.Provider, len(r.providers))
copy(cloned, r.providers)
return cloned, nil
}

func newAdminHandlerForProviderImportExportTests(providerRepo *adminTestProviderRepo) *AdminHandler {
adminSvc := service.NewAdminService(
providerRepo,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
"",
nil,
nil,
nil,
)

return NewAdminHandler(adminSvc, nil, "")
}

func TestAdminHandler_ProvidersImport_WithTrailingSlash(t *testing.T) {
providerRepo := &adminTestProviderRepo{}
h := newAdminHandlerForProviderImportExportTests(providerRepo)

body, err := json.Marshal([]map[string]any{{
"name": "imported-provider",
"type": "custom",
}})
if err != nil {
t.Fatalf("marshal request body: %v", err)
}

req := httptest.NewRequest(http.MethodPost, "/admin/providers/import/", bytes.NewReader(body))
rec := httptest.NewRecorder()

h.ServeHTTP(rec, req)

if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body = %s", rec.Code, http.StatusOK, rec.Body.String())
}

var result service.ImportResult
if err := json.Unmarshal(rec.Body.Bytes(), &result); err != nil {
t.Fatalf("decode response: %v", err)
}

if result.Imported != 1 {
t.Fatalf("imported = %d, want 1", result.Imported)
}
if len(providerRepo.providers) != 1 {
t.Fatalf("provider count = %d, want 1", len(providerRepo.providers))
}
}

func TestAdminHandler_ProvidersExport_WithTrailingSlash(t *testing.T) {
providerRepo := &adminTestProviderRepo{
providers: []*domain.Provider{{
ID: 1,
Name: "exported-provider",
Type: "custom",
}},
}
h := newAdminHandlerForProviderImportExportTests(providerRepo)

req := httptest.NewRequest(http.MethodGet, "/admin/providers/export/", nil)
rec := httptest.NewRecorder()

h.ServeHTTP(rec, req)

if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}

contentDisposition := rec.Header().Get("Content-Disposition")
if contentDisposition != "attachment; filename=providers.json" {
t.Fatalf("Content-Disposition = %q, want attachment header", contentDisposition)
}

var providers []domain.Provider
if err := json.Unmarshal(rec.Body.Bytes(), &providers); err != nil {
t.Fatalf("decode response: %v", err)
}

if len(providers) != 1 || providers[0].Name != "exported-provider" {
t.Fatalf("providers = %+v, want one exported-provider", providers)
}
}
1 change: 1 addition & 0 deletions internal/repository/sqlite/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ type Provider struct {
SoftDeleteModel
Type string `gorm:"size:64"`
Name string `gorm:"size:255"`
Logo LongText
Config LongText
SupportedClientTypes LongText
SupportModels LongText
Expand Down
2 changes: 2 additions & 0 deletions internal/repository/sqlite/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func (r *ProviderRepository) toModel(p *domain.Provider) *Provider {
},
Type: p.Type,
Name: p.Name,
Logo: LongText(p.Logo),
Config: LongText(toJSON(p.Config)),
SupportedClientTypes: LongText(toJSON(p.SupportedClientTypes)),
SupportModels: LongText(toJSON(p.SupportModels)),
Expand All @@ -97,6 +98,7 @@ func (r *ProviderRepository) toDomain(m *Provider) *domain.Provider {
DeletedAt: fromTimestampPtr(m.DeletedAt),
Type: m.Type,
Name: m.Name,
Logo: string(m.Logo),
Config: fromJSON[*domain.ProviderConfig](string(m.Config)),
SupportedClientTypes: fromJSON[[]domain.ClientType](string(m.SupportedClientTypes)),
SupportModels: fromJSON[[]string](string(m.SupportModels)),
Expand Down
Loading