Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ require (
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.5 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/gosimple/slug v1.15.0 // indirect
github.com/gosimple/unidecode v1.0.1 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBY
github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gosimple/slug v1.15.0 h1:wRZHsRrRcs6b0XnxMUBM6WK1U1Vg5B0R7VkIf1Xzobo=
github.com/gosimple/slug v1.15.0/go.mod h1:UiRaFH+GEilHstLUmcBgWcI42viBN7mAb818JrYOeFQ=
github.com/gosimple/unidecode v1.0.1 h1:hZzFTMMqSswvf0LBJZCZgThIZrpDHFXux9KeGmn6T/o=
Expand Down
3 changes: 2 additions & 1 deletion packages/api/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,8 @@ type PAMAccessApprovalRequestResponse struct {
}

type PAMSessionCredentialsResponse struct {
Credentials PAMSessionCredentials `json:"credentials"`
Credentials PAMSessionCredentials `json:"credentials"`
SharedSecret string `json:"sharedSecret,omitempty"`
}

type PAMSessionCredentials struct {
Expand Down
1 change: 1 addition & 0 deletions packages/cmd/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ var relayStartCmd = &cobra.Command{
RelayName: relayName,
SSHPort: "2222",
TLSPort: "8443",
WSPort: "8444",
Host: host,
Type: instanceType,
})
Expand Down
66 changes: 52 additions & 14 deletions packages/gateway-v2/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net"
"strconv"
"strings"
Expand Down Expand Up @@ -550,17 +551,39 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) {

go ssh.DiscardRequests(requests)

// Peek first byte to detect connection type:
// 0x16 = TLS ClientHello → mTLS flow (CLI clients)
// 0x00 = Web ECDH magic byte → web flow (browser clients)
firstByte := make([]byte, 1)
if _, err := io.ReadFull(channel, firstByte); err != nil {
log.Info().Msgf("Failed to read first byte: %v", err)
return
}

switch firstByte[0] {
case 0x16:
g.handleMTLSConnection(channel, firstByte)
case 0x00:
// ECDH magic byte - handle web proxy connection
virtualConn := &prefixedVirtualConnection{channel: channel}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Creating prefixedVirtualConnection without passing the firstByte that was read. This means the magic byte (0x00) won't be available to the PAM web proxy handler, potentially causing protocol issues.

Suggested change
virtualConn := &prefixedVirtualConnection{channel: channel}
virtualConn := newPrefixedVirtualConnection(firstByte, channel)

if err := pam.HandlePAMWebProxy(g.ctx, virtualConn, g.httpClient, g.pamCredentialsManager, g.pamSessionUploader); err != nil {
log.Error().Err(err).Msg("PAM web proxy handler ended with error")
}
default:
log.Warn().Msgf("Unknown protocol byte: 0x%02x, closing channel", firstByte[0])
}
}

