Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 16 additions & 14 deletions api/middleware/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,31 @@ func RepoMapping(repo_type types.RepositoryType) gin.HandlerFunc {
ctx.Next()
return
}
mirror, err := mirrorStore.FindWithMapping(ctx, repo_type, namespace, name, mapping)
repo, err := mirrorStore.FindWithMapping(ctx, repo_type, namespace, name, mapping)
//if found mirror, that means this is a synced source, otherwise it's may a user-upload repo
if err == nil {
repo_id := strings.Split(mirror.Repository.Path, "/")
namespace, name = repo.NamespaceAndName()
//set the real namespace, the name was unchange
slog.Info("namespace changed: ", "namespace", repo_id[0])
ctx.Set("namespace_mapped", repo_id[0])
ctx.Set("name_mapped", repo_id[1])
// for modelscope, the default branch is master, we should mapp it to real branch
if (branch == "main" || branch == "master") && mirror.Repository.DefaultBranch != branch {
ctx.Set("branch_mapped", mirror.Repository.DefaultBranch)
slog.Info("namespace changed: ", "namespace", namespace)
ctx.Set("namespace_mapped", namespace)
ctx.Set("name_mapped", name)
// for modelscope, the default branch is master, we should map it to real branch
if (branch == "main" || branch == "master") && repo.DefaultBranch != branch {
ctx.Set("branch_mapped", repo.DefaultBranch)
}
ctx.Next()
return
}
ctx.Next()
}
}

