diff --git a/runner/cmd/runner/main.go b/runner/cmd/runner/main.go index a080246d4..27e529417 100644 --- a/runner/cmd/runner/main.go +++ b/runner/cmd/runner/main.go @@ -15,6 +15,7 @@ import ( "github.com/urfave/cli/v3" "github.com/dstackai/dstack/runner/consts" + "github.com/dstackai/dstack/runner/internal/executor" "github.com/dstackai/dstack/runner/internal/log" "github.com/dstackai/dstack/runner/internal/runner/api" "github.com/dstackai/dstack/runner/internal/ssh" @@ -162,7 +163,12 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss } }() - server, err := api.NewServer(ctx, tempDir, homeDir, dstackDir, sshd, fmt.Sprintf(":%d", httpPort), version) + ex, err := executor.NewRunExecutor(tempDir, homeDir, dstackDir, sshd) + if err != nil { + return fmt.Errorf("create executor: %w", err) + } + + server, err := api.NewServer(ctx, fmt.Sprintf(":%d", httpPort), version, ex) if err != nil { return fmt.Errorf("create server: %w", err) } diff --git a/runner/internal/executor/base.go b/runner/internal/executor/base.go index 554bd7646..4961180e9 100644 --- a/runner/internal/executor/base.go +++ b/runner/internal/executor/base.go @@ -13,7 +13,6 @@ type Executor interface { GetJobWsLogsHistory() []schemas.LogEvent GetRunnerState() string Run(ctx context.Context) error - SetCodePath(codePath string) SetJob(job schemas.SubmitBody) SetJobState(ctx context.Context, state types.JobState) SetJobStateWithTerminationReason( @@ -23,7 +22,8 @@ type Executor interface { termination_message string, ) SetRunnerState(state string) - AddFileArchive(id string, src io.Reader) error + WriteFileArchive(id string, src io.Reader) error + WriteRepoBlob(src io.Reader) error Lock() RLock() RUnlock() diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index 56a5d1cd9..fc4039cf9 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -52,11 +52,12 @@ type ConnectionTracker interface { } type RunExecutor struct { - tempDir string - homeDir string - dstackDir string - archiveDir string - sshd ssh.SshdManager + tempDir string + homeDir string + dstackDir string + fileArchiveDir string + repoBlobDir string + sshd ssh.SshdManager currentUid uint32 @@ -67,7 +68,7 @@ type RunExecutor struct { secrets map[string]string repoCredentials *schemas.RepoCredentials repoDir string - codePath string + repoBlobPath string jobUid int jobGid int jobHomeDir string @@ -123,14 +124,15 @@ func NewRunExecutor(tempDir string, homeDir string, dstackDir string, sshd ssh.S } return &RunExecutor{ - tempDir: tempDir, - homeDir: homeDir, - dstackDir: dstackDir, - archiveDir: filepath.Join(tempDir, "file_archives"), - sshd: sshd, - currentUid: uid, - jobUid: -1, - jobGid: -1, + tempDir: tempDir, + homeDir: homeDir, + dstackDir: dstackDir, + fileArchiveDir: filepath.Join(tempDir, "file_archives"), + repoBlobDir: filepath.Join(tempDir, "repo_blobs"), + sshd: sshd, + currentUid: uid, + jobUid: -1, + jobGid: -1, mu: mu, state: WaitSubmit, @@ -145,7 +147,7 @@ func NewRunExecutor(tempDir string, homeDir string, dstackDir string, sshd ssh.S }, nil } -// Run must be called after SetJob and SetCodePath +// Run must be called after SetJob and WriteRepoBlob func (ex *RunExecutor) Run(ctx context.Context) (err error) { runnerLogFile, err := log.CreateAppendFile(filepath.Join(ex.tempDir, consts.RunnerLogFileName)) if err != nil { @@ -296,11 +298,6 @@ func (ex *RunExecutor) SetJob(body schemas.SubmitBody) { ex.state = WaitCode } -func (ex *RunExecutor) SetCodePath(codePath string) { - ex.codePath = codePath - ex.state = WaitRun -} - func (ex *RunExecutor) SetJobState(ctx context.Context, state types.JobState) { ex.SetJobStateWithTerminationReason(ctx, state, "", "") } diff --git a/runner/internal/executor/executor_test.go b/runner/internal/executor/executor_test.go index e3661fac0..0d935dd64 100644 --- a/runner/internal/executor/executor_test.go +++ b/runner/internal/executor/executor_test.go @@ -69,7 +69,7 @@ func TestExecutor_HomeDir(t *testing.T) { func TestExecutor_NonZeroExit(t *testing.T) { ex := makeTestExecutor(t) ex.jobSpec.Commands = append(ex.jobSpec.Commands, "exit 100") - makeCodeTar(t, ex.codePath) + makeCodeTar(t, ex) err := ex.Run(t.Context()) assert.Error(t, err) @@ -104,7 +104,7 @@ func TestExecutor_LocalRepo(t *testing.T) { ex := makeTestExecutor(t) cmd := fmt.Sprintf("cat %s/foo", *ex.jobSpec.RepoDir) ex.jobSpec.Commands = append(ex.jobSpec.Commands, cmd) - makeCodeTar(t, ex.codePath) + makeCodeTar(t, ex) err := ex.setupRepo(t.Context()) require.NoError(t, err) @@ -117,7 +117,7 @@ func TestExecutor_LocalRepo(t *testing.T) { func TestExecutor_Recover(t *testing.T) { ex := makeTestExecutor(t) ex.jobSpec.Commands = nil // cause a panic - makeCodeTar(t, ex.codePath) + makeCodeTar(t, ex) err := ex.Run(t.Context()) assert.ErrorContains(t, err, "recovered: ") @@ -134,7 +134,7 @@ func TestExecutor_MaxDuration(t *testing.T) { ex.killDelay = 500 * time.Millisecond ex.jobSpec.Commands = append(ex.jobSpec.Commands, "echo 1 && sleep 2 && echo 2") ex.jobSpec.MaxDuration = 1 // seconds - makeCodeTar(t, ex.codePath) + makeCodeTar(t, ex) err := ex.Run(t.Context()) assert.ErrorContains(t, err, "killed") @@ -155,7 +155,7 @@ func TestExecutor_RemoteRepo(t *testing.T) { RepoConfigEmail: "developer@dstack.ai", } ex.jobSpec.Commands = append(ex.jobSpec.Commands, "git rev-parse HEAD && git config user.name && git config user.email") - err := os.WriteFile(ex.codePath, []byte{}, 0o600) // empty diff + err := ex.WriteRepoBlob(bytes.NewReader([]byte{})) // empty diff require.NoError(t, err) err = ex.setJobWorkingDir(t.Context()) @@ -210,19 +210,17 @@ func makeTestExecutor(t *testing.T) *RunExecutor { require.NoError(t, os.Mkdir(homeDir, 0o700)) dstackDir := filepath.Join(baseDir, "dstack") require.NoError(t, os.Mkdir(dstackDir, 0o755)) - ex, _ := NewRunExecutor(tempDir, homeDir, dstackDir, new(sshdMock)) + ex, err := NewRunExecutor(tempDir, homeDir, dstackDir, new(sshdMock)) + require.NoError(t, err) ex.SetJob(body) - ex.SetCodePath(filepath.Join(baseDir, "code")) // note: create file before run - ex.setJobWorkingDir(context.Background()) + require.NoError(t, ex.setJobWorkingDir(t.Context())) return ex } -func makeCodeTar(t *testing.T, path string) { +func makeCodeTar(t *testing.T, ex *RunExecutor) { t.Helper() - file, err := os.Create(path) - require.NoError(t, err) - defer func() { _ = file.Close() }() - tw := tar.NewWriter(file) + var b bytes.Buffer + tw := tar.NewWriter(&b) files := []struct{ name, body string }{ {"foo", "bar\n"}, @@ -235,6 +233,8 @@ func makeCodeTar(t *testing.T, path string) { require.NoError(t, err) } require.NoError(t, tw.Close()) + + require.NoError(t, ex.WriteRepoBlob(&b)) } func TestWriteDstackProfile(t *testing.T) { diff --git a/runner/internal/executor/files.go b/runner/internal/executor/files.go index c447006c3..ee1170c41 100644 --- a/runner/internal/executor/files.go +++ b/runner/internal/executor/files.go @@ -18,11 +18,11 @@ import ( var renameRegex = regexp.MustCompile(`^([^/]*)(/|$)`) -func (ex *RunExecutor) AddFileArchive(id string, src io.Reader) error { - if err := os.MkdirAll(ex.archiveDir, 0o755); err != nil { +func (ex *RunExecutor) WriteFileArchive(id string, src io.Reader) error { + if err := os.MkdirAll(ex.fileArchiveDir, 0o755); err != nil { return fmt.Errorf("create archive directory: %w", err) } - archivePath := path.Join(ex.archiveDir, id) + archivePath := path.Join(ex.fileArchiveDir, id) archive, err := os.Create(archivePath) if err != nil { return fmt.Errorf("create archive file: %w", err) @@ -45,13 +45,13 @@ func (ex *RunExecutor) setupFiles(ctx context.Context) error { return fmt.Errorf("setup files: working dir must be absolute: %s", ex.jobWorkingDir) } for _, fa := range ex.jobSpec.FileArchives { - archivePath := path.Join(ex.archiveDir, fa.Id) + archivePath := path.Join(ex.fileArchiveDir, fa.Id) if err := extractFileArchive(ctx, archivePath, fa.Path, ex.jobWorkingDir, ex.jobUid, ex.jobGid, ex.jobHomeDir); err != nil { return fmt.Errorf("extract file archive %s: %w", fa.Id, err) } } - if err := os.RemoveAll(ex.archiveDir); err != nil { - log.Warning(ctx, "Failed to remove file archives dir", "path", ex.archiveDir, "err", err) + if err := os.RemoveAll(ex.fileArchiveDir); err != nil { + log.Warning(ctx, "Failed to remove file archives dir", "path", ex.fileArchiveDir, "err", err) } return nil } @@ -90,7 +90,6 @@ func extractFileArchive(ctx context.Context, archivePath string, destPath string if uid != -1 || gid != -1 { for _, p := range paths { - log.Warning(ctx, "path", "path", p) if err := os.Chown(path.Join(destBase, p), uid, gid); err != nil { log.Warning(ctx, "Failed to chown", "path", p, "err", err) } diff --git a/runner/internal/executor/repo.go b/runner/internal/executor/repo.go index 32f623e70..2f757f63c 100644 --- a/runner/internal/executor/repo.go +++ b/runner/internal/executor/repo.go @@ -4,9 +4,11 @@ import ( "context" "errors" "fmt" + "io" "io/fs" "os" "os/exec" + "path" "path/filepath" "github.com/codeclysm/extract/v4" @@ -17,6 +19,23 @@ import ( "github.com/dstackai/dstack/runner/internal/schemas" ) +// WriteRepoBlob must be called after SetJob +func (ex *RunExecutor) WriteRepoBlob(src io.Reader) error { + if err := os.MkdirAll(ex.repoBlobDir, 0o755); err != nil { + return fmt.Errorf("create blob directory: %w", err) + } + ex.repoBlobPath = path.Join(ex.repoBlobDir, ex.run.RunSpec.RepoId) + blob, err := os.Create(ex.repoBlobPath) + if err != nil { + return fmt.Errorf("create blob file: %w", err) + } + defer func() { _ = blob.Close() }() + if _, err = io.Copy(blob, src); err != nil { + return fmt.Errorf("copy blob data: %w", err) + } + return nil +} + // setupRepo must be called from Run // Must be called after setJobWorkingDir and setJobCredentials func (ex *RunExecutor) setupRepo(ctx context.Context) error { @@ -100,6 +119,10 @@ func (ex *RunExecutor) setupRepo(ctx context.Context) error { return fmt.Errorf("chown repo dir: %w", err) } + if err := os.RemoveAll(ex.repoBlobDir); err != nil { + log.Warning(ctx, "Failed to remove repo blobs dir", "path", ex.repoBlobDir, "err", err) + } + return err } @@ -143,7 +166,7 @@ func (ex *RunExecutor) prepareGit(ctx context.Context) error { } log.Trace(ctx, "Applying diff") - repoDiff, err := os.ReadFile(ex.codePath) + repoDiff, err := os.ReadFile(ex.repoBlobPath) if err != nil { return fmt.Errorf("read repo diff: %w", err) } @@ -156,12 +179,12 @@ func (ex *RunExecutor) prepareGit(ctx context.Context) error { } func (ex *RunExecutor) prepareArchive(ctx context.Context) error { - file, err := os.Open(ex.codePath) + file, err := os.Open(ex.repoBlobPath) if err != nil { return fmt.Errorf("open code archive: %w", err) } defer func() { _ = file.Close() }() - log.Trace(ctx, "Extracting code archive", "src", ex.codePath, "dst", ex.repoDir) + log.Trace(ctx, "Extracting code archive", "src", ex.repoBlobPath, "dst", ex.repoDir) if err := extract.Tar(ctx, file, ex.repoDir, nil); err != nil { return fmt.Errorf("extract tar archive: %w", err) } diff --git a/runner/internal/runner/api/http.go b/runner/internal/runner/api/http.go index bbf416efb..87eb96e0a 100644 --- a/runner/internal/runner/api/http.go +++ b/runner/internal/runner/api/http.go @@ -9,8 +9,6 @@ import ( "mime" "mime/multipart" "net/http" - "os" - "path/filepath" "strconv" "github.com/dstackai/dstack/runner/internal/api" @@ -19,6 +17,9 @@ import ( "github.com/dstackai/dstack/runner/internal/schemas" ) +// TODO: set some reasonable value; (optional) make configurable +const maxBodySize = math.MaxInt64 + func (s *Server) healthcheckGetHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { return &schemas.HealthcheckResponse{ Service: "dstack-runner", @@ -84,13 +85,16 @@ func (s *Server) uploadArchivePostHandler(w http.ResponseWriter, r *http.Request return nil, &api.Error{Status: http.StatusBadRequest, Msg: "missing boundary"} } - r.Body = http.MaxBytesReader(w, r.Body, math.MaxInt64) + r.Body = http.MaxBytesReader(w, r.Body, maxBodySize) formReader := multipart.NewReader(r.Body, boundary) part, err := formReader.NextPart() - if errors.Is(err, io.EOF) { - return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty form"} - } if err != nil { + if errors.Is(err, io.EOF) { + return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty form"} + } + if isMaxBytesError(err) { + return nil, &api.Error{Status: http.StatusRequestEntityTooLarge} + } return nil, fmt.Errorf("read multipart form: %w", err) } defer func() { _ = part.Close() }() @@ -106,8 +110,11 @@ func (s *Server) uploadArchivePostHandler(w http.ResponseWriter, r *http.Request if archiveId == "" { return nil, &api.Error{Status: http.StatusBadRequest, Msg: "missing file name"} } - if err := s.executor.AddFileArchive(archiveId, part); err != nil { - return nil, fmt.Errorf("add file archive: %w", err) + if err := s.executor.WriteFileArchive(archiveId, part); err != nil { + if isMaxBytesError(err) { + return nil, &api.Error{Status: http.StatusRequestEntityTooLarge} + } + return nil, fmt.Errorf("write file archive: %w", err) } if _, err := formReader.NextPart(); !errors.Is(err, io.EOF) { return nil, &api.Error{Status: http.StatusBadRequest, Msg: "extra form field(s)"} @@ -123,21 +130,17 @@ func (s *Server) uploadCodePostHandler(w http.ResponseWriter, r *http.Request) ( return nil, &api.Error{Status: http.StatusConflict} } - r.Body = http.MaxBytesReader(w, r.Body, math.MaxInt64) - codePath := filepath.Join(s.tempDir, "code") // todo random name? - file, err := os.Create(codePath) - if err != nil { - return nil, fmt.Errorf("create code file: %w", err) - } - defer func() { _ = file.Close() }() - if _, err = io.Copy(file, r.Body); err != nil { - if err.Error() == "http: request body too large" { + r.Body = http.MaxBytesReader(w, r.Body, maxBodySize) + + if err := s.executor.WriteRepoBlob(r.Body); err != nil { + if isMaxBytesError(err) { return nil, &api.Error{Status: http.StatusRequestEntityTooLarge} } return nil, fmt.Errorf("copy request body: %w", err) } - s.executor.SetCodePath(codePath) + s.executor.SetRunnerState(executor.WaitRun) + return nil, nil } @@ -181,3 +184,8 @@ func (s *Server) stopPostHandler(w http.ResponseWriter, r *http.Request) (interf s.stop() return nil, nil } + +func isMaxBytesError(err error) bool { + var maxBytesError *http.MaxBytesError + return errors.As(err, &maxBytesError) +} diff --git a/runner/internal/runner/api/server.go b/runner/internal/runner/api/server.go index 0a0b851a9..ba577d1a5 100644 --- a/runner/internal/runner/api/server.go +++ b/runner/internal/runner/api/server.go @@ -11,12 +11,10 @@ import ( "github.com/dstackai/dstack/runner/internal/executor" "github.com/dstackai/dstack/runner/internal/log" "github.com/dstackai/dstack/runner/internal/metrics" - "github.com/dstackai/dstack/runner/internal/ssh" ) type Server struct { - srv *http.Server - tempDir string + srv *http.Server shutdownCh chan interface{} // server closes this chan on shutdown jobBarrierCh chan interface{} // only server listens on this chan @@ -34,15 +32,8 @@ type Server struct { version string } -func NewServer( - ctx context.Context, tempDir string, homeDir string, dstackDir string, sshd ssh.SshdManager, - address string, version string, -) (*Server, error) { +func NewServer(ctx context.Context, address string, version string, ex executor.Executor) (*Server, error) { r := api.NewRouter() - ex, err := executor.NewRunExecutor(tempDir, homeDir, dstackDir, sshd) - if err != nil { - return nil, err - } metricsCollector, err := metrics.NewMetricsCollector(ctx) if err != nil { @@ -54,7 +45,6 @@ func NewServer( Addr: address, Handler: r, }, - tempDir: tempDir, shutdownCh: make(chan interface{}), jobBarrierCh: make(chan interface{}),