@@ -12,6 +12,7 @@ import (
1212 "github.com/prometheus/client_golang/prometheus"
1313 "github.com/schollz/progressbar/v3"
1414 "github.com/spf13/afero"
15+ xslices "golang.org/x/exp/slices"
1516 "golang.org/x/xerrors"
1617 "tailscale.com/ipn/store"
1718 "tailscale.com/net/netns"
@@ -20,6 +21,7 @@ import (
2021 cslog "cdr.dev/slog"
2122 csloghuman "cdr.dev/slog/sloggers/sloghuman"
2223 "github.com/coder/coder/v2/agent/agentssh"
24+ "github.com/coder/pretty"
2325 "github.com/coder/serpent"
2426 "github.com/coder/wush/cliui"
2527 "github.com/coder/wush/overlay"
@@ -30,6 +32,8 @@ func serveCmd() *serpent.Command {
3032 var (
3133 overlayType string
3234 verbose bool
35+ enabled = []string {}
36+ disabled = []string {}
3337 )
3438 return & serpent.Command {
3539 Use : "serve" ,
@@ -89,72 +93,64 @@ func serveCmd() *serpent.Command {
8993
9094 fmt .Println (cliui .Timestamp (time .Now ()), "WireGuard is ready" )
9195
92- sshSrv , err := agentssh .NewServer (ctx ,
93- cslog .Make (csloghuman .Sink (logSink )),
94- prometheus .NewRegistry (),
95- fs ,
96- nil ,
97- )
98- if err != nil {
99- return err
100- }
101-
102- sshListener , err := ts .Listen ("tcp" , ":3" )
103- if err != nil {
104- return err
105- }
96+ closers := []io.Closer {}
10697
107- go func () {
108- fmt .Println (cliui .Timestamp (time .Now ()), "SSH server listening" )
109- err := sshSrv .Serve (sshListener )
98+ if xslices .Contains (enabled , "ssh" ) && ! xslices .Contains (disabled , "ssh" ) {
99+ sshSrv , err := agentssh .NewServer (ctx ,
100+ cslog .Make (csloghuman .Sink (logSink )),
101+ prometheus .NewRegistry (),
102+ fs ,
103+ nil ,
104+ )
110105 if err != nil {
111- logger .Info ("ssh server exited" , "err" , err )
112- }
113- }()
114-
115- cpListener , err := ts .Listen ("tcp" , ":4444" )
116- if err != nil {
117- return err
118- }
119-
120- go http .Serve (cpListener , http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
121- if r .Method != "POST" {
122- w .WriteHeader (http .StatusOK )
123- w .Write ([]byte ("OK" ))
124- return
106+ return err
125107 }
108+ closers = append (closers , sshSrv )
126109
127- fiName := strings .TrimPrefix (r .URL .Path , "/" )
128- defer r .Body .Close ()
129-
130- fi , err := os .OpenFile (fiName , os .O_CREATE | os .O_RDWR | os .O_TRUNC , 0644 )
110+ sshListener , err := ts .Listen ("tcp" , ":3" )
131111 if err != nil {
132- http .Error (w , err .Error (), http .StatusInternalServerError )
133- return
112+ return err
134113 }
114+ closers = append (closers , sshListener )
115+
116+ fmt .Println (cliui .Timestamp (time .Now ()), "SSH server " + pretty .Sprint (cliui .DefaultStyles .Enabled , "enabled" ))
117+ go func () {
118+ err := sshSrv .Serve (sshListener )
119+ if err != nil {
120+ fmt .Println (cliui .Timestamp (time .Now ()), "SSH server exited: " + err .Error ())
121+ }
122+ }()
123+ } else {
124+ fmt .Println (cliui .Timestamp (time .Now ()), "SSH server " + pretty .Sprint (cliui .DefaultStyles .Disabled , "disabled" ))
125+ }
135126
136- bar := progressbar .DefaultBytes (
137- r .ContentLength ,
138- fmt .Sprintf ("Downloading %q" , fiName ),
139- )
140- _ , err = io .Copy (io .MultiWriter (fi , bar ), r .Body )
127+ if xslices .Contains (enabled , "cp" ) && ! xslices .Contains (disabled , "cp" ) {
128+ cpListener , err := ts .Listen ("tcp" , ":4444" )
141129 if err != nil {
142- http .Error (w , err .Error (), http .StatusInternalServerError )
143- return
130+ return err
144131 }
145- fi .Close ()
146- bar .Close ()
147-
148- w .WriteHeader (http .StatusOK )
149- w .Write ([]byte (fmt .Sprintf ("File %q written" , fiName )))
150- fmt .Printf ("Received file %s from %s\n " , fiName , r .RemoteAddr )
151- }))
132+ closers = append ([]io.Closer {cpListener }, closers ... )
133+
134+ fmt .Println (cliui .Timestamp (time .Now ()), "File transfer server " + pretty .Sprint (cliui .DefaultStyles .Enabled , "enabled" ))
135+ go func () {
136+ err := http .Serve (cpListener , http .HandlerFunc (cpHandler ))
137+ if err != nil {
138+ fmt .Println (cliui .Timestamp (time .Now ()), "File transfer server exited: " + err .Error ())
139+ }
140+ }()
141+ } else {
142+ fmt .Println (cliui .Timestamp (time .Now ()), "File transfer server " + pretty .Sprint (cliui .DefaultStyles .Disabled , "disabled" ))
143+ }
152144
153145 ctx , ctxCancel := inv .SignalNotifyContext (ctx , os .Interrupt )
154146 defer ctxCancel ()
155147
148+ closers = append (closers , ts )
156149 <- ctx .Done ()
157- return sshSrv .Close ()
150+ for _ , closer := range closers {
151+ closer .Close ()
152+ }
153+ return nil
158154 },
159155 Options : []serpent.Option {
160156 {
@@ -169,6 +165,18 @@ func serveCmd() *serpent.Command {
169165 Default : "false" ,
170166 Value : serpent .BoolOf (& verbose ),
171167 },
168+ {
169+ Flag : "enable" ,
170+ Description : "Server options to enable." ,
171+ Default : "ssh,cp" ,
172+ Value : serpent .EnumArrayOf (& enabled , "ssh" , "cp" ),
173+ },
174+ {
175+ Flag : "disable" ,
176+ Description : "Server options to disable." ,
177+ Default : "" ,
178+ Value : serpent .EnumArrayOf (& disabled , "ssh" , "cp" ),
179+ },
172180 },
173181 }
174182}
@@ -198,3 +206,36 @@ func newTSNet(direction string) (*tsnet.Server, error) {
198206
199207 return srv , nil
200208}
209+
210+ func cpHandler (w http.ResponseWriter , r * http.Request ) {
211+ if r .Method != "POST" {
212+ w .WriteHeader (http .StatusOK )
213+ w .Write ([]byte ("OK" ))
214+ return
215+ }
216+
217+ fiName := strings .TrimPrefix (r .URL .Path , "/" )
218+ defer r .Body .Close ()
219+
220+ fi , err := os .OpenFile (fiName , os .O_CREATE | os .O_RDWR | os .O_TRUNC , 0644 )
221+ if err != nil {
222+ http .Error (w , err .Error (), http .StatusInternalServerError )
223+ return
224+ }
225+
226+ bar := progressbar .DefaultBytes (
227+ r .ContentLength ,
228+ fmt .Sprintf ("Downloading %q" , fiName ),
229+ )
230+ _ , err = io .Copy (io .MultiWriter (fi , bar ), r .Body )
231+ if err != nil {
232+ http .Error (w , err .Error (), http .StatusInternalServerError )
233+ return
234+ }
235+ fi .Close ()
236+ bar .Close ()
237+
238+ w .WriteHeader (http .StatusOK )
239+ w .Write ([]byte (fmt .Sprintf ("File %q written" , fiName )))
240+ fmt .Printf ("Received file %s from %s\n " , fiName , r .RemoteAddr )
241+ }
0 commit comments