Skip to content

Commit 2a4c0e1

Browse files
authored
[runner] Decouple Server and Executor (#3447)
* Pass Executor to Server as an argument * Move repo blob-related code from the API handler to a new Executor method * Fix http.MaxBytesError handling
1 parent dd907bf commit 2a4c0e1

File tree

8 files changed

+99
-76
lines changed

8 files changed

+99
-76
lines changed

runner/cmd/runner/main.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/urfave/cli/v3"
1616

1717
"github.com/dstackai/dstack/runner/consts"
18+
"github.com/dstackai/dstack/runner/internal/executor"
1819
"github.com/dstackai/dstack/runner/internal/log"
1920
"github.com/dstackai/dstack/runner/internal/runner/api"
2021
"github.com/dstackai/dstack/runner/internal/ssh"
@@ -162,7 +163,12 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss
162163
}
163164
}()
164165

165-
server, err := api.NewServer(ctx, tempDir, homeDir, dstackDir, sshd, fmt.Sprintf(":%d", httpPort), version)
166+
ex, err := executor.NewRunExecutor(tempDir, homeDir, dstackDir, sshd)
167+
if err != nil {
168+
return fmt.Errorf("create executor: %w", err)
169+
}
170+
171+
server, err := api.NewServer(ctx, fmt.Sprintf(":%d", httpPort), version, ex)
166172
if err != nil {
167173
return fmt.Errorf("create server: %w", err)
168174
}

