Skip to content

Commit fb19ac0

Browse files
refactor(cli): extract shared utilities and fix review issues
- Extract getHostURL and connectSSH into cmd/cli/common package, eliminating duplication across ssh, scp, and get commands (I1) - Use path instead of filepath for remote SFTP paths to ensure correct POSIX separators on all platforms (I2) - Log warnings for skipped files during recursive remote copy (I3) - Use c.IsSet() for flags with defaults to prevent overwriting existing config with default values (I4) Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
1 parent ec65767 commit fb19ac0

File tree

7 files changed

+149
-231
lines changed

7 files changed

+149
-231
lines changed

cmd/cli/common/host.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package common
18+
19+
import (
20+
"fmt"
21+
"os"
22+
"time"
23+
24+
"golang.org/x/crypto/ssh"
25+
26+
"github.com/NVIDIA/holodeck/api/holodeck/v1alpha1"
27+
"github.com/NVIDIA/holodeck/internal/logger"
28+
"github.com/NVIDIA/holodeck/pkg/provider/aws"
29+
)
30+
31+
// GetHostURL resolves the SSH-reachable host URL for an environment.
32+
// If nodeName is set, it looks for that specific node.
33+
// If preferControlPlane is true and no nodeName is set, it prefers a control-plane node.
34+
// Falls back to the first available node, then single-node properties.
35+
func GetHostURL(env *v1alpha1.Environment, nodeName string, preferControlPlane bool) (string, error) {
36+
// For multinode clusters, find the appropriate node
37+
if env.Spec.Cluster != nil && env.Status.Cluster != nil && len(env.Status.Cluster.Nodes) > 0 {
38+
if nodeName != "" {
39+
for _, node := range env.Status.Cluster.Nodes {
40+
if node.Name == nodeName {
41+
return node.PublicIP, nil
42+
}
43+
}
44+
return "", fmt.Errorf("node %q not found in cluster", nodeName)
45+
}
46+
47+
if preferControlPlane {
48+
for _, node := range env.Status.Cluster.Nodes {
49+
if node.Role == "control-plane" {
50+
return node.PublicIP, nil
51+
}
52+
}
53+
}
54+
55+
// Fallback to first node
56+
return env.Status.Cluster.Nodes[0].PublicIP, nil
57+
}
58+
59+
// Single node - get from properties
60+
switch env.Spec.Provider {
61+
case v1alpha1.ProviderAWS:
62+
for _, p := range env.Status.Properties {
63+
if p.Name == aws.PublicDnsName {
64+
return p.Value, nil
65+
}
66+
}
67+
case v1alpha1.ProviderSSH:
68+
return env.Spec.HostUrl, nil
69+
}
70+
71+
return "", fmt.Errorf("unable to determine host URL")
72+
}
73+
74+
// ConnectSSH establishes an SSH connection with retries.
75+
// Holodeck instances are ephemeral with no pre-established host keys,
76+
// so host key verification is intentionally disabled.
77+
func ConnectSSH(log *logger.FunLogger, keyPath, userName, hostUrl string) (*ssh.Client, error) {
78+
key, err := os.ReadFile(keyPath)
79+
if err != nil {
80+
return nil, fmt.Errorf("failed to read key file %s: %v", keyPath, err)
81+
}
82+
83+
signer, err := ssh.ParsePrivateKey(key)
84+
if err != nil {
85+
return nil, fmt.Errorf("failed to parse private key: %v", err)
86+
}
87+
88+
config := &ssh.ClientConfig{
89+
User: userName,
90+
Auth: []ssh.AuthMethod{
91+
ssh.PublicKeys(signer),
92+
},
93+
// Holodeck instances are ephemeral with no pre-established host keys
94+
HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec
95+
Timeout: 30 * time.Second,
96+
}
97+
98+
var client *ssh.Client
99+
for i := 0; i < 3; i++ {
100+
client, err = ssh.Dial("tcp", hostUrl+":22", config)
101+
if err == nil {
102+
return client, nil
103+
}
104+
log.Warning("Connection attempt %d failed: %v", i+1, err)
105+
time.Sleep(2 * time.Second)
106+
}
107+
108+
return nil, fmt.Errorf("failed to connect after 3 attempts: %v", err)
109+
}

