Skip to content

Commit 4341282

Browse files
authored
Merge pull request #2394 from dstackai/issue_2368_multi_node_ssh_connectivity
Configure inter-node SSH on multi-node tasks
2 parents 0a35951 + 0750b82 commit 4341282

File tree

15 files changed

+289
-7
lines changed

15 files changed

+289
-7
lines changed

docs/docs/concepts/tasks.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ resources:
136136
</div>
137137

138138
Nodes can communicate using their private IP addresses.
139-
Use `DSTACK_MASTER_NODE_IP`, `$DSTACK_NODE_RANK`, and other
139+
Use `DSTACK_MASTER_NODE_IP`, `DSTACK_NODES_IPS`, `DSTACK_NODE_RANK`, and other
140140
[System environment variables](#system-environment-variables)
141141
to discover IP addresses and other details.
142142

@@ -159,6 +159,11 @@ to discover IP addresses and other details.
159159
# ... The rest of the commands
160160
```
161161

162+
??? info "SSH"
163+
You can log in to any node from any node via SSH on port 10022 using the `~/.ssh/dstack_job` private key.
164+
For convenience, `~/.ssh/config` is preconfigured with these options, so a simple `ssh <node_ip>` is enough.
165+
For a list of nodes IPs check the `DSTACK_NODES_IPS` environment variable.
166+
162167
!!! info "Fleets"
163168
Distributed tasks can only run on fleets with
164169
[cluster placement](fleets.md#cloud-placement).

docs/examples.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,14 @@ hide:
155155
Use Docker and Docker Compose inside runs
156156
</p>
157157
</a>
158+
<a href="/examples/misc/nccl-tests"
159+
class="feature-cell sky">
160+
<h3>
161+
NCCL Tests
162+
</h3>
163+
164+
<p>
165+
Run multi-node NCCL Tests with MPI
166+
</p>
167+
</a>
158168
</div>

docs/examples/misc/nccl-tests/index.md

Whitespace-only changes.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
type: task
2+
name: nccl-tests
3+
4+
image: un1def/aws-efa-test
5+
nodes: 2
6+
7+
env:
8+
- NCCL_DEBUG=INFO
9+
10+
commands:
11+
- |
12+
# We use FIFO for inter-node communication
13+
FIFO=/tmp/dstack_job
14+
if [ ${DSTACK_NODE_RANK} -eq 0 ]; then
15+
cd /root/nccl-tests/build
16+
echo "${DSTACK_NODES_IPS}" > hostfile
17+
MPIRUN='mpirun --allow-run-as-root --hostfile hostfile'
18+
# Wait for other nodes
19+
while true; do
20+
if ${MPIRUN} -n ${DSTACK_NODES_NUM} -N 1 true >/dev/null 2>&1; then
21+
break
22+
fi
23+
echo 'Waiting for nodes...'
24+
sleep 5
25+
done
26+
# Run NCCL Tests
27+
${MPIRUN} \
28+
-n $((DSTACK_NODES_NUM * DSTACK_GPUS_PER_NODE)) -N ${DSTACK_GPUS_PER_NODE} \
29+
--mca btl_tcp_if_exclude lo,docker0 \
30+
--bind-to none \
31+
./all_reduce_perf -b 8 -e 8G -f 2 -g 1
32+
# Notify nodes the job is done
33+
${MPIRUN} -n ${DSTACK_NODES_NUM} -N 1 sh -c "echo done > ${FIFO}"
34+
else
35+
mkfifo ${FIFO}
36+
# Wait for a message from the first node
37+
cat ${FIFO}
38+
fi
39+
40+
resources:
41+
gpu: nvidia:4:16GB
42+
shm_size: 16GB

examples/misc/nccl-tests/README.md

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# NCCL Tests
2+
3+
This example shows how to run distributed [NCCL Tests :material-arrow-top-right-thin:{ .external }](https://github.com/NVIDIA/nccl-tests){:target="_blank"} with MPI using `dstack`.
4+
5+
??? info "AWS EFA"
6+
The used image is optimized for AWS [EFA :material-arrow-top-right-thin:{ .external }](https://aws.amazon.com/hpc/efa/){:target="_blank"} but works with regular TCP/IP network adapters as well.
7+
8+
## Configuration
9+
10+
This configuration runs AllReduce test on 2 nodes with 4 GPUs each (8 processes total), but you can adjust both `nodes` and `resources.gpu` without modifying the script.
11+
12+
<div editor-title="examples/misc/nccl-tests/.dstack.yml">
13+
14+
```yaml
15+
type: task
16+
name: nccl-tests
17+
18+
image: un1def/aws-efa-test
19+
nodes: 2
20+
21+
env:
22+
- NCCL_DEBUG=INFO
23+
24+
commands:
25+
- |
26+
# We use FIFO for inter-node communication
27+
FIFO=/tmp/dstack_job
28+
if [ ${DSTACK_NODE_RANK} -eq 0 ]; then
29+
cd /root/nccl-tests/build
30+
echo "${DSTACK_NODES_IPS}" > hostfile
31+
MPIRUN='mpirun --allow-run-as-root --hostfile hostfile'
32+
# Wait for other nodes
33+
while true; do
34+
if ${MPIRUN} -n ${DSTACK_NODES_NUM} -N 1 true >/dev/null 2>&1; then
35+
break
36+
fi
37+
echo 'Waiting for nodes...'
38+
sleep 5
39+
done
40+
# Run NCCL Tests
41+
${MPIRUN} \
42+
-n $((DSTACK_NODES_NUM * DSTACK_GPUS_PER_NODE)) -N ${DSTACK_GPUS_PER_NODE} \
43+
--mca btl_tcp_if_exclude lo,docker0 \
44+
--bind-to none \
45+
./all_reduce_perf -b 8 -e 8G -f 2 -g 1
46+
# Notify nodes the job is done
47+
${MPIRUN} -n ${DSTACK_NODES_NUM} -N 1 sh -c "echo done > ${FIFO}"
48+
else
49+
mkfifo ${FIFO}
50+
# Wait for a message from the first node
51+
cat ${FIFO}
52+
fi
53+
54+
resources:
55+
gpu: nvidia:4:16GB
56+
shm_size: 16GB
57+
58+
```
59+
60+
</div>
61+
62+
### Running a configuration
63+
64+
To run a configuration, use the [`dstack apply`](https://dstack.ai/docs/reference/cli/dstack/apply/) command.
65+
66+
<div class="termy">
67+
68+
```shell
69+
$ dstack apply -f examples/misc/nccl-tests/.dstack.yml
70+
71+
# BACKEND REGION INSTANCE RESOURCES SPOT PRICE
72+
1 aws us-east-1 g4dn.12xlarge 48xCPU, 192GB, 4xT4 (16GB), 100.0GB (disk) no $3.912
73+
2 aws us-west-2 g4dn.12xlarge 48xCPU, 192GB, 4xT4 (16GB), 100.0GB (disk) no $3.912
74+
3 aws us-east-2 g4dn.12xlarge 48xCPU, 192GB, 4xT4 (16GB), 100.0GB (disk) no $3.912
75+
76+
Submit the run nccl-tests? [y/n]: y
77+
```
78+
79+
</div>

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ nav:
273273
- Llama 3.2: examples/llms/llama32/index.md
274274
- Misc:
275275
- Docker Compose: examples/misc/docker-compose/index.md
276+
- NCCL Tests: examples/misc/nccl-tests/index.md
276277
# - Community: community.md
277278
- Partners: partners.md
278279
- Blog:

runner/internal/executor/executor.go

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package executor
22

33
import (
4+
"bytes"
45
"context"
56
"errors"
67
"fmt"
@@ -30,6 +31,7 @@ type RunExecutor struct {
3031
tempDir string
3132
homeDir string
3233
workingDir string
34+
sshPort int
3335
uid uint32
3436

3537
run schemas.RunSpec
@@ -74,6 +76,7 @@ func NewRunExecutor(tempDir string, homeDir string, workingDir string, sshPort i
7476
tempDir: tempDir,
7577
homeDir: homeDir,
7678
workingDir: workingDir,
79+
sshPort: sshPort,
7780
uid: uid,
7881

7982
mu: mu,
@@ -322,15 +325,18 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
322325
log.Warning(ctx, "failed to write SSH environment", "path", ex.homeDir, "err", err)
323326
}
324327
}
328+
userSSHDir := ""
329+
uid := -1
330+
gid := -1
325331
if user != nil && *user.Uid != 0 {
326332
// non-root user
327-
uid := int(*user.Uid)
328-
gid := int(*user.Gid)
333+
uid = int(*user.Uid)
334+
gid = int(*user.Gid)
329335
homeDir, isHomeDirAccessible := prepareHomeDir(ctx, uid, gid, user.HomeDir)
330336
envMap["HOME"] = homeDir
331337
if isHomeDirAccessible {
332338
log.Trace(ctx, "provisioning homeDir", "path", homeDir)
333-
userSSHDir, err := prepareSSHDir(uid, gid, homeDir)
339+
userSSHDir, err = prepareSSHDir(uid, gid, homeDir)
334340
if err != nil {
335341
log.Warning(ctx, "failed to prepare ssh dir", "home", homeDir, "err", err)
336342
} else {
@@ -354,6 +360,17 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
354360
} else {
355361
// root user
356362
envMap["HOME"] = ex.homeDir
363+
userSSHDir = filepath.Join(ex.homeDir, ".ssh")
364+
}
365+
366+
if ex.jobSpec.SSHKey != nil && userSSHDir != "" {
367+
err := configureSSH(
368+
ex.jobSpec.SSHKey.Private, ex.jobSpec.SSHKey.Public, ex.clusterInfo.JobIPs, ex.sshPort,
369+
uid, gid, userSSHDir,
370+
)
371+
if err != nil {
372+
log.Warning(ctx, "failed to configure SSH", "err", err)
373+
}
357374
}
358375

359376
cmd.Env = envMap.Render()
@@ -712,6 +729,56 @@ func writeSSHEnvironment(env map[string]string, uid int, gid int, envPath string
712729
return nil
713730
}
714731

732+
func configureSSH(private string, public string, ips []string, port int, uid int, gid int, sshDir string) error {
733+
privatePath := filepath.Join(sshDir, "dstack_job")
734+
privateFile, err := os.OpenFile(privatePath, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0o600)
735+
if err != nil {
736+
return err
737+
}
738+
defer privateFile.Close()
739+
if err := os.Chown(privatePath, uid, gid); err != nil {
740+
return err
741+
}
742+
if _, err := privateFile.WriteString(private); err != nil {
743+
return err
744+
}
745+
746+
akPath := filepath.Join(sshDir, "authorized_keys")
747+
akFile, err := os.OpenFile(akPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o600)
748+
if err != nil {
749+
return err
750+
}
751+
defer akFile.Close()
752+
if err := os.Chown(akPath, uid, gid); err != nil {
753+
return err
754+
}
755+
if _, err := akFile.WriteString(public); err != nil {
756+
return err
757+
}
758+
759+
configPath := filepath.Join(sshDir, "config")
760+
configFile, err := os.OpenFile(configPath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0o600)
761+
if err != nil {
762+
return err
763+
}
764+
defer configFile.Close()
765+
if err := os.Chown(configPath, uid, gid); err != nil {
766+
return err
767+
}
768+
var configBuffer bytes.Buffer
769+
for _, ip := range ips {
770+
configBuffer.WriteString(fmt.Sprintf("\nHost %s\n", ip))
771+
configBuffer.WriteString(fmt.Sprintf(" Port %d\n", port))
772+
configBuffer.WriteString(" StrictHostKeyChecking no\n")
773+
configBuffer.WriteString(" UserKnownHostsFile /dev/null\n")
774+
configBuffer.WriteString(fmt.Sprintf(" IdentityFile %s\n", privatePath))
775+
}
776+
if _, err := configFile.Write(configBuffer.Bytes()); err != nil {
777+
return err
778+
}
779+
return nil
780+
}
781+
715782
// A makeshift solution to deliver authorized_keys to a non-root user
716783
// without modifying the existing API/bootstrap process
717784
// TODO: implement key delivery properly, i.e. sumbit keys to and write by the runner,

runner/internal/schemas/schemas.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ type JobSpec struct {
5454
Env map[string]string `json:"env"`
5555
SingleBranch bool `json:"single_branch"`
5656
MaxDuration int `json:"max_duration"`
57+
SSHKey *SSHKey `json:"ssh_key"`
5758
WorkingDir *string `json:"working_dir"`
5859
}
5960

@@ -63,6 +64,11 @@ type ClusterInfo struct {
6364
GPUSPerJob int `json:"gpus_per_job"`
6465
}
6566

67+
type SSHKey struct {
68+
Private string `json:"private"`
69+
Public string `json:"public"`
70+
}
71+
6672
type RepoCredentials struct {
6773
CloneURL string `json:"clone_url"`
6874
PrivateKey *string `json:"private_key"`

src/dstack/_internal/core/models/runs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ class Gateway(CoreModel):
178178
options: dict = {}
179179

180180

181+
class JobSSHKey(CoreModel):
182+
private: str
183+
public: str
184+
185+
181186
class JobSpec(CoreModel):
182187
replica_num: int = 0 # default value for backward compatibility
183188
job_num: int
@@ -198,6 +203,7 @@ class JobSpec(CoreModel):
198203
requirements: Requirements
199204
retry: Optional[Retry]
200205
volumes: Optional[List[MountPoint]] = None
206+
ssh_key: Optional[JobSSHKey] = None
201207
# For backward compatibility with 0.18.x when retry_policy was required.
202208
# TODO: remove in 0.19
203209
retry_policy: ProfileRetryPolicy = ProfileRetryPolicy(retry=False)

src/dstack/_internal/server/background/tasks/process_running_jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
127127
run_model = res.unique().scalar_one()
128128
repo_model = run_model.repo
129129
project = run_model.project
130-
run = run_model_to_run(run_model)
130+
run = run_model_to_run(run_model, include_sensitive=True)
131131
job_submission = job_model_to_job_submission(job_model)
132132
job_provisioning_data = job_submission.job_provisioning_data
133133
if job_provisioning_data is None:

0 commit comments

Comments
 (0)