Skip to content

Commit 0d7ef23

Browse files
author
Igor Drozdov
committed
Merge branch 'sshd-forwarded-for' into 'main'
Pass original IP from PROXY requests to internal API calls See merge request gitlab-org/gitlab-shell!665
2 parents 01f4e02 + 9b60ce4 commit 0d7ef23

File tree

6 files changed

+76
-20
lines changed

6 files changed

+76
-20
lines changed

client/client_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ func TestClients(t *testing.T) {
7676
testErrorMessage(t, client)
7777
testAuthenticationHeader(t, client)
7878
testJWTAuthenticationHeader(t, client)
79+
testXForwardedForHeader(t, client)
7980
})
8081
}
8182
}
@@ -221,6 +222,21 @@ func testJWTAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
221222
})
222223
}
223224

225+
func testXForwardedForHeader(t *testing.T, client *GitlabNetClient) {
226+
t.Run("X-Forwarded-For Header inserted if original address in context", func(t *testing.T) {
227+
ctx := context.WithValue(context.Background(), OriginalRemoteIPContextKey{}, "196.7.0.238")
228+
response, err := client.Get(ctx, "/x_forwarded_for")
229+
require.NoError(t, err)
230+
require.NotNil(t, response)
231+
232+
defer response.Body.Close()
233+
234+
responseBody, err := io.ReadAll(response.Body)
235+
require.NoError(t, err)
236+
require.Equal(t, "196.7.0.238", string(responseBody))
237+
})
238+
}
239+
224240
func buildRequests(t *testing.T, relativeURLRoot string) []testserver.TestRequestHandler {
225241
requests := []testserver.TestRequestHandler{
226242
{
@@ -256,6 +272,12 @@ func buildRequests(t *testing.T, relativeURLRoot string) []testserver.TestReques
256272
fmt.Fprint(w, r.Header.Get(apiSecretHeaderName))
257273
},
258274
},
275+
{
276+
Path: "/api/v4/internal/x_forwarded_for",
277+
Handler: func(w http.ResponseWriter, r *http.Request) {
278+
fmt.Fprint(w, r.Header.Get("X-Forwarded-For"))
279+
},
280+
},
259281
{
260282
Path: "/api/v4/internal/error",
261283
Handler: func(w http.ResponseWriter, r *http.Request) {

client/gitlabnet.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ type ApiError struct {
4141
Msg string
4242
}
4343

44+
// To use as the key in a Context to set an X-Forwarded-For header in a request
45+
type OriginalRemoteIPContextKey struct{}
46+
4447
func (e *ApiError) Error() string {
4548
return e.Msg
4649
}
@@ -150,6 +153,11 @@ func (c *GitlabNetClient) DoRequest(ctx context.Context, method, path string, da
150153
}
151154
request.Header.Set(apiSecretHeaderName, tokenString)
152155

156+
originalRemoteIP, ok := ctx.Value(OriginalRemoteIPContextKey{}).(string)
157+
if ok {
158+
request.Header.Add("X-Forwarded-For", originalRemoteIP)
159+
}
160+
153161
request.Header.Add("Content-Type", "application/json")
154162
request.Header.Add("User-Agent", c.userAgent)
155163
request.Close = true

internal/gitlabnet/accessverifier/client.go

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package accessverifier
33
import (
44
"context"
55
"fmt"
6-
"net"
76
"net/http"
87

98
pb "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
@@ -86,7 +85,7 @@ func (c *Client) Verify(ctx context.Context, args *commandargs.Shell, action com
8685
request.KeyId = args.GitlabKeyId
8786
}
8887

89-
request.CheckIp = parseIP(args.Env.RemoteAddr)
88+
request.CheckIp = gitlabnet.ParseIP(args.Env.RemoteAddr)
9089

9190
response, err := c.client.Post(ctx, "/allowed", request)
9291
if err != nil {
@@ -117,18 +116,3 @@ func parse(hr *http.Response, args *commandargs.Shell) (*Response, error) {
117116
func (r *Response) IsCustomAction() bool {
118117
return r.StatusCode == http.StatusMultipleChoices
119118
}
120-
121-
func parseIP(remoteAddr string) string {
122-
// The remoteAddr field can be filled by:
123-
// 1. An IP address via the SSH_CONNECTION environment variable
124-
// 2. A host:port combination via the PROXY protocol
125-
ip, _, err := net.SplitHostPort(remoteAddr)
126-
127-
// If we don't have a port or can't parse this address for some reason,
128-
// just return the original string.
129-
if err != nil {
130-
return remoteAddr
131-
}
132-
133-
return ip
134-
}

internal/gitlabnet/client.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package gitlabnet
33
import (
44
"encoding/json"
55
"fmt"
6+
"net"
67
"net/http"
78

89
"gitlab.com/gitlab-org/gitlab-shell/client"
@@ -34,3 +35,18 @@ func ParseJSON(hr *http.Response, response interface{}) error {
3435

3536
return nil
3637
}
38+
39+
func ParseIP(remoteAddr string) string {
40+
// The remoteAddr field can be filled by:
41+
// 1. An IP address via the SSH_CONNECTION environment variable
42+
// 2. A host:port combination via the PROXY protocol
43+
ip, _, err := net.SplitHostPort(remoteAddr)
44+
45+
// If we don't have a port or can't parse this address for some reason,
46+
// just return the original string.
47+
if err != nil {
48+
return remoteAddr
49+
}
50+
51+
return ip
52+
}

internal/sshd/sshd.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ import (
1212
"github.com/pires/go-proxyproto"
1313
"golang.org/x/crypto/ssh"
1414

15+
"gitlab.com/gitlab-org/gitlab-shell/client"
1516
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
17+
"gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet"
1618
"gitlab.com/gitlab-org/gitlab-shell/internal/metrics"
1719

1820
"gitlab.com/gitlab-org/labkit/correlation"
@@ -145,13 +147,26 @@ func (s *Server) getStatus() status {
145147
return s.status
146148
}
147149

150+
func contextWithValues(parent context.Context, nconn net.Conn) context.Context {
151+
ctx := correlation.ContextWithCorrelation(parent, correlation.SafeRandomID())
152+
153+
// If we're dealing with a PROXY connection, register the original requester's IP
154+
mconn, ok := nconn.(*proxyproto.Conn)
155+
if ok {
156+
ip := gitlabnet.ParseIP(mconn.Raw().RemoteAddr().String())
157+
ctx = context.WithValue(ctx, client.OriginalRemoteIPContextKey{}, ip)
158+
}
159+
160+
return ctx
161+
}
162+
148163
func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
149164
defer s.wg.Done()
150165

151166
metrics.SshdConnectionsInFlight.Inc()
152167
defer metrics.SshdConnectionsInFlight.Dec()
153168

154-
ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
169+
ctx, cancel := context.WithCancel(contextWithValues(ctx, nconn))
155170
defer cancel()
156171
go func() {
157172
<-ctx.Done()

internal/sshd/sshd_test.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ const (
2727

2828
var (
2929
correlationId = ""
30+
xForwardedFor = ""
3031
)
3132

3233
func TestListenAndServe(t *testing.T) {
@@ -63,6 +64,10 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin
6364
},
6465
DestinationAddr: target,
6566
}
67+
xForwardedFor = "127.0.0.1"
68+
defer func() {
69+
xForwardedFor = "" // Cleanup for other test cases
70+
}()
6671

6772
testCases := []struct {
6873
desc string
@@ -132,16 +137,20 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin
132137
require.NoError(t, err)
133138
}
134139

135-
sshConn, _, _, err := ssh.NewClientConn(conn, serverUrl, clientConfig(t))
140+
sshConn, sshChans, sshRequs, err := ssh.NewClientConn(conn, serverUrl, clientConfig(t))
136141
if sshConn != nil {
137-
sshConn.Close()
142+
defer sshConn.Close()
138143
}
139144

140145
if tc.isRejected {
141146
require.Error(t, err, "Expected plain SSH request to be failed")
142147
require.Regexp(t, "ssh: handshake failed", err.Error())
143148
} else {
144149
require.NoError(t, err)
150+
client := ssh.NewClient(sshConn, sshChans, sshRequs)
151+
defer client.Close()
152+
153+
holdSession(t, client)
145154
}
146155
})
147156
}
@@ -306,13 +315,15 @@ func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Contex
306315
correlationId = r.Header.Get("X-Request-Id")
307316

308317
require.NotEmpty(t, correlationId)
318+
require.Equal(t, xForwardedFor, r.Header.Get("X-Forwarded-For"))
309319

310320
fmt.Fprint(w, `{"id": 1000, "key": "key"}`)
311321
},
312322
}, {
313323
Path: "/api/v4/internal/discover",
314324
Handler: func(w http.ResponseWriter, r *http.Request) {
315325
require.Equal(t, correlationId, r.Header.Get("X-Request-Id"))
326+
require.Equal(t, xForwardedFor, r.Header.Get("X-Forwarded-For"))
316327

317328
fmt.Fprint(w, `{"id": 1000, "name": "Test User", "username": "test-user"}`)
318329
},

0 commit comments

Comments
 (0)