diff --git a/api/handler/repo.go b/api/handler/repo.go index a3064c018..f9261db30 100644 --- a/api/handler/repo.go +++ b/api/handler/repo.go @@ -1013,9 +1013,10 @@ func (h *RepoHandler) SDKListFiles(ctx *gin.Context) { if mappedBranch != "" { ref = mappedBranch } + repoType := common.RepoTypeFromContext(ctx) expand := ctx.Query("expand") if expand == "xetEnabled" { - resp, err := h.c.IsXnetEnabled(ctx.Request.Context(), types.ModelRepo, namespace, name, currentUser) + resp, err := h.c.IsXnetEnabled(ctx.Request.Context(), repoType, namespace, name, currentUser) if err != nil { slog.ErrorContext(ctx.Request.Context(), "failed to check if xnetEnabled", slog.Any("error", err)) httpbase.ServerError(ctx, err) @@ -1025,15 +1026,15 @@ func (h *RepoHandler) SDKListFiles(ctx *gin.Context) { return } - files, err := h.c.SDKListFiles(ctx.Request.Context(), common.RepoTypeFromContext(ctx), namespace, name, ref, currentUser) + files, err := h.c.SDKListFiles(ctx.Request.Context(), repoType, namespace, name, ref, currentUser) if err != nil { if errors.Is(err, errorx.ErrUnauthorized) { - slog.ErrorContext(ctx.Request.Context(), "permission denied when accessing repo", slog.String("repo_type", string(common.RepoTypeFromContext(ctx))), slog.Any("path", fmt.Sprintf("%s/%s", namespace, name))) + slog.ErrorContext(ctx.Request.Context(), "permission denied when accessing repo", slog.String("repo_type", string(repoType)), slog.Any("path", fmt.Sprintf("%s/%s", namespace, name))) httpbase.UnauthorizedError(ctx, err) return } if errors.Is(err, errorx.ErrNotFound) { - slog.ErrorContext(ctx.Request.Context(), "repo not found", slog.String("repo_type", string(common.RepoTypeFromContext(ctx))), slog.Any("path", fmt.Sprintf("%s/%s", namespace, name))) + slog.ErrorContext(ctx.Request.Context(), "repo not found", slog.String("repo_type", string(repoType)), slog.Any("path", fmt.Sprintf("%s/%s", namespace, name))) httpbase.NotFoundError(ctx, err) return } @@ -1120,7 +1121,7 @@ func (h *RepoHandler) HeadSDKDownload(ctx *gin.Context) { slog.Debug("Head download repo file succeed", slog.String("repo_type", string(req.RepoType)), slog.String("name", name), slog.String("path", req.Path), slog.String("ref", req.Ref), slog.Int64("contentLength", file.Size)) if file.Lfs && file.XnetEnabled { ctx.Header("X-Xet-Hash", file.LfsSHA256) - ctx.Header("X-Xet-Refresh-Route", h.xetRefreshRoute(namespace, name, branch)) + ctx.Header("X-Xet-Refresh-Route", h.xetRefreshRoute(req.RepoType, namespace, name, branch)) } ctx.Header("Content-Length", strconv.Itoa(int(file.Size))) ctx.Header("X-Repo-Commit", repoCommit) @@ -1128,8 +1129,8 @@ func (h *RepoHandler) HeadSDKDownload(ctx *gin.Context) { ctx.Status(http.StatusOK) } -func (h *RepoHandler) xetRefreshRoute(namespace, name, ref string) string { - return fmt.Sprintf("%s/hf/%s/%s/xet-write-token/%s", h.config.Model.DownloadEndpoint, namespace, name, ref) +func (h *RepoHandler) xetRefreshRoute(repoType types.RepositoryType, namespace, name, ref string) string { + return fmt.Sprintf("%s/hf/%ss/%s/%s/xet-write-token/%s", h.config.Model.DownloadEndpoint, repoType, namespace, name, ref) } func (h *RepoHandler) handleDownload(ctx *gin.Context, isResolve bool) {