runner/internal/executor/base.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ type Executor interface {
1313
GetJobWsLogsHistory() []schemas.LogEvent
1414
GetRunnerState() string
1515
Run(ctx context.Context) error
16-
SetCodePath(codePath string)
1716
SetJob(job schemas.SubmitBody)
1817
SetJobState(ctx context.Context, state types.JobState)
1918
SetJobStateWithTerminationReason(
@@ -23,7 +22,8 @@ type Executor interface {
2322
termination_message string,
2423
)
2524
SetRunnerState(state string)
26-
AddFileArchive(id string, src io.Reader) error
25+
WriteFileArchive(id string, src io.Reader) error
26+
WriteRepoBlob(src io.Reader) error
2727
Lock()
2828
RLock()
2929
RUnlock()

runner/internal/executor/executor.go

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,12 @@ type ConnectionTracker interface {
5252
}
5353

5454
type RunExecutor struct {
55-
tempDir string
56-
homeDir string
57-
dstackDir string
58-
archiveDir string
59-
sshd ssh.SshdManager
55+
tempDir string
56+
homeDir string
57+
dstackDir string
58+
fileArchiveDir string
59+
repoBlobDir string
60+
sshd ssh.SshdManager
6061

6162
currentUid uint32
6263

@@ -67,7 +68,7 @@ type RunExecutor struct {
6768
secrets map[string]string
6869
repoCredentials *schemas.RepoCredentials
6970
repoDir string
70-
codePath string
71+
repoBlobPath string
7172
jobUid int
7273
jobGid int
7374
jobHomeDir string
@@ -123,14 +124,15 @@ func NewRunExecutor(tempDir string, homeDir string, dstackDir string, sshd ssh.S
123124
}
124125

125126
return &RunExecutor{
126-
tempDir: tempDir,
127-
homeDir: homeDir,
128-
dstackDir: dstackDir,
129-
archiveDir: filepath.Join(tempDir, "file_archives"),
130-
sshd: sshd,
131-
currentUid: uid,
132-
jobUid: -1,
133-
jobGid: -1,
127+
tempDir: tempDir,
128+
homeDir: homeDir,
129+
dstackDir: dstackDir,
130+
fileArchiveDir: filepath.Join(tempDir, "file_archives"),
131+
repoBlobDir: filepath.Join(tempDir, "repo_blobs"),
132+
sshd: sshd,
133+
currentUid: uid,
134+
jobUid: -1,
135+
jobGid: -1,
134136

135137
mu: mu,
136138
state: WaitSubmit,
@@ -145,7 +147,7 @@ func NewRunExecutor(tempDir string, homeDir string, dstackDir string, sshd ssh.S
145147
}, nil
146148
}
147149

148-
// Run must be called after SetJob and SetCodePath
150+
// Run must be called after SetJob and WriteRepoBlob
149151
func (ex *RunExecutor) Run(ctx context.Context) (err error) {
150152
runnerLogFile, err := log.CreateAppendFile(filepath.Join(ex.tempDir, consts.RunnerLogFileName))
151153
if err != nil {
@@ -296,11 +298,6 @@ func (ex *RunExecutor) SetJob(body schemas.SubmitBody) {
296298
ex.state = WaitCode
297299
}
298300

299-
func (ex *RunExecutor) SetCodePath(codePath string) {
300-
ex.codePath = codePath
301-
ex.state = WaitRun
302-
}
303-
304301
func (ex *RunExecutor) SetJobState(ctx context.Context, state types.JobState) {
305302
ex.SetJobStateWithTerminationReason(ctx, state, "", "")
306303
}

runner/internal/executor/executor_test.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func TestExecutor_HomeDir(t *testing.T) {
6969
func TestExecutor_NonZeroExit(t *testing.T) {
7070
ex := makeTestExecutor(t)
7171
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "exit 100")
72-
makeCodeTar(t, ex.codePath)
72+
makeCodeTar(t, ex)
7373

7474
err := ex.Run(t.Context())
7575
assert.Error(t, err)
@@ -104,7 +104,7 @@ func TestExecutor_LocalRepo(t *testing.T) {
104104
ex := makeTestExecutor(t)
105105
cmd := fmt.Sprintf("cat %s/foo", *ex.jobSpec.RepoDir)
106106
ex.jobSpec.Commands = append(ex.jobSpec.Commands, cmd)
107-
makeCodeTar(t, ex.codePath)
107+
makeCodeTar(t, ex)
108108

109109
err := ex.setupRepo(t.Context())
110110
require.NoError(t, err)
@@ -117,7 +117,7 @@ func TestExecutor_LocalRepo(t *testing.T) {
117117
func TestExecutor_Recover(t *testing.T) {
118118
ex := makeTestExecutor(t)
119119
ex.jobSpec.Commands = nil // cause a panic
120-
makeCodeTar(t, ex.codePath)
120+
makeCodeTar(t, ex)
121121

122122
err := ex.Run(t.Context())
123123
assert.ErrorContains(t, err, "recovered: ")
@@ -134,7 +134,7 @@ func TestExecutor_MaxDuration(t *testing.T) {
134134
ex.killDelay = 500 * time.Millisecond
135135
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "echo 1 && sleep 2 && echo 2")
136136
ex.jobSpec.MaxDuration = 1 // seconds
137-
makeCodeTar(t, ex.codePath)
137+
makeCodeTar(t, ex)
138138

139139
err := ex.Run(t.Context())
140140
assert.ErrorContains(t, err, "killed")
@@ -155,7 +155,7 @@ func TestExecutor_RemoteRepo(t *testing.T) {
155155
RepoConfigEmail: "[email protected]",
156156
}
157157
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "git rev-parse HEAD && git config user.name && git config user.email")
158-
err := os.WriteFile(ex.codePath, []byte{}, 0o600) // empty diff
158+
err := ex.WriteRepoBlob(bytes.NewReader([]byte{})) // empty diff
159159
require.NoError(t, err)
160160

161161
err = ex.setJobWorkingDir(t.Context())
@@ -210,19 +210,17 @@ func makeTestExecutor(t *testing.T) *RunExecutor {
210210
require.NoError(t, os.Mkdir(homeDir, 0o700))
211211
dstackDir := filepath.Join(baseDir, "dstack")
212212
require.NoError(t, os.Mkdir(dstackDir, 0o755))
213-
ex, _ := NewRunExecutor(tempDir, homeDir, dstackDir, new(sshdMock))
213+
ex, err := NewRunExecutor(tempDir, homeDir, dstackDir, new(sshdMock))
214+
require.NoError(t, err)
214215
ex.SetJob(body)
215-
ex.SetCodePath(filepath.Join(baseDir, "code")) // note: create file before run
216-
ex.setJobWorkingDir(context.Background())
216+
require.NoError(t, ex.setJobWorkingDir(t.Context()))
217217
return ex
218218
}
219219

220-
func makeCodeTar(t *testing.T, path string) {
220+
func makeCodeTar(t *testing.T, ex *RunExecutor) {
221221
t.Helper()
222-
file, err := os.Create(path)
223-
require.NoError(t, err)
224-
defer func() { _ = file.Close() }()
225-
tw := tar.NewWriter(file)
222+
var b bytes.Buffer
223+
tw := tar.NewWriter(&b)
226224

227225
files := []struct{ name, body string }{
228226
{"foo", "bar\n"},
@@ -235,6 +233,8 @@ func makeCodeTar(t *testing.T, path string) {
235233
require.NoError(t, err)
236234
}
237235
require.NoError(t, tw.Close())
236+
237+
require.NoError(t, ex.WriteRepoBlob(&b))
238238
}
239239

240240
func TestWriteDstackProfile(t *testing.T) {

runner/internal/executor/files.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ import (
1818

1919
var renameRegex = regexp.MustCompile(`^([^/]*)(/|$)`)
2020

21-
func (ex *RunExecutor) AddFileArchive(id string, src io.Reader) error {
22-
if err := os.MkdirAll(ex.archiveDir, 0o755); err != nil {
21+
func (ex *RunExecutor) WriteFileArchive(id string, src io.Reader) error {
22+
if err := os.MkdirAll(ex.fileArchiveDir, 0o755); err != nil {
2323
return fmt.Errorf("create archive directory: %w", err)
2424
}
25-
archivePath := path.Join(ex.archiveDir, id)
25+
archivePath := path.Join(ex.fileArchiveDir, id)
2626
archive, err := os.Create(archivePath)
2727
if err != nil {
2828
return fmt.Errorf("create archive file: %w", err)
@@ -45,13 +45,13 @@ func (ex *RunExecutor) setupFiles(ctx context.Context) error {
4545
return fmt.Errorf("setup files: working dir must be absolute: %s", ex.jobWorkingDir)
4646
}
4747
for _, fa := range ex.jobSpec.FileArchives {
48-
archivePath := path.Join(ex.archiveDir, fa.Id)
48+
archivePath := path.Join(ex.fileArchiveDir, fa.Id)
4949
if err := extractFileArchive(ctx, archivePath, fa.Path, ex.jobWorkingDir, ex.jobUid, ex.jobGid, ex.jobHomeDir); err != nil {
5050
return fmt.Errorf("extract file archive %s: %w", fa.Id, err)
5151
}
5252
}
53-
if err := os.RemoveAll(ex.archiveDir); err != nil {
54-
log.Warning(ctx, "Failed to remove file archives dir", "path", ex.archiveDir, "err", err)
53+
if err := os.RemoveAll(ex.fileArchiveDir); err != nil {
54+
log.Warning(ctx, "Failed to remove file archives dir", "path", ex.fileArchiveDir, "err", err)
5555
}
5656
return nil
5757
}
@@ -90,7 +90,6 @@ func extractFileArchive(ctx context.Context, archivePath string, destPath string
9090

9191
if uid != -1 || gid != -1 {
9292
for _, p := range paths {
93-
log.Warning(ctx, "path", "path", p)
9493
if err := os.Chown(path.Join(destBase, p), uid, gid); err != nil {
9594
log.Warning(ctx, "Failed to chown", "path", p, "err", err)
9695
}

runner/internal/executor/repo.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"io"
78
"io/fs"
89
"os"
910
"os/exec"
11+
"path"
1012
"path/filepath"
1113

1214
"github.com/codeclysm/extract/v4"
@@ -17,6 +19,23 @@ import (
1719
"github.com/dstackai/dstack/runner/internal/schemas"
1820
)
1921

22+
// WriteRepoBlob must be called after SetJob
23+
func (ex *RunExecutor) WriteRepoBlob(src io.Reader) error {
24+
if err := os.MkdirAll(ex.repoBlobDir, 0o755); err != nil {
25+
return fmt.Errorf("create blob directory: %w", err)
26+
}
27+
ex.repoBlobPath = path.Join(ex.repoBlobDir, ex.run.RunSpec.RepoId)
28+
blob, err := os.Create(ex.repoBlobPath)
29+
if err != nil {
30+
return fmt.Errorf("create blob file: %w", err)
31+
}
32+
defer func() { _ = blob.Close() }()
33+
if _, err = io.Copy(blob, src); err != nil {
34+
return fmt.Errorf("copy blob data: %w", err)
35+
}
36+
return nil
37+
}
38+
2039
// setupRepo must be called from Run
2140
// Must be called after setJobWorkingDir and setJobCredentials
2241
func (ex *RunExecutor) setupRepo(ctx context.Context) error {
@@ -100,6 +119,10 @@ func (ex *RunExecutor) setupRepo(ctx context.Context) error {
100119
return fmt.Errorf("chown repo dir: %w", err)
101120
}
102121

122+
if err := os.RemoveAll(ex.repoBlobDir); err != nil {
123+
log.Warning(ctx, "Failed to remove repo blobs dir", "path", ex.repoBlobDir, "err", err)
124+
}
125+
103126
return err
104127
}
105128

@@ -143,7 +166,7 @@ func (ex *RunExecutor) prepareGit(ctx context.Context) error {
143166
}
144167

145168
log.Trace(ctx, "Applying diff")
146-
repoDiff, err := os.ReadFile(ex.codePath)
169+
repoDiff, err := os.ReadFile(ex.repoBlobPath)
147170
if err != nil {
148171
return fmt.Errorf("read repo diff: %w", err)
149172
}
@@ -156,12 +179,12 @@ func (ex *RunExecutor) prepareGit(ctx context.Context) error {
156179
}
157180

158181
func (ex *RunExecutor) prepareArchive(ctx context.Context) error {
159-
file, err := os.Open(ex.codePath)
182+
file, err := os.Open(ex.repoBlobPath)
160183
if err != nil {
161184
return fmt.Errorf("open code archive: %w", err)
162185
}
163186
defer func() { _ = file.Close() }()
164-
log.Trace(ctx, "Extracting code archive", "src", ex.codePath, "dst", ex.repoDir)
187+
log.Trace(ctx, "Extracting code archive", "src", ex.repoBlobPath, "dst", ex.repoDir)
165188
if err := extract.Tar(ctx, file, ex.repoDir, nil); err != nil {
166189
return fmt.Errorf("extract tar archive: %w", err)
167190
}

runner/internal/runner/api/http.go

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ import (
99
"mime"
1010
"mime/multipart"
1111
"net/http"
12-
"os"
13-
"path/filepath"
1412
"strconv"
1513

1614
"github.com/dstackai/dstack/runner/internal/api"
@@ -19,6 +17,9 @@ import (
1917
"github.com/dstackai/dstack/runner/internal/schemas"
2018
)
2119

20+
// TODO: set some reasonable value; (optional) make configurable
21+
const maxBodySize = math.MaxInt64
22+
2223
func (s *Server) healthcheckGetHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) {
2324
return &schemas.HealthcheckResponse{
2425
Service: "dstack-runner",
@@ -84,13 +85,16 @@ func (s *Server) uploadArchivePostHandler(w http.ResponseWriter, r *http.Request
8485
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "missing boundary"}
8586
}
8687

87-
r.Body = http.MaxBytesReader(w, r.Body, math.MaxInt64)
88+
r.Body = http.MaxBytesReader(w, r.Body, maxBodySize)
8889
formReader := multipart.NewReader(r.Body, boundary)
8990
part, err := formReader.NextPart()
90-
if errors.Is(err, io.EOF) {
91-
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty form"}
92-
}
9391
if err != nil {
92+
if errors.Is(err, io.EOF) {
93+
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "empty form"}
94+
}
95+
if isMaxBytesError(err) {
96+
return nil, &api.Error{Status: http.StatusRequestEntityTooLarge}
97+
}
9498
return nil, fmt.Errorf("read multipart form: %w", err)
9599
}
96100
defer func() { _ = part.Close() }()
@@ -106,8 +110,11 @@ func (s *Server) uploadArchivePostHandler(w http.ResponseWriter, r *http.Request
106110
if archiveId == "" {
107111
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "missing file name"}
108112
}
109-
if err := s.executor.AddFileArchive(archiveId, part); err != nil {
110-
return nil, fmt.Errorf("add file archive: %w", err)
113+
if err := s.executor.WriteFileArchive(archiveId, part); err != nil {
114+
if isMaxBytesError(err) {
115+
return nil, &api.Error{Status: http.StatusRequestEntityTooLarge}
116+
}
117+
return nil, fmt.Errorf("write file archive: %w", err)
111118
}
112119
if _, err := formReader.NextPart(); !errors.Is(err, io.EOF) {
113120
return nil, &api.Error{Status: http.StatusBadRequest, Msg: "extra form field(s)"}
@@ -123,21 +130,17 @@ func (s *Server) uploadCodePostHandler(w http.ResponseWriter, r *http.Request) (
123130
return nil, &api.Error{Status: http.StatusConflict}
124131
}
125132

126-
r.Body = http.MaxBytesReader(w, r.Body, math.MaxInt64)
127-
codePath := filepath.Join(s.tempDir, "code") // todo random name?
128-
file, err := os.Create(codePath)
129-
if err != nil {
130-
return nil, fmt.Errorf("create code file: %w", err)
131-
}
132-
defer func() { _ = file.Close() }()
133-
if _, err = io.Copy(file, r.Body); err != nil {
134-
if err.Error() == "http: request body too large" {
133+
r.Body = http.MaxBytesReader(w, r.Body, maxBodySize)
134+
135+
if err := s.executor.WriteRepoBlob(r.Body); err != nil {
136+
if isMaxBytesError(err) {
135137
return nil, &api.Error{Status: http.StatusRequestEntityTooLarge}
136138
}
137139
return nil, fmt.Errorf("copy request body: %w", err)
138140
}
139141

140-
s.executor.SetCodePath(codePath)
142+
s.executor.SetRunnerState(executor.WaitRun)
143+
141144
return nil, nil
142145
}
143146

@@ -181,3 +184,8 @@ func (s *Server) stopPostHandler(w http.ResponseWriter, r *http.Request) (interf
181184
s.stop()
182185
return nil, nil
183186
}
187+
188+
func isMaxBytesError(err error) bool {
189+
var maxBytesError *http.MaxBytesError
190+
return errors.As(err, &maxBytesError)
191+
}

0 commit comments

Comments
 (0)