func (g *Gateway) handleMTLSConnection(channel ssh.Channel, firstByte []byte) {
// Create mTLS server configuration
tlsConfig := g.tlsConfig
if tlsConfig == nil {
log.Info().Msgf("TLS config not initialized, cannot create mTLS server")
return
}

// Create a virtual connection that pipes data between SSH channel and TLS
virtualConn := &virtualConnection{
channel: channel,
}
// Create a virtual connection that prepends the peeked byte
virtualConn := newPrefixedVirtualConnection(firstByte, channel)

// Wrap the virtual connection with TLS
tlsConn := tls.Server(virtualConn, tlsConfig)
Expand Down Expand Up @@ -776,40 +799,55 @@ func (g *Gateway) parseDetailsFromCertificate(tlsConn *tls.Conn, config *Forward
return nil
}

// virtualConnection implements net.Conn to bridge SSH channel and TLS
type virtualConnection struct {
// prefixedVirtualConnection implements net.Conn, prepending buffered bytes before reading from the channel
type prefixedVirtualConnection struct {
channel ssh.Channel
prefix []byte
offset int
}

func (vc *virtualConnection) Read(b []byte) (n int, err error) {
func newPrefixedVirtualConnection(prefix []byte, channel ssh.Channel) *prefixedVirtualConnection {
return &prefixedVirtualConnection{
channel: channel,
prefix: prefix,
offset: 0,
}
}

func (vc *prefixedVirtualConnection) Read(b []byte) (int, error) {
if vc.offset < len(vc.prefix) {
n := copy(b, vc.prefix[vc.offset:])
vc.offset += n
return n, nil
}
return vc.channel.Read(b)
}

func (vc *virtualConnection) Write(b []byte) (n int, err error) {
func (vc *prefixedVirtualConnection) Write(b []byte) (int, error) {
return vc.channel.Write(b)
}

func (vc *virtualConnection) Close() error {
func (vc *prefixedVirtualConnection) Close() error {
return vc.channel.Close()
}

func (vc *virtualConnection) LocalAddr() net.Addr {
func (vc *prefixedVirtualConnection) LocalAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
}

func (vc *virtualConnection) RemoteAddr() net.Addr {
func (vc *prefixedVirtualConnection) RemoteAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
}

func (vc *virtualConnection) SetDeadline(t time.Time) error {
func (vc *prefixedVirtualConnection) SetDeadline(t time.Time) error {
return nil
}

func (vc *virtualConnection) SetReadDeadline(t time.Time) error {
func (vc *prefixedVirtualConnection) SetReadDeadline(t time.Time) error {
return nil
}

func (vc *virtualConnection) SetWriteDeadline(t time.Time) error {
func (vc *prefixedVirtualConnection) SetWriteDeadline(t time.Time) error {
return nil
}

Expand Down
140 changes: 140 additions & 0 deletions packages/pam/encrypted_conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package pam

import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"fmt"
"io"
"net"
"sync"
"time"
)

// EncryptedConn wraps a net.Conn with AES-256-GCM encryption.
// Frame format: [4-byte big-endian total-frame-length][12-byte random nonce][ciphertext + 16-byte GCM auth tag]
type EncryptedConn struct {
inner net.Conn
gcm cipher.AEAD

readMu sync.Mutex
readBuf []byte

writeMu sync.Mutex
}

func NewEncryptedConn(inner net.Conn, aesKey []byte) (*EncryptedConn, error) {
block, err := aes.NewCipher(aesKey)
if err != nil {
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
return &EncryptedConn{
inner: inner,
gcm: gcm,
}, nil
}

func (ec *EncryptedConn) Read(b []byte) (int, error) {
ec.readMu.Lock()
defer ec.readMu.Unlock()

// Serve from buffer if available
if len(ec.readBuf) > 0 {
n := copy(b, ec.readBuf)
ec.readBuf = ec.readBuf[n:]
return n, nil
}

// Read frame length
lengthBuf := make([]byte, 4)
if _, err := io.ReadFull(ec.inner, lengthBuf); err != nil {
return 0, err
}
frameLen := binary.BigEndian.Uint32(lengthBuf)
if frameLen > 1<<24 {
return 0, fmt.Errorf("encrypted frame too large: %d bytes", frameLen)
}

// Read the full frame (nonce + ciphertext)
frame := make([]byte, frameLen)
if _, err := io.ReadFull(ec.inner, frame); err != nil {
return 0, fmt.Errorf("failed to read encrypted frame: %w", err)
}

nonceSize := ec.gcm.NonceSize()
if int(frameLen) < nonceSize {
return 0, fmt.Errorf("frame too short for nonce")
}

nonce := frame[:nonceSize]
ciphertext := frame[nonceSize:]

plaintext, err := ec.gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return 0, fmt.Errorf("GCM decryption failed: %w", err)
}

n := copy(b, plaintext)
if n < len(plaintext) {
ec.readBuf = plaintext[n:]
}
return n, nil
}

func (ec *EncryptedConn) Write(b []byte) (int, error) {
ec.writeMu.Lock()
defer ec.writeMu.Unlock()

nonce := make([]byte, ec.gcm.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return 0, fmt.Errorf("failed to generate nonce: %w", err)
}

ciphertext := ec.gcm.Seal(nil, nonce, b, nil)

// Frame = nonce + ciphertext
frameLen := len(nonce) + len(ciphertext)
lengthBuf := make([]byte, 4)
binary.BigEndian.PutUint32(lengthBuf, uint32(frameLen))

if _, err := ec.inner.Write(lengthBuf); err != nil {
return 0, fmt.Errorf("failed to write frame length: %w", err)
}
if _, err := ec.inner.Write(nonce); err != nil {
return 0, fmt.Errorf("failed to write nonce: %w", err)
}
if _, err := ec.inner.Write(ciphertext); err != nil {
return 0, fmt.Errorf("failed to write ciphertext: %w", err)
}

return len(b), nil
}

func (ec *EncryptedConn) Close() error {
return ec.inner.Close()
}

func (ec *EncryptedConn) LocalAddr() net.Addr {
return ec.inner.LocalAddr()
}

func (ec *EncryptedConn) RemoteAddr() net.Addr {
return ec.inner.RemoteAddr()
}

func (ec *EncryptedConn) SetDeadline(t time.Time) error {
return ec.inner.SetDeadline(t)
}

func (ec *EncryptedConn) SetReadDeadline(t time.Time) error {
return ec.inner.SetReadDeadline(t)
}

func (ec *EncryptedConn) SetWriteDeadline(t time.Time) error {
return ec.inner.SetWriteDeadline(t)
}
Loading