Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion runner/cmd/runner/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/urfave/cli/v3"

"github.com/dstackai/dstack/runner/consts"
"github.com/dstackai/dstack/runner/internal/executor"
"github.com/dstackai/dstack/runner/internal/log"
"github.com/dstackai/dstack/runner/internal/runner/api"
"github.com/dstackai/dstack/runner/internal/ssh"
Expand Down Expand Up @@ -162,7 +163,12 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss
}
}()

server, err := api.NewServer(ctx, tempDir, homeDir, dstackDir, sshd, fmt.Sprintf(":%d", httpPort), version)
ex, err := executor.NewRunExecutor(tempDir, homeDir, dstackDir, sshd)
if err != nil {
return fmt.Errorf("create executor: %w", err)
}

server, err := api.NewServer(ctx, fmt.Sprintf(":%d", httpPort), version, ex)
if err != nil {
return fmt.Errorf("create server: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions runner/internal/executor/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ type Executor interface {
GetJobWsLogsHistory() []schemas.LogEvent
GetRunnerState() string
Run(ctx context.Context) error
SetCodePath(codePath string)
SetJob(job schemas.SubmitBody)
SetJobState(ctx context.Context, state types.JobState)
SetJobStateWithTerminationReason(
Expand All @@ -23,7 +22,8 @@ type Executor interface {
termination_message string,
)
SetRunnerState(state string)
AddFileArchive(id string, src io.Reader) error
WriteFileArchive(id string, src io.Reader) error
WriteRepoBlob(src io.Reader) error
Lock()
RLock()
RUnlock()
Expand Down
37 changes: 17 additions & 20 deletions runner/internal/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ type ConnectionTracker interface {
}

type RunExecutor struct {
tempDir string
homeDir string
dstackDir string
archiveDir string
sshd ssh.SshdManager
tempDir string
homeDir string
dstackDir string
fileArchiveDir string
repoBlobDir string
sshd ssh.SshdManager

currentUid uint32

Expand All @@ -67,7 +68,7 @@ type RunExecutor struct {
secrets map[string]string
repoCredentials *schemas.RepoCredentials
repoDir string
codePath string
repoBlobPath string
jobUid int
jobGid int
jobHomeDir string
Expand Down Expand Up @@ -123,14 +124,15 @@ func NewRunExecutor(tempDir string, homeDir string, dstackDir string, sshd ssh.S
}

return &RunExecutor{
tempDir: tempDir,
homeDir: homeDir,
dstackDir: dstackDir,
archiveDir: filepath.Join(tempDir, "file_archives"),
sshd: sshd,
currentUid: uid,
jobUid: -1,
jobGid: -1,
tempDir: tempDir,
homeDir: homeDir,
dstackDir: dstackDir,
fileArchiveDir: filepath.Join(tempDir, "file_archives"),
repoBlobDir: filepath.Join(tempDir, "repo_blobs"),
sshd: sshd,
currentUid: uid,
jobUid: -1,
jobGid: -1,

mu: mu,
state: WaitSubmit,
Expand All @@ -145,7 +147,7 @@ func NewRunExecutor(tempDir string, homeDir string, dstackDir string, sshd ssh.S
}, nil
}

// Run must be called after SetJob and SetCodePath
// Run must be called after SetJob and WriteRepoBlob
func (ex *RunExecutor) Run(ctx context.Context) (err error) {
runnerLogFile, err := log.CreateAppendFile(filepath.Join(ex.tempDir, consts.RunnerLogFileName))
if err != nil {
Expand Down Expand Up @@ -296,11 +298,6 @@ func (ex *RunExecutor) SetJob(body schemas.SubmitBody) {
ex.state = WaitCode
}

func (ex *RunExecutor) SetCodePath(codePath string) {
ex.codePath = codePath
ex.state = WaitRun
}

func (ex *RunExecutor) SetJobState(ctx context.Context, state types.JobState) {
ex.SetJobStateWithTerminationReason(ctx, state, "", "")
}
Expand Down
26 changes: 13 additions & 13 deletions runner/internal/executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func TestExecutor_HomeDir(t *testing.T) {
func TestExecutor_NonZeroExit(t *testing.T) {
ex := makeTestExecutor(t)
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "exit 100")
makeCodeTar(t, ex.codePath)
makeCodeTar(t, ex)

err := ex.Run(t.Context())
assert.Error(t, err)
Expand Down Expand Up @@ -104,7 +104,7 @@ func TestExecutor_LocalRepo(t *testing.T) {
ex := makeTestExecutor(t)
cmd := fmt.Sprintf("cat %s/foo", *ex.jobSpec.RepoDir)
ex.jobSpec.Commands = append(ex.jobSpec.Commands, cmd)
makeCodeTar(t, ex.codePath)
makeCodeTar(t, ex)

err := ex.setupRepo(t.Context())
require.NoError(t, err)
Expand All @@ -117,7 +117,7 @@ func TestExecutor_LocalRepo(t *testing.T) {
func TestExecutor_Recover(t *testing.T) {
ex := makeTestExecutor(t)
ex.jobSpec.Commands = nil // cause a panic
makeCodeTar(t, ex.codePath)
makeCodeTar(t, ex)

err := ex.Run(t.Context())
assert.ErrorContains(t, err, "recovered: ")
Expand All @@ -134,7 +134,7 @@ func TestExecutor_MaxDuration(t *testing.T) {
ex.killDelay = 500 * time.Millisecond
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "echo 1 && sleep 2 && echo 2")
ex.jobSpec.MaxDuration = 1 // seconds
makeCodeTar(t, ex.codePath)
makeCodeTar(t, ex)

err := ex.Run(t.Context())
assert.ErrorContains(t, err, "killed")
Expand All @@ -155,7 +155,7 @@ func TestExecutor_RemoteRepo(t *testing.T) {
RepoConfigEmail: "[email protected]",
}
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "git rev-parse HEAD && git config user.name && git config user.email")
err := os.WriteFile(ex.codePath, []byte{}, 0o600) // empty diff
err := ex.WriteRepoBlob(bytes.NewReader([]byte{})) // empty diff
require.NoError(t, err)

err = ex.setJobWorkingDir(t.Context())
Expand Down Expand Up @@ -210,19 +210,17 @@ func makeTestExecutor(t *testing.T) *RunExecutor {
require.NoError(t, os.Mkdir(homeDir, 0o700))
dstackDir := filepath.Join(baseDir, "dstack")
require.NoError(t, os.Mkdir(dstackDir, 0o755))
ex, _ := NewRunExecutor(tempDir, homeDir, dstackDir, new(sshdMock))
ex, err := NewRunExecutor(tempDir, homeDir, dstackDir, new(sshdMock))
require.NoError(t, err)
ex.SetJob(body)
ex.SetCodePath(filepath.Join(baseDir, "code")) // note: create file before run
ex.setJobWorkingDir(context.Background())
require.NoError(t, ex.setJobWorkingDir(t.Context()))
return ex
}

func makeCodeTar(t *testing.T, path string) {
func makeCodeTar(t *testing.T, ex *RunExecutor) {
t.Helper()
file, err := os.Create(path)
require.NoError(t, err)
defer func() { _ = file.Close() }()
tw := tar.NewWriter(file)
var b bytes.Buffer
tw := tar.NewWriter(&b)

files := []struct{ name, body string }{
{"foo", "bar\n"},
Expand All @@ -235,6 +233,8 @@ func makeCodeTar(t *testing.T, path string) {
require.NoError(t, err)
}
require.NoError(t, tw.Close())

require.NoError(t, ex.WriteRepoBlob(&b))
}

func TestWriteDstackProfile(t *testing.T) {
Expand Down
13 changes: 6 additions & 7 deletions runner/internal/executor/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ import (

var renameRegex = regexp.MustCompile(`^([^/]*)(/|$)`)

func (ex *RunExecutor) AddFileArchive(id string, src io.Reader) error {
if err := os.MkdirAll(ex.archiveDir, 0o755); err != nil {
func (ex *RunExecutor) WriteFileArchive(id string, src io.Reader) error {
if err := os.MkdirAll(ex.fileArchiveDir, 0o755); err != nil {
return fmt.Errorf("create archive directory: %w", err)
}
archivePath := path.Join(ex.archiveDir, id)
archivePath := path.Join(ex.fileArchiveDir, id)
archive, err := os.Create(archivePath)
if err != nil {
return fmt.Errorf("create archive file: %w", err)
Expand All @@ -45,13 +45,13 @@ func (ex *RunExecutor) setupFiles(ctx context.Context) error {
return fmt.Errorf("setup files: working dir must be absolute: %s", ex.jobWorkingDir)
}
for _, fa := range ex.jobSpec.FileArchives {
archivePath := path.Join(ex.archiveDir, fa.Id)
archivePath := path.Join(ex.fileArchiveDir, fa.Id)
if err := extractFileArchive(ctx, archivePath, fa.Path, ex.jobWorkingDir, ex.jobUid, ex.jobGid, ex.jobHomeDir); err != nil {
return fmt.Errorf("extract file archive %s: %w", fa.Id, err)
}
}
if err := os.RemoveAll(ex.archiveDir); err != nil {
log.Warning(ctx, "Failed to remove file archives dir", "path", ex.archiveDir, "err", err)
if err := os.RemoveAll(ex.fileArchiveDir); err != nil {
log.Warning(ctx, "Failed to remove file archives dir", "path", ex.fileArchiveDir, "err", err)
}
return nil
}
Expand Down Expand Up @@ -90,7 +90,6 @@ func extractFileArchive(ctx context.Context, archivePath string, destPath string

if uid != -1 || gid != -1 {
for _, p := range paths {
log.Warning(ctx, "path", "path", p)
if err := os.Chown(path.Join(destBase, p), uid, gid); err != nil {
log.Warning(ctx, "Failed to chown", "path", p, "err", err)
}
Expand Down
29 changes: 26 additions & 3 deletions runner/internal/executor/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"os"
"os/exec"
"path"
"path/filepath"

"github.com/codeclysm/extract/v4"
Expand All @@ -17,6 +19,23 @@ import (
"github.com/dstackai/dstack/runner/internal/schemas"
)

// WriteRepoBlob must be called after SetJob
func (ex *RunExecutor) WriteRepoBlob(src io.Reader) error {
if err := os.MkdirAll(ex.repoBlobDir, 0o755); err != nil {
return fmt.Errorf("create blob directory: %w", err)
}
ex.repoBlobPath = path.Join(ex.repoBlobDir, ex.run.RunSpec.RepoId)
blob, err := os.Create(ex.repoBlobPath)
if err != nil {
return fmt.Errorf("create blob file: %w", err)
}
defer func() { _ = blob.Close() }()
if _, err = io.Copy(blob, src); err != nil {
return fmt.Errorf("copy blob data: %w", err)
}
return nil
}

// setupRepo must be called from Run
// Must be called after setJobWorkingDir and setJobCredentials
func (ex *RunExecutor) setupRepo(ctx context.Context) error {
Expand Down Expand Up @@ -100,6 +119,10 @@ func (ex *RunExecutor) setupRepo(ctx context.Context) error {
return fmt.Errorf("chown repo dir: %w", err)
}

if err := os.RemoveAll(ex.repoBlobDir); err != nil {
log.Warning(ctx, "Failed to remove repo blobs dir", "path", ex.repoBlobDir, "err", err)
}

return err
}

Expand Down Expand Up @@ -143,7 +166,7 @@ func (ex *RunExecutor) prepareGit(ctx context.Context) error {
}

log.Trace(ctx, "Applying diff")
repoDiff, err := os.ReadFile(ex.codePath)
repoDiff, err := os.ReadFile(ex.repoBlobPath)
if err != nil {
return fmt.Errorf("read repo diff: %w", err)
}
Expand All @@ -156,12 +179,12 @@ func (ex *RunExecutor) prepareGit(ctx context.Context) error {
}

func (ex *RunExecutor) prepareArchive(ctx context.Context) error {
file, err := os.Open(ex.codePath)
file, err := os.Open(ex.repoBlobPath)
if err != nil {
return fmt.Errorf("open code archive: %w", err)
}
defer func() { _ = file.Close() }()
log.Trace(ctx, "Extracting code archive", "src", ex.codePath, "dst", ex.repoDir)
log.Trace(ctx, "Extracting code archive", "src", ex.repoBlobPath, "dst", ex.repoDir)
if err := extract.Tar(ctx, file, ex.repoDir, nil); err != nil {
return fmt.Errorf("extract tar archive: %w", err)
}
Expand Down
44 changes: 26 additions & 18 deletions runner/internal/runner/api/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ import (
"mime"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strconv"

"github.com/dstackai/dstack/runner/internal/api"
Expand All @@ -19,6 +17,9 @@ import (
"github.com/dstackai/dstack/runner/internal/schemas"
)

// TODO: set some reasonable value; (optional) make configurable
const maxBodySize = math.MaxInt64

func (s *Server) healthcheckGetHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
return &schemas.HealthcheckResponse{
Service: "dstack-runner",
Expand Down Expand Up @@ -84,13 +85,16 @@ func (s *Server) uploadArchivePostHandler(w http.ResponseWriter, r *http.Request
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "missing boundary"}
}

r.Body = http.MaxBytesReader(w, r.Body, math.MaxInt64)
r.Body = http.MaxBytesReader(w, r.Body, maxBodySize)
formReader := multipart.NewReader(r.Body, boundary)
part, err := formReader.NextPart()
if errors.Is(err, io.EOF) {
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty form"}
}
if err != nil {
if errors.Is(err, io.EOF) {
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty form"}
}
if isMaxBytesError(err) {
return nil, &api.Error{Status: http.StatusRequestEntityTooLarge}
}
return nil, fmt.Errorf("read multipart form: %w", err)
}
defer func() { _ = part.Close() }()
Expand All @@ -106,8 +110,11 @@ func (s *Server) uploadArchivePostHandler(w http.ResponseWriter, r *http.Request
if archiveId == "" {
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "missing file name"}
}
if err := s.executor.AddFileArchive(archiveId, part); err != nil {
return nil, fmt.Errorf("add file archive: %w", err)
if err := s.executor.WriteFileArchive(archiveId, part); err != nil {
if isMaxBytesError(err) {
return nil, &api.Error{Status: http.StatusRequestEntityTooLarge}
}
return nil, fmt.Errorf("write file archive: %w", err)
}
if _, err := formReader.NextPart(); !errors.Is(err, io.EOF) {
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "extra form field(s)"}
Expand All @@ -123,21 +130,17 @@ func (s *Server) uploadCodePostHandler(w http.ResponseWriter, r *http.Request) (
return nil, &api.Error{Status: http.StatusConflict}
}

r.Body = http.MaxBytesReader(w, r.Body, math.MaxInt64)
codePath := filepath.Join(s.tempDir, "code") // todo random name?
file, err := os.Create(codePath)
if err != nil {
return nil, fmt.Errorf("create code file: %w", err)
}
defer func() { _ = file.Close() }()
if _, err = io.Copy(file, r.Body); err != nil {
if err.Error() == "http: request body too large" {
r.Body = http.MaxBytesReader(w, r.Body, maxBodySize)

if err := s.executor.WriteRepoBlob(r.Body); err != nil {
if isMaxBytesError(err) {
return nil, &api.Error{Status: http.StatusRequestEntityTooLarge}
}
return nil, fmt.Errorf("copy request body: %w", err)
}

s.executor.SetCodePath(codePath)
s.executor.SetRunnerState(executor.WaitRun)

return nil, nil
}

Expand Down Expand Up @@ -181,3 +184,8 @@ func (s *Server) stopPostHandler(w http.ResponseWriter, r *http.Request) (interf
s.stop()
return nil, nil
}

func isMaxBytesError(err error) bool {
var maxBytesError *http.MaxBytesError
return errors.As(err, &maxBytesError)
}
Loading