Skip to content

Commit 3e8bb7a

Browse files
committed
Use configurable io.Writer for user facing log messages instead of hardcoding stdout
Signed-off-by: Chance Zibolski <chance.zibolski@gmail.com>
1 parent fd06e15 commit 3e8bb7a

File tree

18 files changed

+82
-49
lines changed

18 files changed

+82
-49
lines changed

src/datachannel/streaming.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import (
2121
"encoding/json"
2222
"errors"
2323
"fmt"
24+
"io"
2425
"math"
25-
"os"
2626
"reflect"
2727
"sync"
2828
"time"
@@ -115,6 +115,9 @@ type DataChannel struct {
115115

116116
// AgentVersion received during handshake
117117
agentVersion string
118+
119+
// Out is where user ssm plugin logs go
120+
Out io.Writer
118121
}
119122

120123
type ListMessageBuffer struct {
@@ -510,7 +513,7 @@ func (dataChannel *DataChannel) handleHandshakeComplete(log log.T, clientMessage
510513
handshakeComplete.HandshakeTimeToComplete.Seconds())
511514

512515
if handshakeComplete.CustomerMessage != "" {
513-
fmt.Fprintln(os.Stdout, handshakeComplete.CustomerMessage)
516+
fmt.Fprintln(dataChannel.Out, handshakeComplete.CustomerMessage)
514517
}
515518

516519
return err
@@ -783,9 +786,9 @@ func (dataChannel DataChannel) HandleChannelClosedMessage(log log.T, stopHandler
783786

784787
log.Infof("Exiting session with sessionId: %s with output: %s", sessionId, channelClosedMessage.Output)
785788
if channelClosedMessage.Output == "" {
786-
fmt.Fprintf(os.Stdout, "\n\nExiting session with sessionId: %s.\n\n", sessionId)
789+
fmt.Fprintf(dataChannel.Out, "\n\nExiting session with sessionId: %s.\n\n", sessionId)
787790
} else {
788-
fmt.Fprintf(os.Stdout, "\n\nSessionId: %s : %s\n\n", sessionId, channelClosedMessage.Output)
791+
fmt.Fprintf(dataChannel.Out, "\n\nSessionId: %s : %s\n\n", sessionId, channelClosedMessage.Output)
789792
}
790793

791794
stopHandler()

src/sessionmanagerplugin/session/portsession/basicportforwarding.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package portsession
1616

1717
import (
1818
"fmt"
19+
"io"
1920
"net"
2021
"os"
2122
"os/signal"
@@ -39,6 +40,7 @@ type BasicPortForwarding struct {
3940
sessionId string
4041
portParameters PortParameters
4142
session session.Session
43+
out io.Writer
4244
}
4345

4446
// getNewListener returns a new listener to given address and type like tcp, unix etc.
@@ -132,7 +134,7 @@ func (p *BasicPortForwarding) startLocalConn(log log.T) (err error) {
132134
return err
133135
}
134136
log.Infof("Connection accepted for session %s.", p.sessionId)
135-
fmt.Printf("Connection accepted for session %s.\n", p.sessionId)
137+
fmt.Fprintf(p.out, "Connection accepted for session %s.\n", p.sessionId)
136138

137139
p.listener = &listener
138140
p.stream = &tcpConn
@@ -159,7 +161,7 @@ func (p *BasicPortForwarding) startLocalListener(log log.T, portNumber string) (
159161
}
160162

161163
log.Info(displayMessage)
162-
fmt.Println(displayMessage)
164+
fmt.Fprintln(p.out, displayMessage)
163165
return
164166
}
165167

@@ -169,13 +171,13 @@ func (p *BasicPortForwarding) handleControlSignals(log log.T) {
169171
signal.Notify(c, sessionutil.ControlSignals...)
170172
go func() {
171173
<-c
172-
fmt.Println("Terminate signal received, exiting.")
174+
fmt.Fprintln(p.out, "Terminate signal received, exiting.")
173175

174176
if version.DoesAgentSupportTerminateSessionFlag(log, p.session.DataChannel.GetAgentVersion()) {
175177
if err := p.session.DataChannel.SendFlag(log, message.TerminateSession); err != nil {
176178
log.Errorf("Failed to send TerminateSession flag: %v", err)
177179
}
178-
fmt.Fprintf(os.Stdout, "\n\nExiting session with sessionId: %s.\n\n", p.sessionId)
180+
fmt.Fprintf(p.out, "\n\nExiting session with sessionId: %s.\n\n", p.sessionId)
179181
p.Stop()
180182
} else {
181183
p.session.TerminateSession(log)

src/sessionmanagerplugin/session/portsession/basicportforwarding_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ func TestSetSessionHandlers(t *testing.T) {
4848
Session: mockSession,
4949
portParameters: PortParameters{PortNumber: "22", Type: "LocalPortForwarding"},
5050
portSessionType: &BasicPortForwarding{
51+
out: os.Stdout,
5152
session: mockSession,
5253
portParameters: PortParameters{PortNumber: "22", Type: "LocalPortForwarding"},
5354
},
@@ -84,6 +85,7 @@ func TestStartSessionTCPLocalPortFromDocument(t *testing.T) {
8485
Session: getSessionMock(),
8586
portParameters: PortParameters{PortNumber: "22", Type: "LocalPortForwarding", LocalPortNumber: "54321"},
8687
portSessionType: &BasicPortForwarding{
88+
out: os.Stdout,
8789
session: getSessionMock(),
8890
portParameters: PortParameters{PortNumber: "22", Type: "LocalPortForwarding"},
8991
},
@@ -101,6 +103,7 @@ func TestStartSessionTCPAcceptFailed(t *testing.T) {
101103
Session: getSessionMock(),
102104
portParameters: PortParameters{PortNumber: "22", Type: "LocalPortForwarding"},
103105
portSessionType: &BasicPortForwarding{
106+
out: os.Stdout,
104107
session: getSessionMock(),
105108
portParameters: PortParameters{PortNumber: "22", Type: "LocalPortForwarding"},
106109
},
@@ -117,6 +120,7 @@ func TestStartSessionTCPConnectFailed(t *testing.T) {
117120
Session: getSessionMock(),
118121
portParameters: PortParameters{PortNumber: "22", Type: "LocalPortForwarding"},
119122
portSessionType: &BasicPortForwarding{
123+
out: os.Stdout,
120124
session: getSessionMock(),
121125
portParameters: PortParameters{PortNumber: "22", Type: "LocalPortForwarding"},
122126
},

src/sessionmanagerplugin/session/portsession/muxportforwarding.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ type MuxPortForwarding struct {
6161
session session.Session
6262
muxClient *MuxClient
6363
mgsConn *MgsConn
64+
out io.Writer
6465
}
6566

6667
func (c *MgsConn) close() {
@@ -131,7 +132,7 @@ func (p *MuxPortForwarding) WriteStream(outputMessage message.ClientMessage) err
131132
binary.Read(buf, binary.BigEndian, &flag)
132133

133134
if message.ConnectToPortError == flag {
134-
fmt.Printf("\nConnection to destination port failed, check SSM Agent logs.\n")
135+
fmt.Fprintf(p.out, "\nConnection to destination port failed, check SSM Agent logs.\n")
135136
}
136137
}
137138
return nil
@@ -190,12 +191,12 @@ func (p *MuxPortForwarding) handleControlSignals(log log.T) {
190191
signal.Notify(c, sessionutil.ControlSignals...)
191192
go func() {
192193
<-c
193-
fmt.Println("Terminate signal received, exiting.")
194+
fmt.Fprintln(p.out, "Terminate signal received, exiting.")
194195

195196
if err := p.session.DataChannel.SendFlag(log, message.TerminateSession); err != nil {
196197
log.Errorf("Failed to send TerminateSession flag: %v", err)
197198
}
198-
fmt.Fprintf(os.Stdout, "\n\nExiting session with sessionId: %s.\n\n", p.sessionId)
199+
fmt.Fprintf(p.out, "\n\nExiting session with sessionId: %s.\n\n", p.sessionId)
199200
p.Stop()
200201
}()
201202
}
@@ -252,10 +253,10 @@ func (p *MuxPortForwarding) handleClientConnections(log log.T, ctx context.Conte
252253
defer listener.Close()
253254

254255
log.Infof(displayMsg)
255-
fmt.Printf(displayMsg)
256+
fmt.Fprintf(p.out, displayMsg)
256257

257258
log.Infof("Waiting for connections...\n")
258-
fmt.Printf("\nWaiting for connections...\n")
259+
fmt.Fprintf(p.out, "\nWaiting for connections...\n")
259260

260261
var once sync.Once
261262
for {
@@ -269,7 +270,7 @@ func (p *MuxPortForwarding) handleClientConnections(log log.T, ctx context.Conte
269270
log.Infof("Connection accepted from %s\n for session [%s]", conn.RemoteAddr(), p.sessionId)
270271

271272
once.Do(func() {
272-
fmt.Printf("\nConnection accepted for session [%s]\n", p.sessionId)
273+
fmt.Fprintf(p.out, "\nConnection accepted for session [%s]\n", p.sessionId)
273274
})
274275

275276
stream, err := p.muxClient.session.OpenStream()

src/sessionmanagerplugin/session/portsession/portsession.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
package portsession
1616

1717
import (
18+
"os"
19+
1820
"github.com/aws/session-manager-plugin/src/config"
1921
"github.com/aws/session-manager-plugin/src/jsonutil"
2022
"github.com/aws/session-manager-plugin/src/log"
@@ -70,12 +72,14 @@ func (s *PortSession) Initialize(log log.T, sessionVar *session.Session) {
7072
sessionId: s.SessionId,
7173
portParameters: s.portParameters,
7274
session: s.Session,
75+
out: os.Stdout,
7376
}
7477
} else {
7578
s.portSessionType = &BasicPortForwarding{
7679
sessionId: s.SessionId,
7780
portParameters: s.portParameters,
7881
session: s.Session,
82+
out: os.Stdout,
7983
}
8084
}
8185
} else {

src/sessionmanagerplugin/session/portsession/test_portsession.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
package portsession
1616

1717
import (
18+
"os"
19+
1820
"github.com/aws/session-manager-plugin/src/communicator/mocks"
1921
"github.com/aws/session-manager-plugin/src/datachannel"
2022
"github.com/aws/session-manager-plugin/src/log"
@@ -41,11 +43,13 @@ func getSessionMock() session.Session {
4143
}
4244

4345
func getSessionMockWithParams(properties interface{}, agentVersion string) session.Session {
44-
datachannel := &datachannel.DataChannel{}
46+
out := os.Stdout
47+
datachannel := &datachannel.DataChannel{Out: out}
4548
datachannel.SetAgentVersion(agentVersion)
4649

4750
var mockSession = session.Session{
4851
DataChannel: datachannel,
52+
Out: out,
4953
}
5054

5155
mockSession.DataChannel.Initialize(mockLog, "clientId", "sessionId", "targetId", false)

src/sessionmanagerplugin/session/session.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,15 @@ type Session struct {
8585
SessionType string
8686
SessionProperties interface{}
8787
DisplayMode sessionutil.DisplayMode
88+
Out io.Writer
8889
}
8990

90-
//startSession create the datachannel for session
91+
// startSession create the datachannel for session
9192
var startSession = func(session *Session, log log.T) error {
9293
return session.Execute(log)
9394
}
9495

95-
//setSessionHandlersWithSessionType set session handlers based on session subtype
96+
// setSessionHandlersWithSessionType set session handlers based on session subtype
9697
var setSessionHandlersWithSessionType = func(session *Session, log log.T) error {
9798
// SessionType is set inside DataChannel
9899
sessionSubType := SessionRegistry[session.SessionType]
@@ -203,7 +204,8 @@ func ValidateInputAndStartSession(args []string, out io.Writer) {
203204
session.Endpoint = ssmEndpoint
204205
session.ClientId = clientId
205206
session.TargetId = target
206-
session.DataChannel = &datachannel.DataChannel{}
207+
session.DataChannel = &datachannel.DataChannel{Out: out}
208+
session.Out = out
207209

208210
default:
209211
fmt.Fprint(out, "Invalid Operation")
@@ -217,12 +219,12 @@ func ValidateInputAndStartSession(args []string, out io.Writer) {
217219
}
218220
}
219221

220-
//Execute create data channel and start the session
222+
// Execute create data channel and start the session
221223
func (s *Session) Execute(log log.T) (err error) {
222-
fmt.Fprintf(os.Stdout, "\nStarting session with SessionId: %s\n", s.SessionId)
224+
fmt.Fprintf(s.Out, "\nStarting session with SessionId: %s\n", s.SessionId)
223225

224226
// sets the display mode
225-
s.DisplayMode = sessionutil.NewDisplayMode(log)
227+
s.DisplayMode = sessionutil.NewDisplayMode(log, s.Out)
226228

227229
if err = s.OpenDataChannel(log); err != nil {
228230
log.Errorf("Error in Opening data channel: %v", err)

src/sessionmanagerplugin/session/session_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ func TestValidateInputAndStartSessionWithWrongEnvVariableName(t *testing.T) {
104104
}
105105

106106
func TestExecute(t *testing.T) {
107-
sessionMock := &Session{}
107+
sessionMock := &Session{Out: os.Stdout}
108108
sessionMock.DataChannel = mockDataChannel
109109
SetupMockActions()
110110
mockDataChannel.On("Open", mock.Anything).Return(nil)
@@ -128,7 +128,7 @@ func TestExecute(t *testing.T) {
128128
}
129129

130130
func TestExecuteAndStreamMessageResendTimesOut(t *testing.T) {
131-
sessionMock := &Session{}
131+
sessionMock := &Session{Out: os.Stdout}
132132
sessionMock.DataChannel = mockDataChannel
133133
SetupMockActions()
134134
mockDataChannel.On("Open", mock.Anything).Return(nil)

src/sessionmanagerplugin/session/sessionhandler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ func (s *Session) ResumeSessionHandler(log log.T) (err error) {
129129
return
130130
} else if s.TokenValue == "" {
131131
log.Debugf("Session: %s timed out", s.SessionId)
132-
fmt.Fprintf(os.Stdout, "Session: %s timed out.\n", s.SessionId)
132+
fmt.Fprintf(s.Out, "Session: %s timed out.\n", s.SessionId)
133133
os.Exit(0)
134134
}
135135
s.DataChannel.GetWsChannel().SetChannelToken(s.TokenValue)

src/sessionmanagerplugin/session/sessionhandler_test.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@ package session
1616

1717
import (
1818
"fmt"
19+
"os"
1920
"testing"
2021

2122
wsChannelMock "github.com/aws/session-manager-plugin/src/communicator/mocks"
2223
"github.com/aws/session-manager-plugin/src/config"
2324
"github.com/aws/session-manager-plugin/src/datachannel"
2425
dataChannelMock "github.com/aws/session-manager-plugin/src/datachannel/mocks"
2526
"github.com/aws/session-manager-plugin/src/message"
27+
"github.com/aws/session-manager-plugin/src/sessionmanagerplugin/session/sessionutil"
2628
"github.com/stretchr/testify/mock"
2729

2830
"github.com/stretchr/testify/assert"
@@ -38,7 +40,7 @@ func TestOpenDataChannel(t *testing.T) {
3840
mockDataChannel = &dataChannelMock.IDataChannel{}
3941
mockWsChannel = &wsChannelMock.IWebSocketChannel{}
4042

41-
sessionMock := &Session{}
43+
sessionMock := &Session{Out: os.Stdout}
4244
sessionMock.DataChannel = mockDataChannel
4345
SetupMockActions()
4446
mockDataChannel.On("Open", mock.Anything).Return(nil)
@@ -51,7 +53,7 @@ func TestOpenDataChannelWithError(t *testing.T) {
5153
mockDataChannel = &dataChannelMock.IDataChannel{}
5254
mockWsChannel = &wsChannelMock.IWebSocketChannel{}
5355

54-
sessionMock := &Session{}
56+
sessionMock := &Session{Out: os.Stdout}
5557
sessionMock.DataChannel = mockDataChannel
5658
SetupMockActions()
5759

@@ -69,10 +71,12 @@ func TestProcessFirstMessageOutputMessageFirst(t *testing.T) {
6971
Payload: []byte("testing"),
7072
}
7173

72-
dataChannel := &datachannel.DataChannel{}
74+
dataChannel := &datachannel.DataChannel{Out: os.Stdout}
7375
dataChannel.Initialize(logger, clientId, sessionId, instanceId, false)
7476
session := Session{
77+
Out: os.Stdout,
7578
DataChannel: dataChannel,
79+
DisplayMode: sessionutil.NewDisplayMode(logger, os.Stdout),
7680
}
7781

7882
session.ProcessFirstMessage(logger, outputMessage)

0 commit comments

Comments
 (0)