cmd/cli/get/get.go

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ import (
2222
"path/filepath"
2323

2424
"github.com/NVIDIA/holodeck/api/holodeck/v1alpha1"
25+
"github.com/NVIDIA/holodeck/cmd/cli/common"
2526
"github.com/NVIDIA/holodeck/internal/instances"
2627
"github.com/NVIDIA/holodeck/internal/logger"
2728
"github.com/NVIDIA/holodeck/pkg/jyaml"
28-
"github.com/NVIDIA/holodeck/pkg/provider/aws"
2929
"github.com/NVIDIA/holodeck/pkg/utils"
3030

3131
cli "github.com/urfave/cli/v2"
@@ -161,7 +161,7 @@ func (m command) runKubeconfig(instanceID string) error {
161161
}
162162

163163
// Determine host URL
164-
hostUrl, err := m.getHostURL(&env, true)
164+
hostUrl, err := common.GetHostURL(&env, m.node, true)
165165
if err != nil {
166166
return fmt.Errorf("failed to get host URL: %v", err)
167167
}
@@ -217,7 +217,7 @@ func (m command) runSSHConfig(instanceID string) error {
217217
}
218218

219219
// Single node
220-
hostUrl, err := m.getHostURL(&env, false)
220+
hostUrl, err := common.GetHostURL(&env, m.node, false)
221221
if err != nil {
222222
return fmt.Errorf("failed to get host URL: %v", err)
223223
}
@@ -255,42 +255,3 @@ func (m command) generateClusterSSHConfig(instanceID string, env *v1alpha1.Envir
255255
return nil
256256
}
257257

258-
func (m command) getHostURL(env *v1alpha1.Environment, controlPlaneOnly bool) (string, error) {
259-
// For multinode clusters, find the appropriate node
260-
if env.Spec.Cluster != nil && env.Status.Cluster != nil && len(env.Status.Cluster.Nodes) > 0 {
261-
// If a specific node is requested, find it
262-
if m.node != "" {
263-
for _, node := range env.Status.Cluster.Nodes {
264-
if node.Name == m.node {
265-
return node.PublicIP, nil
266-
}
267-
}
268-
return "", fmt.Errorf("node %q not found in cluster", m.node)
269-
}
270-
271-
// Default to first control-plane node (required for kubeconfig)
272-
if controlPlaneOnly {
273-
for _, node := range env.Status.Cluster.Nodes {
274-
if node.Role == "control-plane" {
275-
return node.PublicIP, nil
276-
}
277-
}
278-
}
279-
280-
// Fallback to first node
281-
return env.Status.Cluster.Nodes[0].PublicIP, nil
282-
}
283-
284-
// Single node - get from properties
285-
if env.Spec.Provider == v1alpha1.ProviderAWS {
286-
for _, p := range env.Status.Properties {
287-
if p.Name == aws.PublicDnsName {
288-
return p.Value, nil
289-
}
290-
}
291-
} else if env.Spec.Provider == v1alpha1.ProviderSSH {
292-
return env.Spec.HostUrl, nil
293-
}
294-
295-
return "", fmt.Errorf("unable to determine host URL")
296-
}

cmd/cli/get/get_test.go

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ import (
2020
"testing"
2121

2222
"github.com/NVIDIA/holodeck/api/holodeck/v1alpha1"
23+
"github.com/NVIDIA/holodeck/cmd/cli/common"
2324
"github.com/NVIDIA/holodeck/pkg/provider/aws"
2425
)
2526

2627
func TestGetHostURL_AWS_SingleNode(t *testing.T) {
27-
cmd := command{}
2828
env := &v1alpha1.Environment{
2929
Spec: v1alpha1.EnvironmentSpec{
3030
Provider: v1alpha1.ProviderAWS,
@@ -36,7 +36,7 @@ func TestGetHostURL_AWS_SingleNode(t *testing.T) {
3636
},
3737
}
3838

39-
url, err := cmd.getHostURL(env, false)
39+
url, err := common.GetHostURL(env, "", false)
4040
if err != nil {
4141
t.Fatalf("unexpected error: %v", err)
4242
}
@@ -46,7 +46,6 @@ func TestGetHostURL_AWS_SingleNode(t *testing.T) {
4646
}
4747

4848
func TestGetHostURL_SSH(t *testing.T) {
49-
cmd := command{}
5049
env := &v1alpha1.Environment{
5150
Spec: v1alpha1.EnvironmentSpec{
5251
Provider: v1alpha1.ProviderSSH,
@@ -56,7 +55,7 @@ func TestGetHostURL_SSH(t *testing.T) {
5655
},
5756
}
5857

59-
url, err := cmd.getHostURL(env, false)
58+
url, err := common.GetHostURL(env, "", false)
6059
if err != nil {
6160
t.Fatalf("unexpected error: %v", err)
6261
}
@@ -66,7 +65,6 @@ func TestGetHostURL_SSH(t *testing.T) {
6665
}
6766

6867
func TestGetHostURL_NoProperties(t *testing.T) {
69-
cmd := command{}
7068
env := &v1alpha1.Environment{
7169
Spec: v1alpha1.EnvironmentSpec{
7270
Provider: v1alpha1.ProviderAWS,
@@ -76,14 +74,13 @@ func TestGetHostURL_NoProperties(t *testing.T) {
7674
},
7775
}
7876

79-
_, err := cmd.getHostURL(env, false)
77+
_, err := common.GetHostURL(env, "", false)
8078
if err == nil {
8179
t.Error("expected error for missing properties")
8280
}
8381
}
8482

8583
func TestGetHostURL_Cluster_ControlPlaneOnly(t *testing.T) {
86-
cmd := command{}
8784
env := &v1alpha1.Environment{
8885
Spec: v1alpha1.EnvironmentSpec{
8986
Provider: v1alpha1.ProviderAWS,
@@ -99,7 +96,7 @@ func TestGetHostURL_Cluster_ControlPlaneOnly(t *testing.T) {
9996
},
10097
}
10198

102-
url, err := cmd.getHostURL(env, true)
99+
url, err := common.GetHostURL(env, "", true)
103100
if err != nil {
104101
t.Fatalf("unexpected error: %v", err)
105102
}
@@ -109,7 +106,6 @@ func TestGetHostURL_Cluster_ControlPlaneOnly(t *testing.T) {
109106
}
110107

111108
func TestGetHostURL_Cluster_SpecificNode(t *testing.T) {
112-
cmd := command{node: "worker-0"}
113109
env := &v1alpha1.Environment{
114110
Spec: v1alpha1.EnvironmentSpec{
115111
Provider: v1alpha1.ProviderAWS,
@@ -125,7 +121,7 @@ func TestGetHostURL_Cluster_SpecificNode(t *testing.T) {
125121
},
126122
}
127123

128-
url, err := cmd.getHostURL(env, false)
124+
url, err := common.GetHostURL(env, "worker-0", false)
129125
if err != nil {
130126
t.Fatalf("unexpected error: %v", err)
131127
}
@@ -135,7 +131,6 @@ func TestGetHostURL_Cluster_SpecificNode(t *testing.T) {
135131
}
136132

137133
func TestGetHostURL_Cluster_NodeNotFound(t *testing.T) {
138-
cmd := command{node: "nonexistent"}
139134
env := &v1alpha1.Environment{
140135
Spec: v1alpha1.EnvironmentSpec{
141136
Provider: v1alpha1.ProviderAWS,
@@ -150,7 +145,7 @@ func TestGetHostURL_Cluster_NodeNotFound(t *testing.T) {
150145
},
151146
}
152147

153-
_, err := cmd.getHostURL(env, false)
148+
_, err := common.GetHostURL(env, "nonexistent", false)
154149
if err == nil {
155150
t.Error("expected error for nonexistent node")
156151
}

0 commit comments

Comments
 (0)