Skip to content

Commit 1eef7ee

Browse files
committed
Minor refactor changes
1 parent c3a11a3 commit 1eef7ee

File tree

4 files changed

+130
-124
lines changed

4 files changed

+130
-124
lines changed

ssh/config/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
log "github.com/inconshreveable/log15"
1313
)
1414

15+
// Config is an interface representing a configuration file
1516
type Config interface {
1617
}
1718

ssh/server/ssh/sessionchannel.go

Lines changed: 119 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -15,149 +15,154 @@ import (
1515
"golang.org/x/crypto/ssh"
1616
)
1717

18-
func handleSession(newChannel ssh.NewChannel) {
18+
func handleSession(perms *ssh.Permissions, newChannel ssh.NewChannel) {
1919
connection, requests, err := newChannel.Accept()
2020
if err != nil {
21-
log.Debug("Could not accept channel (%s)", err)
21+
log.Error("Could not accept channel", "error", err)
2222
return
2323
}
2424

25-
var bashf *os.File
26-
27-
close := func() {
25+
closeConn := func() {
2826
err = connection.Close()
2927
if err != nil {
30-
log.Debug("Could not close connection: %v", err)
28+
log.Error("Could not close connection", "error", err)
29+
return
3130
}
31+
32+
log.Debug("Connection closed")
3233
}
34+
var once sync.Once
35+
36+
var cmdf *os.File
37+
hasRequestedPty := false
38+
var ptyPayload []byte
39+
40+
execCmd := func(name string, arg ...string) error {
41+
cmd := exec.Command(name, arg...)
42+
close := func() {
43+
cmd.Process.Kill()
44+
err := cmd.Wait()
45+
if err != nil {
46+
log.Error("Error waiting for bash to end", "error", err)
47+
}
3348

34-
// Sessions have out-of-band requests such as "shell", "pty-req" and "env"
35-
go func() {
36-
var once sync.Once
37-
defer once.Do(close)
38-
for req := range requests {
39-
switch req.Type {
40-
case "shell":
41-
// We only accept the default shell
42-
// (i.e. no command in the Payload)
43-
if len(req.Payload) == 0 {
44-
req.Reply(true, nil)
45-
} else {
46-
log.Debug("Non-empty shell payload not yet supported!")
47-
}
48-
case "pty-req":
49+
tb := []byte{0, 0, 0, 0}
50+
connection.SendRequest("exit-status", false, tb)
4951

50-
bash := exec.Command("bash")
52+
once.Do(closeConn)
5153

52-
ptyClose := func() {
53-
bash.Process.Kill()
54-
err := bash.Wait()
55-
if err != nil {
56-
log.Debug("Error waiting for bash to end: %v", err)
57-
}
54+
log.Debug("Session closed")
55+
}
5856

59-
tb := []byte{0, 0, 0, 0}
60-
connection.SendRequest("exit-status", false, tb)
57+
var pipesWait sync.WaitGroup
58+
pipesWait.Add(2)
59+
go func() {
60+
pipesWait.Wait()
61+
close()
62+
}()
63+
64+
if hasRequestedPty {
65+
log.Debug("Creating pty...")
66+
cmdf, err = pty.Start(cmd)
67+
if err != nil {
68+
return err
69+
}
6170

62-
close()
71+
var once sync.Once
72+
go func() {
73+
io.Copy(connection, cmdf)
74+
pipesWait.Done()
75+
}()
76+
go func() {
77+
io.Copy(cmdf, connection)
78+
pipesWait.Done()
79+
}()
80+
81+
termLen := ptyPayload[3]
82+
w, h := parseDims(ptyPayload[termLen+4:])
83+
SetWinsize(cmdf.Fd(), w, h)
84+
} else {
85+
stdin, err := cmd.StdinPipe()
86+
if err != nil {
87+
return err
88+
}
6389

64-
log.Debug("Session closed")
65-
}
90+
stdout, err := cmd.StdoutPipe()
91+
if err != nil {
92+
return err
93+
}
6694

67-
// Allocate a terminal for this channel
68-
log.Debug("Creating pty...")
69-
bashf, err = pty.Start(bash)
95+
stderr, err := cmd.StderrPipe()
96+
if err != nil {
97+
return err
98+
}
99+
100+
// we want to wait for stdout and stderr before closing connection, but don't really mind about stdin
101+
go func() {
102+
_, err := io.Copy(stdin, connection)
103+
log.Debug("Stdin copy ended", "error", err)
104+
}()
105+
go func() {
106+
_, err := io.Copy(connection, stdout)
107+
log.Debug("Stdout copy ended", "error", err)
108+
pipesWait.Done()
109+
}()
110+
go func() {
111+
_, err := io.Copy(connection.Stderr(), stderr)
112+
log.Debug("Stderr copy ended", "error", err)
113+
pipesWait.Done()
114+
}()
115+
116+
err = cmd.Start()
117+
if err != nil {
118+
return err
119+
}
120+
}
121+
122+
return nil
123+
}
124+
125+
// Sessions have out-of-band requests such as "shell", "pty-req" and "exec"
126+
go func() {
127+
defer once.Do(closeConn)
128+
for req := range requests {
129+
switch req.Type {
130+
case "shell":
131+
// TODO determine and use default shell, don't force bash
132+
err := execCmd("bash")
70133
if err != nil {
71-
log.Debug("Could not start pty (%s)", err)
72-
return
134+
log.Error("Can't create shell!", "error", err)
73135
}
74136

75-
//pipe session to bash and visa-versa
76-
go func() {
77-
io.Copy(connection, bashf)
78-
once.Do(ptyClose)
79-
}()
80-
go func() {
81-
io.Copy(bashf, connection)
82-
once.Do(ptyClose)
83-
}()
84-
85-
termLen := req.Payload[3]
86-
w, h := parseDims(req.Payload[termLen+4:])
87-
SetWinsize(bashf.Fd(), w, h)
88-
// Responding true (OK) here will let the client
89-
// know we have a pty ready for input
90-
req.Reply(true, nil)
137+
if req.WantReply {
138+
req.Reply(true, nil)
139+
}
140+
case "pty-req":
141+
hasRequestedPty = true
142+
ptyPayload = req.Payload
143+
if req.WantReply {
144+
req.Reply(true, nil)
145+
}
91146
case "window-change":
92-
if bashf == nil {
93-
log.Debug("No pty requested!")
147+
if cmdf == nil {
148+
log.Debug("Tried to change window size but no pty requested!")
94149
} else {
95150
w, h := parseDims(req.Payload)
96-
SetWinsize(bashf.Fd(), w, h)
151+
SetWinsize(cmdf.Fd(), w, h)
152+
if req.WantReply {
153+
req.Reply(true, nil)
154+
}
97155
}
98156
case "exec":
99157
cmdStrLen := binary.BigEndian.Uint32(req.Payload[0:4])
100158
cmdStr := string(req.Payload[4 : cmdStrLen+4])
101-
cmd := exec.Command("bash", "-c", cmdStr)
102-
103-
stdin, err := cmd.StdinPipe()
104-
if err != nil {
105-
log.Debug("Error creating stdin pipe: %v", err)
106-
continue
107-
}
108-
109-
stdout, err := cmd.StdoutPipe()
159+
err := execCmd("bash", "-c", cmdStr)
110160
if err != nil {
111-
log.Debug("Error creating stdout pipe: %v", err)
112-
continue
161+
log.Error("Can't create shell!", "error", err)
113162
}
114163

115-
stderr, err := cmd.StderrPipe()
116-
if err != nil {
117-
log.Debug("Error creating stderr pipe: %v", err)
118-
continue
119-
}
120-
121-
var pipesWait sync.WaitGroup
122-
123-
// we want to wait for stdout and stderr before closing connection, but don't really mind about stdin
124-
pipesWait.Add(2)
125-
go func() {
126-
io.Copy(stdin, connection)
127-
}()
128-
go func() {
129-
io.Copy(connection, stdout)
130-
pipesWait.Done()
131-
}()
132-
go func() {
133-
io.Copy(connection.Stderr(), stderr)
134-
pipesWait.Done()
135-
}()
136-
137-
err = cmd.Start()
138-
if err != nil {
139-
log.Debug("Error running command %s: %v", cmdStr, err)
140-
continue
141-
}
142-
143-
go func() {
144-
err := cmd.Wait()
145-
if err != nil {
146-
log.Debug("Error waiting for command %s to end: %v", cmdStr, err)
147-
}
148-
149-
pipesWait.Wait()
150-
151-
tb := []byte{0, 0, 0, 0}
152-
connection.SendRequest("exit-status", false, tb)
153-
154-
once.Do(close)
155-
}()
156-
157-
err = req.Reply(true, nil)
158-
if err != nil {
159-
log.Debug("Error responding to request: %v", err)
160-
continue
164+
if req.WantReply {
165+
req.Reply(true, nil)
161166
}
162167
default:
163168
log.Debug("Unkown session request type %s", req.Type)

ssh/server/ssh/ssh.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
)
1616

1717
// ChannelHandlerFunction is a type for channel handlers, such as terminal sessions, tunnels, or X11 forwarding.
18-
type ChannelHandlerFunction func(newChannel ssh.NewChannel)
18+
type ChannelHandlerFunction func(perms *ssh.Permissions, newChannel ssh.NewChannel)
1919

2020
// Server is a struct containing information about SSH servers.
2121
type Server struct {
@@ -58,16 +58,16 @@ func Create(config *serverconfig.ServerConfig, version string) (*Server, error)
5858
return server, nil
5959
}
6060

61-
func (s *Server) handleChannels(chans <-chan ssh.NewChannel) {
61+
func (s *Server) handleChannels(perms *ssh.Permissions, chans <-chan ssh.NewChannel) {
6262
// Service the incoming Channel channel in go routine
6363
for newChannel := range chans {
64-
go s.handleChannel(newChannel)
64+
go s.handleChannel(perms, newChannel)
6565
}
6666
}
6767

68-
func (s *Server) handleChannel(newChannel ssh.NewChannel) {
68+
func (s *Server) handleChannel(perms *ssh.Permissions, newChannel ssh.NewChannel) {
6969
if handler, exists := s.channelHandlers[newChannel.ChannelType()]; exists {
70-
handler(newChannel)
70+
handler(perms, newChannel)
7171
} else {
7272
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", newChannel.ChannelType()))
7373
return
@@ -79,16 +79,16 @@ func (s *Server) HandleConnection(conn net.Conn) error {
7979
log.Debug("Handling new connection")
8080
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.configuration)
8181
if err != nil {
82-
log.Debug("Failed to create new connection (%s)", err)
82+
log.Error("Failed to create new connection", "error", err)
8383
conn.Close()
8484
return err
8585
}
8686

87-
log.Debug("New SSH connection from %s (%s)", sshConn.RemoteAddr(), sshConn.ClientVersion())
87+
log.Debug("New SSH connection", "remoteAddress", sshConn.RemoteAddr(), "clientVersion", sshConn.ClientVersion())
8888
// Discard all global out-of-band Requests
8989
go ssh.DiscardRequests(reqs)
9090
// Accept all channels
91-
s.handleChannels(chans)
91+
s.handleChannels(sshConn.Permissions, chans)
9292

9393
return nil
9494
}

ssh/server/ssh/tunnelchannel.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func handleTunnelForRemoteConnection(connection ssh.Channel, remoteConnection ne
3131
}()
3232
}
3333

34-
func handleTCPTunnel(newChannel ssh.NewChannel) {
34+
func handleTCPTunnel(perms *ssh.Permissions, newChannel ssh.NewChannel) {
3535
extraData := newChannel.ExtraData()
3636
addressLen := binary.BigEndian.Uint32(extraData[0:4])
3737
address := string(extraData[4 : addressLen+4])
@@ -54,7 +54,7 @@ func handleTCPTunnel(newChannel ssh.NewChannel) {
5454
handleTunnelForRemoteConnection(connection, remoteConnection)
5555
}
5656

57-
func handleSCIONQUICTunnel(newChannel ssh.NewChannel) {
57+
func handleSCIONQUICTunnel(perms *ssh.Permissions, newChannel ssh.NewChannel) {
5858
extraData := newChannel.ExtraData()
5959
addressLen := binary.BigEndian.Uint32(extraData[0:4])
6060
address := string(extraData[4 : addressLen+4])

0 commit comments

Comments
 (0)