func GetMapping(ctx *gin.Context) types.Mapping {
rawRp := ctx.Query("mirror")
if rawRp == "" {
return types.AutoMapping
fullPath := ctx.FullPath()
if strings.HasPrefix(fullPath, "/hf/") {
return types.HFMapping
}
if strings.HasPrefix(fullPath, "/ms/") {
return types.ModelScopeMapping
}
return types.Mapping(rawRp)
//csg
return types.CSGHubMapping
}
12 changes: 8 additions & 4 deletions api/router/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,12 @@ func NewRouter(config *config.Config, enableSwagger bool) (*gin.Engine, error) {
if err != nil {
return nil, fmt.Errorf("error creating HF dataset handler: %w", err)
}

createHFRoutes(r, hfdsHandler, repoCommonHandler, modelHandler, userHandler)
//create routes for hf
createMappingRoutes(r, "/hf", hfdsHandler, repoCommonHandler, modelHandler, userHandler)
//create routes for ms
createMappingRoutes(r, "/ms", hfdsHandler, repoCommonHandler, modelHandler, userHandler)
//create routes for csg
createMappingRoutes(r, "/csg", hfdsHandler, repoCommonHandler, modelHandler, userHandler)

apiGroup := r.Group("/api/v1")
// TODO:use middleware to handle common response
Expand Down Expand Up @@ -758,9 +762,9 @@ func createAccountRoutes(apiGroup *gin.RouterGroup, needAPIKey gin.HandlerFunc,
}
}

func createHFRoutes(r *gin.Engine, hfdsHandler *handler.HFDatasetHandler, repoCommonHandler *handler.RepoHandler, modelHandler *handler.ModelHandler, userHandler *handler.UserHandler) {
func createMappingRoutes(r *gin.Engine, group string, hfdsHandler *handler.HFDatasetHandler, repoCommonHandler *handler.RepoHandler, modelHandler *handler.ModelHandler, userHandler *handler.UserHandler) {
// Huggingface SDK routes
hfGroup := r.Group("/hf")
hfGroup := r.Group(group)
{
hfGroup.GET("/:namespace/:name/resolve/:branch/*file_path", middleware.RepoMapping(types.ModelRepo), repoCommonHandler.SDKDownload)
hfGroup.HEAD("/:namespace/:name/resolve/:branch/*file_path", middleware.RepoMapping(types.ModelRepo), repoCommonHandler.HeadSDKDownload)
Expand Down
48 changes: 18 additions & 30 deletions builder/store/database/mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type MirrorStore interface {
FindByRepoID(ctx context.Context, repoID int64) (*Mirror, error)
FindByID(ctx context.Context, ID int64) (*Mirror, error)
FindByRepoPath(ctx context.Context, repoType types.RepositoryType, namespace, name string) (*Mirror, error)
FindWithMapping(ctx context.Context, repoType types.RepositoryType, namespace, name string, mapping types.Mapping) (*Mirror, error)
FindWithMapping(ctx context.Context, repoType types.RepositoryType, namespace, name string, mapping types.Mapping) (*Repository, error)
Create(ctx context.Context, mirror *Mirror) (*Mirror, error)
WithPagination(ctx context.Context) ([]Mirror, error)
WithPaginationWithRepository(ctx context.Context) ([]Mirror, error)
Expand Down Expand Up @@ -141,37 +141,25 @@ func (s *mirrorStoreImpl) FindByRepoPath(ctx context.Context, repoType types.Rep
return &mirror, nil
}

func (s *mirrorStoreImpl) FindWithMapping(ctx context.Context, repoType types.RepositoryType, namespace, name string, mapping types.Mapping) (*Mirror, error) {
var mirror Mirror
var err error
if mapping == types.CSGHubMapping {
return s.FindByRepoPath(ctx, repoType, namespace, name)
} else if mapping == types.HFMapping {
err = s.db.Operator.Core.NewSelect().
Model(&mirror).
Relation("Repository").
Where("mirror.source_repo_path=?", fmt.Sprintf("%s/%s", namespace, name)).
Where("repository.repository_type=?", repoType).
Scan(ctx)
func (s *mirrorStoreImpl) FindWithMapping(ctx context.Context, repoType types.RepositoryType, namespace, name string, mapping types.Mapping) (*Repository, error) {
resRepo := new(Repository)
query := s.db.Operator.Core.
NewSelect().
Model(resRepo)
path := fmt.Sprintf("%s/%s", namespace, name)
query.Where("repository_type = ?", repoType)
if mapping == types.HFMapping {
//compatiebility with old data
//TODO: remove path after sdk 0.4.6
query.Where("hf_path = ? or path = ?", path, path)
} else if mapping == types.ModelScopeMapping {
query.Where("ms_path = ?", path, path)
} else {
// auto mapping
//fix some repo id has mirror but it's not public,for example: https://opencsg.com/models/Qwen/Qwen_Qwen2-7B-Instruct
exist, _ := s.IsRepoExist(ctx, repoType, namespace, name)
if exist {
// no need mapping if repo id already exists in reporitory
return nil, fmt.Errorf("repo already exists, no need mapping")
}
err = s.db.Operator.Core.NewSelect().
Model(&mirror).
Relation("Repository").
Where("mirror.source_repo_path=?", fmt.Sprintf("%s/%s", namespace, name)).
Where("repository.repository_type=?", repoType).
Scan(ctx)
// for csg path
query.Where("path = ?", path)
}
if err != nil {
return nil, err
}
return &mirror, nil
err := query.Limit(1).Scan(ctx)
return resRepo, err
}

func (s *mirrorStoreImpl) Create(ctx context.Context, mirror *Mirror) (*Mirror, error) {
Expand Down
33 changes: 12 additions & 21 deletions builder/store/database/mirror_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package database_test

import (
"context"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -129,42 +128,34 @@ func TestMirrorStore_FindWithMapping(t *testing.T) {
store := database.NewMirrorStoreWithDB(db)

repos := []*database.Repository{
{Name: "repo1", RepositoryType: types.ModelRepo, Path: "models_ns/repo1"},
{Name: "repo2", RepositoryType: types.DatasetRepo, Path: "datasets_ns/repo2"},
{Name: "repo3", RepositoryType: types.PromptRepo, Path: "prompts_ns/repo3"},
{Name: "repo1", RepositoryType: types.ModelRepo, Path: "ns/repo1", HFPath: "hf/repo1"},
{Name: "repo2", RepositoryType: types.DatasetRepo, Path: "ns/repo2", MSPath: "ms/repo2"},
{Name: "repo3", RepositoryType: types.PromptRepo, Path: "ns/repo3"},
}

for _, repo := range repos {
repo.GitPath = repo.Path
err := db.Core.NewInsert().Model(repo).Scan(ctx, repo)
require.Nil(t, err)
sp := strings.Split(repo.Path, "_")
_, err = store.Create(ctx, &database.Mirror{
RepositoryID: repo.ID,
SourceRepoPath: strings.ReplaceAll(sp[1], "ns/", "nsn/"),
Interval: repo.Name,
})
require.Nil(t, err)
}

mi, err := store.FindWithMapping(ctx, types.ModelRepo, "ns", "repo1", types.CSGHubMapping)
require.Nil(t, err)
require.Equal(t, "repo1", mi.Interval)

_, err = store.FindWithMapping(ctx, types.ModelRepo, "ns", "repo1", types.HFMapping)
require.NotNil(t, err)
require.Equal(t, "repo1", mi.Name)

mi, err = store.FindWithMapping(ctx, types.ModelRepo, "nsn", "repo1", types.HFMapping)
_, err = store.FindWithMapping(ctx, types.ModelRepo, "hf", "repo1", types.HFMapping)
require.Nil(t, err)
require.Equal(t, "repo1", mi.Interval)

mi, err = store.FindWithMapping(ctx, types.DatasetRepo, "nsn", "repo2", types.HFMapping)
_, err = store.FindWithMapping(ctx, types.ModelRepo, "aaa", "repo1", types.HFMapping)
require.NotNil(t, err)

mi, err = store.FindWithMapping(ctx, types.DatasetRepo, "ms", "repo2", types.ModelScopeMapping)
require.Nil(t, err)
require.Equal(t, "repo2", mi.Interval)
require.Equal(t, "repo2", mi.Name)

mi, err = store.FindWithMapping(ctx, types.PromptRepo, "nsn", "repo3", types.AutoMapping)
mi, err = store.FindWithMapping(ctx, types.PromptRepo, "ns", "repo3", types.CSGHubMapping)
require.Nil(t, err)
require.Equal(t, "repo3", mi.Interval)
require.Equal(t, "repo3", mi.Name)
}

func TestMirrorStore_ToSync(t *testing.T) {
Expand Down
Loading
Loading