Skip to content

Commit 78d26ee

Browse files
authored
[runner] Fix MPI hostfile (#3441)
* Don't set slots on CPU nodes * Move the file to /dstack/mpi and make it world-readable Fixes: #3434 Fixes: #3436
1 parent b78851b commit 78d26ee

File tree

6 files changed

+77
-49
lines changed

6 files changed

+77
-49
lines changed

runner/cmd/runner/main.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io"
88
"os"
99
"os/signal"
10+
"path"
1011
"path/filepath"
1112
"syscall"
1213

@@ -121,27 +122,32 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss
121122
log.DefaultEntry.Logger.SetOutput(io.MultiWriter(os.Stdout, defaultLogFile))
122123
log.DefaultEntry.Logger.SetLevel(logrus.Level(logLevel))
123124

125+
// NB: The Mkdir/Chown/Chmod code below relies on the fact that RunnerDstackDir path is _not_ nested (/dstack).
126+
// Adjust it if the path is changed to, e.g., /opt/dstack
127+
const dstackDir = consts.RunnerDstackDir
128+
dstackSshDir := path.Join(dstackDir, "ssh")
129+
124130
// To ensure that all components of the authorized_keys path are owned by root and no directories
125131
// are group or world writable, as required by sshd with "StrictModes yes" (the default value),
126132
// we fix `/dstack` ownership and permissions and remove `/dstack/ssh` (it will be (re)created
127133
// in Sshd.Prepare())
128134
// See: https://github.com/openssh/openssh-portable/blob/d01efaa1c9ed84fd9011201dbc3c7cb0a82bcee3/misc.c#L2257-L2272
129-
if err := os.Mkdir("/dstack", 0o755); errors.Is(err, os.ErrExist) {
130-
if err := os.Chown("/dstack", 0, 0); err != nil {
135+
if err := os.Mkdir(dstackDir, 0o755); errors.Is(err, os.ErrExist) {
136+
if err := os.Chown(dstackDir, 0, 0); err != nil {
131137
return fmt.Errorf("chown dstack dir: %w", err)
132138
}
133-
if err := os.Chmod("/dstack", 0o755); err != nil {
139+
if err := os.Chmod(dstackDir, 0o755); err != nil {
134140
return fmt.Errorf("chmod dstack dir: %w", err)
135141
}
136142
} else if err != nil {
137143
return fmt.Errorf("create dstack dir: %w", err)
138144
}
139-
if err := os.RemoveAll("/dstack/ssh"); err != nil {
145+
if err := os.RemoveAll(dstackSshDir); err != nil {
140146
return fmt.Errorf("remove dstack ssh dir: %w", err)
141147
}
142148

143149
sshd := ssh.NewSshd("/usr/sbin/sshd")
144-
if err := sshd.Prepare(ctx, "/dstack/ssh", sshPort, "INFO"); err != nil {
150+
if err := sshd.Prepare(ctx, dstackSshDir, sshPort, "INFO"); err != nil {
145151
return fmt.Errorf("prepare sshd: %w", err)
146152
}
147153
if err := sshd.AddAuthorizedKeys(ctx, sshAuthorizedKeys...); err != nil {
@@ -156,7 +162,7 @@ func start(ctx context.Context, tempDir string, homeDir string, httpPort int, ss
156162
}
157163
}()
158164

159-
server, err := api.NewServer(ctx, tempDir, homeDir, fmt.Sprintf(":%d", httpPort), sshd, version)
165+
server, err := api.NewServer(ctx, tempDir, homeDir, dstackDir, sshd, fmt.Sprintf(":%d", httpPort), version)
160166
if err != nil {
161167
return fmt.Errorf("create server: %w", err)
162168
}

runner/cmd/shim/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func mainInner() int {
5656
Usage: "Set shim's home directory",
5757
Destination: &args.Shim.HomeDir,
5858
TakesFile: true,
59-
DefaultText: path.Join("~", consts.DstackDirPath),
59+
DefaultText: path.Join("~", consts.DstackUserDir),
6060
Sources: cli.EnvVars("DSTACK_SHIM_HOME"),
6161
},
6262
&cli.StringFlag{
@@ -187,7 +187,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
187187
if err != nil {
188188
return err
189189
}
190-
shimHomeDir = filepath.Join(home, consts.DstackDirPath)
190+
shimHomeDir = filepath.Join(home, consts.DstackUserDir)
191191
args.Shim.HomeDir = shimHomeDir
192192
}
193193

runner/consts/consts.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package consts
22

3-
const DstackDirPath string = ".dstack"
3+
// A directory inside user's home used for dstack-related files
4+
const DstackUserDir string = ".dstack"
45

56
// Runner's log filenames
67
const (
@@ -29,6 +30,13 @@ const (
2930
// The current user's homedir (as of 2024-12-28, it's always root) should be used
3031
// instead of the hardcoded value
3132
RunnerHomeDir = "/root"
33+
// A directory for:
34+
// 1. Files used by the runner and related components (e.g., sshd stores its config and log inside /dstack/ssh)
35+
// 2. Files shared between users (e.g., sshd authorized_keys, MPI hostfile)
36+
// The inner structure should be considered private and subject to change, the users should not make assumptions
37+
// about its structure.
38+
// The only way to access its content/paths should be via public environment variables such as DSTACK_MPI_HOSTFILE.
39+
RunnerDstackDir = "/dstack"
3240
)
3341

3442
const (

runner/internal/executor/executor.go

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ type ConnectionTracker interface {
5454
type RunExecutor struct {
5555
tempDir string
5656
homeDir string
57+
dstackDir string
5758
archiveDir string
5859
sshd ssh.SshdManager
5960

@@ -91,7 +92,7 @@ func (s *stubConnectionTracker) GetNoConnectionsSecs() int64 { return 0 }
9192
func (s *stubConnectionTracker) Track(ticker <-chan time.Time) {}
9293
func (s *stubConnectionTracker) Stop() {}
9394

94-
func NewRunExecutor(tempDir string, homeDir string, sshd ssh.SshdManager) (*RunExecutor, error) {
95+
func NewRunExecutor(tempDir string, homeDir string, dstackDir string, sshd ssh.SshdManager) (*RunExecutor, error) {
9596
mu := &sync.RWMutex{}
9697
timestamp := NewMonotonicTimestamp()
9798
user, err := osuser.Current()
@@ -124,6 +125,7 @@ func NewRunExecutor(tempDir string, homeDir string, sshd ssh.SshdManager) (*RunE
124125
return &RunExecutor{
125126
tempDir: tempDir,
126127
homeDir: homeDir,
128+
dstackDir: dstackDir,
127129
archiveDir: filepath.Join(tempDir, "file_archives"),
128130
sshd: sshd,
129131
currentUid: uid,
@@ -384,12 +386,12 @@ func (ex *RunExecutor) getRepoData() schemas.RepoData {
384386
}
385387

386388
func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error {
387-
node_rank := ex.jobSpec.JobNum
388-
nodes_num := ex.jobSpec.JobsPerReplica
389-
gpus_per_node_num := ex.clusterInfo.GPUSPerJob
390-
gpus_num := nodes_num * gpus_per_node_num
389+
nodeRank := ex.jobSpec.JobNum
390+
nodesNum := ex.jobSpec.JobsPerReplica
391+
gpusPerNodeNum := ex.clusterInfo.GPUSPerJob
392+
gpusNum := nodesNum * gpusPerNodeNum
391393

392-
mpiHostfilePath := filepath.Join(ex.homeDir, ".dstack/mpi/hostfile")
394+
mpiHostfilePath := filepath.Join(ex.dstackDir, "mpi/hostfile")
393395

394396
jobEnvs := map[string]string{
395397
"DSTACK_RUN_ID": ex.run.Id,
@@ -400,10 +402,10 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
400402
"DSTACK_WORKING_DIR": ex.jobWorkingDir,
401403
"DSTACK_NODES_IPS": strings.Join(ex.clusterInfo.JobIPs, "\n"),
402404
"DSTACK_MASTER_NODE_IP": ex.clusterInfo.MasterJobIP,
403-
"DSTACK_NODE_RANK": strconv.Itoa(node_rank),
404-
"DSTACK_NODES_NUM": strconv.Itoa(nodes_num),
405-
"DSTACK_GPUS_PER_NODE": strconv.Itoa(gpus_per_node_num),
406-
"DSTACK_GPUS_NUM": strconv.Itoa(gpus_num),
405+
"DSTACK_NODE_RANK": strconv.Itoa(nodeRank),
406+
"DSTACK_NODES_NUM": strconv.Itoa(nodesNum),
407+
"DSTACK_GPUS_PER_NODE": strconv.Itoa(gpusPerNodeNum),
408+
"DSTACK_GPUS_NUM": strconv.Itoa(gpusNum),
407409
"DSTACK_MPI_HOSTFILE": mpiHostfilePath,
408410
}
409411

@@ -460,7 +462,7 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
460462
envMap.Update(ex.jobSpec.Env, false)
461463

462464
const profilePath = "/etc/profile"
463-
const dstackProfilePath = "/dstack/profile"
465+
dstackProfilePath := path.Join(ex.dstackDir, "profile")
464466
if err := writeDstackProfile(envMap, dstackProfilePath); err != nil {
465467
log.Warning(ctx, "failed to write dstack_profile", "path", dstackProfilePath, "err", err)
466468
} else if err := includeDstackProfile(profilePath, dstackProfilePath); err != nil {
@@ -508,7 +510,7 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
508510
}
509511
}
510512

511-
err = writeMpiHostfile(ctx, ex.clusterInfo.JobIPs, gpus_per_node_num, mpiHostfilePath)
513+
err = writeMpiHostfile(ctx, ex.clusterInfo.JobIPs, gpusPerNodeNum, mpiHostfilePath)
512514
if err != nil {
513515
return fmt.Errorf("write MPI hostfile: %w", err)
514516
}
@@ -839,7 +841,7 @@ func prepareSSHDir(uid int, gid int, homeDir string) (string, error) {
839841
return sshDir, nil
840842
}
841843

842-
func writeMpiHostfile(ctx context.Context, ips []string, gpus_per_node int, path string) error {
844+
func writeMpiHostfile(ctx context.Context, ips []string, gpusPerNode int, path string) error {
843845
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
844846
return fmt.Errorf("create MPI hostfile directory: %w", err)
845847
}
@@ -855,9 +857,16 @@ func writeMpiHostfile(ctx context.Context, ips []string, gpus_per_node int, path
855857
}
856858
}
857859
if len(nonEmptyIps) == len(ips) {
860+
var template string
861+
if gpusPerNode == 0 {
862+
// CPU node: the number of slots defaults to the number of processor cores on that host
863+
// See: https://docs.open-mpi.org/en/main/launching-apps/scheduling.html#calculating-the-number-of-slots
864+
template = "%s\n"
865+
} else {
866+
template = fmt.Sprintf("%%s slots=%d\n", gpusPerNode)
867+
}
858868
for _, ip := range nonEmptyIps {
859-
line := fmt.Sprintf("%s slots=%d\n", ip, gpus_per_node)
860-
if _, err = file.WriteString(line); err != nil {
869+
if _, err = fmt.Fprintf(file, template, ip); err != nil {
861870
return fmt.Errorf("write MPI hostfile line: %w", err)
862871
}
863872
}

runner/internal/executor/executor_test.go

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ func TestExecutor_WorkingDir_Set(t *testing.T) {
2828

2929
ex.jobSpec.WorkingDir = &workingDir
3030
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "pwd")
31-
err = ex.setJobWorkingDir(context.TODO())
31+
err = ex.setJobWorkingDir(t.Context())
3232
require.NoError(t, err)
3333
require.Equal(t, workingDir, ex.jobWorkingDir)
3434
err = os.MkdirAll(workingDir, 0o755)
3535
require.NoError(t, err)
3636

37-
err = ex.execJob(context.TODO(), io.Writer(&b))
37+
err = ex.execJob(t.Context(), io.Writer(&b))
3838
assert.NoError(t, err)
3939
// Normalize line endings for cross-platform compatibility.
4040
assert.Equal(t, workingDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n"))
@@ -47,11 +47,11 @@ func TestExecutor_WorkingDir_NotSet(t *testing.T) {
4747
require.NoError(t, err)
4848
ex.jobSpec.WorkingDir = nil
4949
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "pwd")
50-
err = ex.setJobWorkingDir(context.TODO())
50+
err = ex.setJobWorkingDir(t.Context())
5151
require.NoError(t, err)
5252
require.Equal(t, cwd, ex.jobWorkingDir)
5353

54-
err = ex.execJob(context.TODO(), io.Writer(&b))
54+
err = ex.execJob(t.Context(), io.Writer(&b))
5555
assert.NoError(t, err)
5656
assert.Equal(t, cwd+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n"))
5757
}
@@ -61,7 +61,7 @@ func TestExecutor_HomeDir(t *testing.T) {
6161
ex := makeTestExecutor(t)
6262
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "echo ~")
6363

64-
err := ex.execJob(context.TODO(), io.Writer(&b))
64+
err := ex.execJob(t.Context(), io.Writer(&b))
6565
assert.NoError(t, err)
6666
assert.Equal(t, ex.homeDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n"))
6767
}
@@ -71,7 +71,7 @@ func TestExecutor_NonZeroExit(t *testing.T) {
7171
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "exit 100")
7272
makeCodeTar(t, ex.codePath)
7373

74-
err := ex.Run(context.TODO())
74+
err := ex.Run(t.Context())
7575
assert.Error(t, err)
7676
assert.NotEmpty(t, ex.jobStateHistory)
7777
exitStatus := ex.jobStateHistory[len(ex.jobStateHistory)-1].ExitStatus
@@ -90,11 +90,11 @@ func TestExecutor_SSHCredentials(t *testing.T) {
9090
PrivateKey: &key,
9191
}
9292

93-
clean, err := ex.setupCredentials(context.TODO())
93+
clean, err := ex.setupCredentials(t.Context())
9494
defer clean()
9595
require.NoError(t, err)
9696

97-
err = ex.execJob(context.TODO(), io.Writer(&b))
97+
err = ex.execJob(t.Context(), io.Writer(&b))
9898
assert.NoError(t, err)
9999
assert.Equal(t, key, b.String())
100100
}
@@ -106,10 +106,10 @@ func TestExecutor_LocalRepo(t *testing.T) {
106106
ex.jobSpec.Commands = append(ex.jobSpec.Commands, cmd)
107107
makeCodeTar(t, ex.codePath)
108108

109-
err := ex.setupRepo(context.TODO())
109+
err := ex.setupRepo(t.Context())
110110
require.NoError(t, err)
111111

112-
err = ex.execJob(context.TODO(), io.Writer(&b))
112+
err = ex.execJob(t.Context(), io.Writer(&b))
113113
assert.NoError(t, err)
114114
assert.Equal(t, "bar\n", strings.ReplaceAll(b.String(), "\r\n", "\n"))
115115
}
@@ -119,7 +119,7 @@ func TestExecutor_Recover(t *testing.T) {
119119
ex.jobSpec.Commands = nil // cause a panic
120120
makeCodeTar(t, ex.codePath)
121121

122-
err := ex.Run(context.TODO())
122+
err := ex.Run(t.Context())
123123
assert.ErrorContains(t, err, "recovered: ")
124124
}
125125

@@ -136,7 +136,7 @@ func TestExecutor_MaxDuration(t *testing.T) {
136136
ex.jobSpec.MaxDuration = 1 // seconds
137137
makeCodeTar(t, ex.codePath)
138138

139-
err := ex.Run(context.TODO())
139+
err := ex.Run(t.Context())
140140
assert.ErrorContains(t, err, "killed")
141141
}
142142

@@ -158,12 +158,12 @@ func TestExecutor_RemoteRepo(t *testing.T) {
158158
err := os.WriteFile(ex.codePath, []byte{}, 0o600) // empty diff
159159
require.NoError(t, err)
160160

161-
err = ex.setJobWorkingDir(context.TODO())
161+
err = ex.setJobWorkingDir(t.Context())
162162
require.NoError(t, err)
163-
err = ex.setupRepo(context.TODO())
163+
err = ex.setupRepo(t.Context())
164164
require.NoError(t, err)
165165

166-
err = ex.execJob(context.TODO(), io.Writer(&b))
166+
err = ex.execJob(t.Context(), io.Writer(&b))
167167
assert.NoError(t, err)
168168
expected := fmt.Sprintf("%s\n%s\n%s\n", ex.getRepoData().RepoHash, ex.getRepoData().RepoConfigName, ex.getRepoData().RepoConfigEmail)
169169
assert.Equal(t, expected, strings.ReplaceAll(b.String(), "\r\n", "\n"))
@@ -204,11 +204,13 @@ func makeTestExecutor(t *testing.T) *RunExecutor {
204204
},
205205
}
206206

207-
temp := filepath.Join(baseDir, "temp")
208-
_ = os.Mkdir(temp, 0o700)
209-
home := filepath.Join(baseDir, "home")
210-
_ = os.Mkdir(home, 0o700)
211-
ex, _ := NewRunExecutor(temp, home, new(sshdMock))
207+
tempDir := filepath.Join(baseDir, "temp")
208+
require.NoError(t, os.Mkdir(tempDir, 0o700))
209+
homeDir := filepath.Join(baseDir, "home")
210+
require.NoError(t, os.Mkdir(homeDir, 0o700))
211+
dstackDir := filepath.Join(baseDir, "dstack")
212+
require.NoError(t, os.Mkdir(dstackDir, 0o755))
213+
ex, _ := NewRunExecutor(tempDir, homeDir, dstackDir, new(sshdMock))
212214
ex.SetJob(body)
213215
ex.SetCodePath(filepath.Join(baseDir, "code")) // note: create file before run
214216
ex.setJobWorkingDir(context.Background())
@@ -261,7 +263,7 @@ func TestExecutor_Logs(t *testing.T) {
261263
// \033[31m = red text, \033[1;32m = bold green text, \033[0m = reset
262264
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "printf '\\033[31mRed Hello World\\033[0m\\n' && printf '\\033[1;32mBold Green Line 2\\033[0m\\n' && printf 'Line 3\\n'")
263265

264-
err := ex.execJob(context.TODO(), io.Writer(&b))
266+
err := ex.execJob(t.Context(), io.Writer(&b))
265267
assert.NoError(t, err)
266268

267269
logHistory := ex.GetHistory(0).JobLogs
@@ -285,7 +287,7 @@ func TestExecutor_LogsWithErrors(t *testing.T) {
285287
ex := makeTestExecutor(t)
286288
ex.jobSpec.Commands = append(ex.jobSpec.Commands, "echo 'Success message' && echo 'Error message' >&2 && exit 1")
287289

288-
err := ex.execJob(context.TODO(), io.Writer(&b))
290+
err := ex.execJob(t.Context(), io.Writer(&b))
289291
assert.Error(t, err)
290292

291293
logHistory := ex.GetHistory(0).JobLogs
@@ -309,7 +311,7 @@ func TestExecutor_LogsAnsiCodeHandling(t *testing.T) {
309311

310312
ex.jobSpec.Commands = append(ex.jobSpec.Commands, cmd)
311313

312-
err := ex.execJob(context.TODO(), io.Writer(&b))
314+
err := ex.execJob(t.Context(), io.Writer(&b))
313315
assert.NoError(t, err)
314316

315317
// 1. Check WebSocket logs, which should preserve ANSI codes.

runner/internal/runner/api/server.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,12 @@ type Server struct {
3434
version string
3535
}
3636

37-
func NewServer(ctx context.Context, tempDir string, homeDir string, address string, sshd ssh.SshdManager, version string) (*Server, error) {
37+
func NewServer(
38+
ctx context.Context, tempDir string, homeDir string, dstackDir string, sshd ssh.SshdManager,
39+
address string, version string,
40+
) (*Server, error) {
3841
r := api.NewRouter()
39-
ex, err := executor.NewRunExecutor(tempDir, homeDir, sshd)
42+
ex, err := executor.NewRunExecutor(tempDir, homeDir, dstackDir, sshd)
4043
if err != nil {
4144
return nil, err
4245
}

0 commit comments

Comments
 (0)