Skip to content

Commit 298f555

Browse files
committed
Add PAM web proxy
1 parent 9ef9396 commit 298f555

File tree

10 files changed

+804
-15
lines changed

10 files changed

+804
-15
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ require (
111111
github.com/google/s2a-go v0.1.7 // indirect
112112
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
113113
github.com/googleapis/gax-go/v2 v2.12.5 // indirect
114+
github.com/gorilla/websocket v1.5.3 // indirect
114115
github.com/gosimple/slug v1.15.0 // indirect
115116
github.com/gosimple/unidecode v1.0.1 // indirect
116117
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,8 @@ github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBY
318318
github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E=
319319
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8=
320320
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
321+
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
322+
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
321323
github.com/gosimple/slug v1.15.0 h1:wRZHsRrRcs6b0XnxMUBM6WK1U1Vg5B0R7VkIf1Xzobo=
322324
github.com/gosimple/slug v1.15.0/go.mod h1:UiRaFH+GEilHstLUmcBgWcI42viBN7mAb818JrYOeFQ=
323325
github.com/gosimple/unidecode v1.0.1 h1:hZzFTMMqSswvf0LBJZCZgThIZrpDHFXux9KeGmn6T/o=

packages/api/model.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,8 @@ type PAMAccessApprovalRequestResponse struct {
825825
}
826826

827827
type PAMSessionCredentialsResponse struct {
828-
Credentials PAMSessionCredentials `json:"credentials"`
828+
Credentials PAMSessionCredentials `json:"credentials"`
829+
SharedSecret string `json:"sharedSecret,omitempty"`
829830
}
830831

831832
type PAMSessionCredentials struct {

packages/cmd/relay.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ var relayStartCmd = &cobra.Command{
5353
RelayName: relayName,
5454
SSHPort: "2222",
5555
TLSPort: "8443",
56+
WSPort: "8444",
5657
Host: host,
5758
Type: instanceType,
5859
})

packages/gateway-v2/gateway.go

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"encoding/json"
1111
"encoding/pem"
1212
"fmt"
13+
"io"
1314
"net"
1415
"strconv"
1516
"strings"
@@ -550,17 +551,39 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) {
550551

551552
go ssh.DiscardRequests(requests)
552553

554+
// Peek first byte to detect connection type:
555+
// 0x16 = TLS ClientHello → mTLS flow (CLI clients)
556+
// 0x00 = Web ECDH magic byte → web flow (browser clients)
557+
firstByte := make([]byte, 1)
558+
if _, err := io.ReadFull(channel, firstByte); err != nil {
559+
log.Info().Msgf("Failed to read first byte: %v", err)
560+
return
561+
}
562+
563+
switch firstByte[0] {
564+
case 0x16:
565+
g.handleMTLSConnection(channel, firstByte)
566+
case 0x00:
567+
// ECDH magic byte - handle web proxy connection
568+
virtualConn := &prefixedVirtualConnection{channel: channel}
569+
if err := pam.HandlePAMWebProxy(g.ctx, virtualConn, g.httpClient, g.pamCredentialsManager, g.pamSessionUploader); err != nil {
570+
log.Error().Err(err).Msg("PAM web proxy handler ended with error")
571+
}
572+
default:
573+
log.Warn().Msgf("Unknown protocol byte: 0x%02x, closing channel", firstByte[0])
574+
}
575+
}
576+
577+
func (g *Gateway) handleMTLSConnection(channel ssh.Channel, firstByte []byte) {
553578
// Create mTLS server configuration
554579
tlsConfig := g.tlsConfig
555580
if tlsConfig == nil {
556581
log.Info().Msgf("TLS config not initialized, cannot create mTLS server")
557582
return
558583
}
559584

560-
// Create a virtual connection that pipes data between SSH channel and TLS
561-
virtualConn := &virtualConnection{
562-
channel: channel,
563-
}
585+
// Create a virtual connection that prepends the peeked byte
586+
virtualConn := newPrefixedVirtualConnection(firstByte, channel)
564587

565588
// Wrap the virtual connection with TLS
566589
tlsConn := tls.Server(virtualConn, tlsConfig)
@@ -776,40 +799,55 @@ func (g *Gateway) parseDetailsFromCertificate(tlsConn *tls.Conn, config *Forward
776799
return nil
777800
}
778801

779-
// virtualConnection implements net.Conn to bridge SSH channel and TLS
780-
type virtualConnection struct {
802+
// prefixedVirtualConnection implements net.Conn, prepending buffered bytes before reading from the channel
803+
type prefixedVirtualConnection struct {
781804
channel ssh.Channel
805+
prefix []byte
806+
offset int
782807
}
783808

784-
func (vc *virtualConnection) Read(b []byte) (n int, err error) {
809+
func newPrefixedVirtualConnection(prefix []byte, channel ssh.Channel) *prefixedVirtualConnection {
810+
return &prefixedVirtualConnection{
811+
channel: channel,
812+
prefix: prefix,
813+
offset: 0,
814+
}
815+
}
816+
817+
func (vc *prefixedVirtualConnection) Read(b []byte) (int, error) {
818+
if vc.offset < len(vc.prefix) {
819+
n := copy(b, vc.prefix[vc.offset:])
820+
vc.offset += n
821+
return n, nil
822+
}
785823
return vc.channel.Read(b)
786824
}
787825

788-
func (vc *virtualConnection) Write(b []byte) (n int, err error) {
826+
func (vc *prefixedVirtualConnection) Write(b []byte) (int, error) {
789827
return vc.channel.Write(b)
790828
}
791829

792-
func (vc *virtualConnection) Close() error {
830+
func (vc *prefixedVirtualConnection) Close() error {
793831
return vc.channel.Close()
794832
}
795833

796-
func (vc *virtualConnection) LocalAddr() net.Addr {
834+
func (vc *prefixedVirtualConnection) LocalAddr() net.Addr {
797835
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
798836
}
799837

800-
func (vc *virtualConnection) RemoteAddr() net.Addr {
838+
func (vc *prefixedVirtualConnection) RemoteAddr() net.Addr {
801839
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
802840
}
803841

804-
func (vc *virtualConnection) SetDeadline(t time.Time) error {
842+
func (vc *prefixedVirtualConnection) SetDeadline(t time.Time) error {
805843
return nil
806844
}
807845

808-
func (vc *virtualConnection) SetReadDeadline(t time.Time) error {
846+
func (vc *prefixedVirtualConnection) SetReadDeadline(t time.Time) error {
809847
return nil
810848
}
811849

812-
func (vc *virtualConnection) SetWriteDeadline(t time.Time) error {
850+
func (vc *prefixedVirtualConnection) SetWriteDeadline(t time.Time) error {
813851
return nil
814852
}
815853

packages/pam/encrypted_conn.go

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package pam
2+
3+
import (
4+
"crypto/aes"
5+
"crypto/cipher"
6+
"crypto/rand"
7+
"encoding/binary"
8+
"fmt"
9+
"io"
10+
"net"
11+
"sync"
12+
"time"
13+
)
14+
15+
// EncryptedConn wraps a net.Conn with AES-256-GCM encryption.
16+
// Frame format: [4-byte big-endian total-frame-length][12-byte random nonce][ciphertext + 16-byte GCM auth tag]
17+
type EncryptedConn struct {
18+
inner net.Conn
19+
gcm cipher.AEAD
20+
21+
readMu sync.Mutex
22+
readBuf []byte
23+
24+
writeMu sync.Mutex
25+
}
26+
27+
func NewEncryptedConn(inner net.Conn, aesKey []byte) (*EncryptedConn, error) {
28+
block, err := aes.NewCipher(aesKey)
29+
if err != nil {
30+
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
31+
}
32+
gcm, err := cipher.NewGCM(block)
33+
if err != nil {
34+
return nil, fmt.Errorf("failed to create GCM: %w", err)
35+
}
36+
return &EncryptedConn{
37+
inner: inner,
38+
gcm: gcm,
39+
}, nil
40+
}
41+
42+
func (ec *EncryptedConn) Read(b []byte) (int, error) {
43+
ec.readMu.Lock()
44+
defer ec.readMu.Unlock()
45+
46+
// Serve from buffer if available
47+
if len(ec.readBuf) > 0 {
48+
n := copy(b, ec.readBuf)
49+
ec.readBuf = ec.readBuf[n:]
50+
return n, nil
51+
}
52+
53+
// Read frame length
54+
lengthBuf := make([]byte, 4)
55+
if _, err := io.ReadFull(ec.inner, lengthBuf); err != nil {
56+
return 0, err
57+
}
58+
frameLen := binary.BigEndian.Uint32(lengthBuf)
59+
if frameLen > 1<<24 {
60+
return 0, fmt.Errorf("encrypted frame too large: %d bytes", frameLen)
61+
}
62+
63+
// Read the full frame (nonce + ciphertext)
64+
frame := make([]byte, frameLen)
65+
if _, err := io.ReadFull(ec.inner, frame); err != nil {
66+
return 0, fmt.Errorf("failed to read encrypted frame: %w", err)
67+
}
68+
69+
nonceSize := ec.gcm.NonceSize()
70+
if int(frameLen) < nonceSize {
71+
return 0, fmt.Errorf("frame too short for nonce")
72+
}
73+
74+
nonce := frame[:nonceSize]
75+
ciphertext := frame[nonceSize:]
76+
77+
plaintext, err := ec.gcm.Open(nil, nonce, ciphertext, nil)
78+
if err != nil {
79+
return 0, fmt.Errorf("GCM decryption failed: %w", err)
80+
}
81+
82+
n := copy(b, plaintext)
83+
if n < len(plaintext) {
84+
ec.readBuf = plaintext[n:]
85+
}
86+
return n, nil
87+
}
88+
89+
func (ec *EncryptedConn) Write(b []byte) (int, error) {
90+
ec.writeMu.Lock()
91+
defer ec.writeMu.Unlock()
92+
93+
nonce := make([]byte, ec.gcm.NonceSize())
94+
if _, err := rand.Read(nonce); err != nil {
95+
return 0, fmt.Errorf("failed to generate nonce: %w", err)
96+
}
97+
98+
ciphertext := ec.gcm.Seal(nil, nonce, b, nil)
99+
100+
// Frame = nonce + ciphertext
101+
frameLen := len(nonce) + len(ciphertext)
102+
lengthBuf := make([]byte, 4)
103+
binary.BigEndian.PutUint32(lengthBuf, uint32(frameLen))
104+
105+
if _, err := ec.inner.Write(lengthBuf); err != nil {
106+
return 0, fmt.Errorf("failed to write frame length: %w", err)
107+
}
108+
if _, err := ec.inner.Write(nonce); err != nil {
109+
return 0, fmt.Errorf("failed to write nonce: %w", err)
110+
}
111+
if _, err := ec.inner.Write(ciphertext); err != nil {
112+
return 0, fmt.Errorf("failed to write ciphertext: %w", err)
113+
}
114+
115+
return len(b), nil
116+
}
117+
118+
func (ec *EncryptedConn) Close() error {
119+
return ec.inner.Close()
120+
}
121+
122+
func (ec *EncryptedConn) LocalAddr() net.Addr {
123+
return ec.inner.LocalAddr()
124+
}
125+
126+
func (ec *EncryptedConn) RemoteAddr() net.Addr {
127+
return ec.inner.RemoteAddr()
128+
}
129+
130+
func (ec *EncryptedConn) SetDeadline(t time.Time) error {
131+
return ec.inner.SetDeadline(t)
132+
}
133+
134+
func (ec *EncryptedConn) SetReadDeadline(t time.Time) error {
135+
return ec.inner.SetReadDeadline(t)
136+
}
137+
138+
func (ec *EncryptedConn) SetWriteDeadline(t time.Time) error {
139+
return ec.inner.SetWriteDeadline(t)
140+
}

0 commit comments

Comments
 (0)