@@ -3,9 +3,11 @@ package cmd
33import (
44 "context"
55 "errors"
6+ "flag"
67 "fmt"
78 "os"
89 "os/signal"
10+ "os/user"
911 "strings"
1012 "syscall"
1113
@@ -17,43 +19,34 @@ import (
1719)
1820
1921var (
20- port int
21- user = "root"
22- host string
22+ port int
23+ username string
24+ host string
25+ command string
2326)
2427
2528var sshCmd = & cobra.Command {
26- Use : "ssh [user@]host" ,
27- Args : func (cmd * cobra.Command , args []string ) error {
28- if len (args ) < 1 {
29- return errors .New ("requires a host argument" )
30- }
31-
32- split := strings .Split (args [0 ], "@" )
33- if len (split ) == 2 {
34- user = split [0 ]
35- host = split [1 ]
36- } else {
37- host = args [0 ]
38- }
39-
40- return nil
41- },
42- Short : "connect to a remote SSH server" ,
29+ Use : "ssh [user@]host [command]" ,
30+ Short : "Connect to a NetBird peer via SSH" ,
31+ Long : `Connect to a NetBird peer using SSH.
32+
33+ Examples:
34+ netbird ssh peer-hostname
35+ netbird ssh user@peer-hostname
36+ netbird ssh peer-hostname --login myuser
37+ netbird ssh peer-hostname -p 22022
38+ netbird ssh peer-hostname ls -la
39+ netbird ssh peer-hostname whoami` ,
40+ DisableFlagParsing : true ,
41+ Args : validateSSHArgsWithoutFlagParsing ,
4342 RunE : func (cmd * cobra.Command , args []string ) error {
4443 SetFlagsFromEnvVars (rootCmd )
4544 SetFlagsFromEnvVars (cmd )
4645
4746 cmd .SetOut (cmd .OutOrStdout ())
4847
49- err := util .InitLog (logLevel , "console" )
50- if err != nil {
51- return fmt .Errorf ("failed initializing log %v" , err )
52- }
53-
54- if ! util .IsAdmin () {
55- cmd .Printf ("error: you must have Administrator privileges to run this command\n " )
56- return nil
48+ if err := util .InitLog (logLevel , "console" ); err != nil {
49+ return fmt .Errorf ("init log: %w" , err )
5750 }
5851
5952 ctx := internal .CtxInitState (cmd .Context ())
@@ -62,15 +55,14 @@ var sshCmd = &cobra.Command{
6255 ConfigPath : configPath ,
6356 })
6457 if err != nil {
65- return err
58+ return fmt . Errorf ( "update config: %w" , err )
6659 }
6760
6861 sig := make (chan os.Signal , 1 )
6962 signal .Notify (sig , syscall .SIGTERM , syscall .SIGINT )
7063 sshctx , cancel := context .WithCancel (ctx )
7164
7265 go func () {
73- // blocking
7466 if err := runSSH (sshctx , host , []byte (config .SSHKey ), cmd ); err != nil {
7567 cmd .Printf ("Error: %v\n " , err )
7668 os .Exit (1 )
@@ -88,31 +80,124 @@ var sshCmd = &cobra.Command{
8880 },
8981}
9082
83+ func validateSSHArgsWithoutFlagParsing (_ * cobra.Command , args []string ) error {
84+ if len (args ) < 1 {
85+ return errors .New ("host argument required" )
86+ }
87+
88+ // Reset globals to defaults
89+ port = nbssh .DefaultSSHPort
90+ username = ""
91+ host = ""
92+ command = ""
93+
94+ // Create a new FlagSet for parsing SSH-specific flags
95+ fs := flag .NewFlagSet ("ssh-flags" , flag .ContinueOnError )
96+ fs .SetOutput (nil ) // Suppress error output
97+
98+ // Define SSH-specific flags
99+ portFlag := fs .Int ("p" , nbssh .DefaultSSHPort , "SSH port" )
100+ fs .Int ("port" , nbssh .DefaultSSHPort , "SSH port" )
101+ userFlag := fs .String ("u" , "" , "SSH username" )
102+ fs .String ("user" , "" , "SSH username" )
103+ loginFlag := fs .String ("login" , "" , "SSH username (alias for --user)" )
104+
105+ // Parse flags until we hit the hostname (first non-flag argument)
106+ err := fs .Parse (args )
107+ if err != nil {
108+ // If flag parsing fails, treat everything as hostname + command
109+ // This handles cases like `ssh hostname ls -la` where `-la` should be part of the command
110+ return parseHostnameAndCommand (args )
111+ }
112+
113+ // Get the remaining args (hostname and command)
114+ remaining := fs .Args ()
115+ if len (remaining ) < 1 {
116+ return errors .New ("host argument required" )
117+ }
118+
119+ // Set parsed values
120+ port = * portFlag
121+ if * userFlag != "" {
122+ username = * userFlag
123+ } else if * loginFlag != "" {
124+ username = * loginFlag
125+ }
126+
127+ return parseHostnameAndCommand (remaining )
128+ }
129+
130+ func parseHostnameAndCommand (args []string ) error {
131+ if len (args ) < 1 {
132+ return errors .New ("host argument required" )
133+ }
134+
135+ // Parse hostname (possibly with user@host format)
136+ arg := args [0 ]
137+ if strings .Contains (arg , "@" ) {
138+ parts := strings .SplitN (arg , "@" , 2 )
139+ if len (parts ) != 2 || parts [0 ] == "" || parts [1 ] == "" {
140+ return errors .New ("invalid user@host format" )
141+ }
142+ // Only use username from host if not already set by flags
143+ if username == "" {
144+ username = parts [0 ]
145+ }
146+ host = parts [1 ]
147+ } else {
148+ host = arg
149+ }
150+
151+ // Set default username if none provided
152+ if username == "" {
153+ if sudoUser := os .Getenv ("SUDO_USER" ); sudoUser != "" {
154+ username = sudoUser
155+ } else if currentUser , err := user .Current (); err == nil {
156+ username = currentUser .Username
157+ } else {
158+ username = "root"
159+ }
160+ }
161+
162+ // Everything after hostname becomes the command
163+ if len (args ) > 1 {
164+ command = strings .Join (args [1 :], " " )
165+ }
166+
167+ return nil
168+ }
169+
91170func runSSH (ctx context.Context , addr string , pemKey []byte , cmd * cobra.Command ) error {
92- c , err := nbssh .DialWithKey (fmt .Sprintf ("%s:%d" , addr , port ), user , pemKey )
171+ target := fmt .Sprintf ("%s:%d" , addr , port )
172+ c , err := nbssh .DialWithKey (ctx , target , username , pemKey )
93173 if err != nil {
94- cmd .Printf ("Error: %v\n " , err )
95- cmd .Printf ("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
96- "\n You can verify the connection by running:\n \n " +
97- " netbird status\n \n " )
98- return err
174+ cmd .Printf ("Failed to connect to %s@%s\n " , username , target )
175+ cmd .Printf ("\n Troubleshooting steps:\n " )
176+ cmd .Printf (" 1. Check peer connectivity: netbird status\n " )
177+ cmd .Printf (" 2. Verify SSH server is enabled on the peer\n " )
178+ cmd .Printf (" 3. Ensure correct hostname/IP is used\n \n " )
179+ return fmt .Errorf ("dial %s: %w" , target , err )
99180 }
100181 go func () {
101182 <- ctx .Done ()
102- err = c .Close ()
103- if err != nil {
104- return
105- }
183+ _ = c .Close ()
106184 }()
107185
108- err = c .OpenTerminal ()
109- if err != nil {
110- return err
186+ if command != "" {
187+ if err := c .ExecuteCommandWithIO (ctx , command ); err != nil {
188+ return err
189+ }
190+ } else {
191+ if err := c .OpenTerminal (ctx ); err != nil {
192+ return err
193+ }
111194 }
112195
113196 return nil
114197}
115198
116199func init () {
117- sshCmd .PersistentFlags ().IntVarP (& port , "port" , "p" , nbssh .DefaultSSHPort , "Sets remote SSH port. Defaults to " + fmt .Sprint (nbssh .DefaultSSHPort ))
200+ sshCmd .PersistentFlags ().IntVarP (& port , "port" , "p" , nbssh .DefaultSSHPort , "Remote SSH port" )
201+ sshCmd .PersistentFlags ().StringVarP (& username , "user" , "u" , "" , "SSH username" )
202+ sshCmd .PersistentFlags ().StringVar (& username , "login" , "" , "SSH username (alias for --user)" )
118203}
0 commit comments