Skip to content

Commit f7d8bba

Browse files
committed
Refactor ssh server and client
1 parent f3e657a commit f7d8bba

19 files changed

+3499
-521
lines changed

client/cmd/ssh.go

Lines changed: 129 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ package cmd
33
import (
44
"context"
55
"errors"
6+
"flag"
67
"fmt"
78
"os"
89
"os/signal"
10+
"os/user"
911
"strings"
1012
"syscall"
1113

@@ -17,43 +19,34 @@ import (
1719
)
1820

1921
var (
20-
port int
21-
user = "root"
22-
host string
22+
port int
23+
username string
24+
host string
25+
command string
2326
)
2427

2528
var sshCmd = &cobra.Command{
26-
Use: "ssh [user@]host",
27-
Args: func(cmd *cobra.Command, args []string) error {
28-
if len(args) < 1 {
29-
return errors.New("requires a host argument")
30-
}
31-
32-
split := strings.Split(args[0], "@")
33-
if len(split) == 2 {
34-
user = split[0]
35-
host = split[1]
36-
} else {
37-
host = args[0]
38-
}
39-
40-
return nil
41-
},
42-
Short: "connect to a remote SSH server",
29+
Use: "ssh [user@]host [command]",
30+
Short: "Connect to a NetBird peer via SSH",
31+
Long: `Connect to a NetBird peer using SSH.
32+
33+
Examples:
34+
netbird ssh peer-hostname
35+
netbird ssh user@peer-hostname
36+
netbird ssh peer-hostname --login myuser
37+
netbird ssh peer-hostname -p 22022
38+
netbird ssh peer-hostname ls -la
39+
netbird ssh peer-hostname whoami`,
40+
DisableFlagParsing: true,
41+
Args: validateSSHArgsWithoutFlagParsing,
4342
RunE: func(cmd *cobra.Command, args []string) error {
4443
SetFlagsFromEnvVars(rootCmd)
4544
SetFlagsFromEnvVars(cmd)
4645

4746
cmd.SetOut(cmd.OutOrStdout())
4847

49-
err := util.InitLog(logLevel, "console")
50-
if err != nil {
51-
return fmt.Errorf("failed initializing log %v", err)
52-
}
53-
54-
if !util.IsAdmin() {
55-
cmd.Printf("error: you must have Administrator privileges to run this command\n")
56-
return nil
48+
if err := util.InitLog(logLevel, "console"); err != nil {
49+
return fmt.Errorf("init log: %w", err)
5750
}
5851

5952
ctx := internal.CtxInitState(cmd.Context())
@@ -62,15 +55,14 @@ var sshCmd = &cobra.Command{
6255
ConfigPath: configPath,
6356
})
6457
if err != nil {
65-
return err
58+
return fmt.Errorf("update config: %w", err)
6659
}
6760

6861
sig := make(chan os.Signal, 1)
6962
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
7063
sshctx, cancel := context.WithCancel(ctx)
7164

7265
go func() {
73-
// blocking
7466
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
7567
cmd.Printf("Error: %v\n", err)
7668
os.Exit(1)
@@ -88,31 +80,124 @@ var sshCmd = &cobra.Command{
8880
},
8981
}
9082

83+
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
84+
if len(args) < 1 {
85+
return errors.New("host argument required")
86+
}
87+
88+
// Reset globals to defaults
89+
port = nbssh.DefaultSSHPort
90+
username = ""
91+
host = ""
92+
command = ""
93+
94+
// Create a new FlagSet for parsing SSH-specific flags
95+
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
96+
fs.SetOutput(nil) // Suppress error output
97+
98+
// Define SSH-specific flags
99+
portFlag := fs.Int("p", nbssh.DefaultSSHPort, "SSH port")
100+
fs.Int("port", nbssh.DefaultSSHPort, "SSH port")
101+
userFlag := fs.String("u", "", "SSH username")
102+
fs.String("user", "", "SSH username")
103+
loginFlag := fs.String("login", "", "SSH username (alias for --user)")
104+
105+
// Parse flags until we hit the hostname (first non-flag argument)
106+
err := fs.Parse(args)
107+
if err != nil {
108+
// If flag parsing fails, treat everything as hostname + command
109+
// This handles cases like `ssh hostname ls -la` where `-la` should be part of the command
110+
return parseHostnameAndCommand(args)
111+
}
112+
113+
// Get the remaining args (hostname and command)
114+
remaining := fs.Args()
115+
if len(remaining) < 1 {
116+
return errors.New("host argument required")
117+
}
118+
119+
// Set parsed values
120+
port = *portFlag
121+
if *userFlag != "" {
122+
username = *userFlag
123+
} else if *loginFlag != "" {
124+
username = *loginFlag
125+
}
126+
127+
return parseHostnameAndCommand(remaining)
128+
}
129+
130+
func parseHostnameAndCommand(args []string) error {
131+
if len(args) < 1 {
132+
return errors.New("host argument required")
133+
}
134+
135+
// Parse hostname (possibly with user@host format)
136+
arg := args[0]
137+
if strings.Contains(arg, "@") {
138+
parts := strings.SplitN(arg, "@", 2)
139+
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
140+
return errors.New("invalid user@host format")
141+
}
142+
// Only use username from host if not already set by flags
143+
if username == "" {
144+
username = parts[0]
145+
}
146+
host = parts[1]
147+
} else {
148+
host = arg
149+
}
150+
151+
// Set default username if none provided
152+
if username == "" {
153+
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
154+
username = sudoUser
155+
} else if currentUser, err := user.Current(); err == nil {
156+
username = currentUser.Username
157+
} else {
158+
username = "root"
159+
}
160+
}
161+
162+
// Everything after hostname becomes the command
163+
if len(args) > 1 {
164+
command = strings.Join(args[1:], " ")
165+
}
166+
167+
return nil
168+
}
169+
91170
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
92-
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
171+
target := fmt.Sprintf("%s:%d", addr, port)
172+
c, err := nbssh.DialWithKey(ctx, target, username, pemKey)
93173
if err != nil {
94-
cmd.Printf("Error: %v\n", err)
95-
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
96-
"\nYou can verify the connection by running:\n\n" +
97-
" netbird status\n\n")
98-
return err
174+
cmd.Printf("Failed to connect to %s@%s\n", username, target)
175+
cmd.Printf("\nTroubleshooting steps:\n")
176+
cmd.Printf(" 1. Check peer connectivity: netbird status\n")
177+
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
178+
cmd.Printf(" 3. Ensure correct hostname/IP is used\n\n")
179+
return fmt.Errorf("dial %s: %w", target, err)
99180
}
100181
go func() {
101182
<-ctx.Done()
102-
err = c.Close()
103-
if err != nil {
104-
return
105-
}
183+
_ = c.Close()
106184
}()
107185

108-
err = c.OpenTerminal()
109-
if err != nil {
110-
return err
186+
if command != "" {
187+
if err := c.ExecuteCommandWithIO(ctx, command); err != nil {
188+
return err
189+
}
190+
} else {
191+
if err := c.OpenTerminal(ctx); err != nil {
192+
return err
193+
}
111194
}
112195

113196
return nil
114197
}
115198

116199
func init() {
117-
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort))
200+
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Remote SSH port")
201+
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", "SSH username")
202+
sshCmd.PersistentFlags().StringVar(&username, "login", "", "SSH username (alias for --user)")
118203
}

0 commit comments

Comments
 (0)