Skip to content

Commit d24136a

Browse files
committed
implement websockets natively
1 parent f9a2246 commit d24136a

File tree

14 files changed

+360
-839
lines changed

14 files changed

+360
-839
lines changed

.pipelines/scripts/e2e_run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ rm -f "$temp_file"
9494
# gotestsum configure to only show logs for failed tests, json file for detailed logs
9595
# Run the tests! Yey!
9696
test_exit_code=0
97-
./bin/gotestsum --format testdox --junitfile "${BUILD_SRC_DIR}/e2e/report.xml" --jsonfile "${BUILD_SRC_DIR}/e2e/test-log.json" -- -parallel 100 -timeout 90m || test_exit_code=$?
97+
./bin/gotestsum --format standard-verbose --junitfile "${BUILD_SRC_DIR}/e2e/report.xml" --jsonfile "${BUILD_SRC_DIR}/e2e/test-log.json" -- -parallel 150 -timeout 30m || test_exit_code=$?
9898

9999
# Upload test results as Azure DevOps artifacts
100100
echo "##vso[artifact.upload containerfolder=test-results;artifactname=e2e-test-log]${BUILD_SRC_DIR}/e2e/test-log.json"

e2e/bastionssh.go

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
package e2e
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"fmt"
8+
"io"
9+
"net"
10+
"net/http"
11+
"net/url"
12+
"strings"
13+
"time"
14+
15+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
16+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
17+
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
18+
"github.com/coder/websocket"
19+
"golang.org/x/crypto/ssh"
20+
)
21+
22+
type Bastion struct {
23+
credential *azidentity.AzureCLICredential
24+
subscriptionID, resourceGroupName, dnsName string
25+
httpClient *http.Client
26+
httpTransport *http.Transport
27+
}
28+
29+
func NewBastion(credential *azidentity.AzureCLICredential, subscriptionID, resourceGroupName, dnsName string) *Bastion {
30+
transport := &http.Transport{
31+
MaxIdleConns: 100,
32+
MaxIdleConnsPerHost: 100,
33+
IdleConnTimeout: 30 * time.Second,
34+
}
35+
36+
return &Bastion{
37+
credential: credential,
38+
subscriptionID: subscriptionID,
39+
resourceGroupName: resourceGroupName,
40+
dnsName: dnsName,
41+
httpTransport: transport,
42+
httpClient: &http.Client{
43+
Transport: transport,
44+
Timeout: 30 * time.Second,
45+
},
46+
}
47+
}
48+
49+
type tunnelSession struct {
50+
bastion *Bastion
51+
ws *websocket.Conn
52+
session *sessionToken
53+
}
54+
55+
func (b *Bastion) NewTunnelSession(targetHost string, port uint16) (*tunnelSession, error) {
56+
session, err := b.newSessionToken(targetHost, port)
57+
if err != nil {
58+
return nil, err
59+
}
60+
61+
wsUrl := fmt.Sprintf("wss://%v/webtunnelv2/%v?X-Node-Id=%v", b.dnsName, session.WebsocketToken, session.NodeID)
62+
63+
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
64+
ws, _, err := websocket.Dial(ctx, wsUrl, &websocket.DialOptions{
65+
CompressionMode: websocket.CompressionDisabled,
66+
})
67+
cancel()
68+
if err != nil {
69+
return nil, err
70+
}
71+
72+
ws.SetReadLimit(32 * 1024 * 1024)
73+
74+
return &tunnelSession{
75+
bastion: b,
76+
ws: ws,
77+
session: session,
78+
}, nil
79+
}
80+
81+
type sessionToken struct {
82+
AuthToken string `json:"authToken"`
83+
Username string `json:"username"`
84+
DataSource string `json:"dataSource"`
85+
NodeID string `json:"nodeId"`
86+
AvailableDataSources []string `json:"availableDataSources"`
87+
WebsocketToken string `json:"websocketToken"`
88+
}
89+
90+
func (t *tunnelSession) Close() error {
91+
_ = t.ws.Close(websocket.StatusNormalClosure, "")
92+
93+
req, err := http.NewRequest("DELETE", fmt.Sprintf("https://%v/api/tokens/%v", t.bastion.dnsName, t.session.AuthToken), nil)
94+
if err != nil {
95+
return err
96+
}
97+
98+
req.Header.Add("X-Node-Id", t.session.NodeID)
99+
100+
resp, err := t.bastion.httpClient.Do(req)
101+
if err != nil {
102+
return err
103+
}
104+
defer resp.Body.Close()
105+
106+
if resp.StatusCode == 404 {
107+
return nil
108+
}
109+
110+
if resp.StatusCode != 204 {
111+
return fmt.Errorf("unexpected status code: %v", resp.StatusCode)
112+
}
113+
114+
if t.bastion.httpTransport != nil {
115+
t.bastion.httpTransport.CloseIdleConnections()
116+
}
117+
118+
return nil
119+
}
120+
121+
func (b *Bastion) newSessionToken(targetHost string, port uint16) (*sessionToken, error) {
122+
123+
token, err := b.credential.GetToken(context.Background(), policy.TokenRequestOptions{
124+
Scopes: []string{fmt.Sprintf("%s/.default", cloud.AzurePublic.Services[cloud.ResourceManager].Endpoint)},
125+
})
126+
127+
if err != nil {
128+
return nil, err
129+
}
130+
131+
apiUrl := fmt.Sprintf("https://%v/api/tokens", b.dnsName)
132+
133+
// target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"
134+
data := url.Values{}
135+
data.Set("resourceId", fmt.Sprintf("/subscriptions/%v/resourceGroups/%v/providers/Microsoft.Network/bh-hostConnect/%v", b.subscriptionID, b.resourceGroupName, targetHost))
136+
data.Set("protocol", "tcptunnel")
137+
data.Set("workloadHostPort", fmt.Sprintf("%v", port))
138+
data.Set("aztoken", token.Token)
139+
data.Set("hostname", targetHost)
140+
141+
req, err := http.NewRequest("POST", apiUrl, strings.NewReader(data.Encode()))
142+
if err != nil {
143+
return nil, err
144+
}
145+
146+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
147+
resp, err := b.httpClient.Do(req) // TODO client settings
148+
if err != nil {
149+
return nil, err
150+
}
151+
152+
defer resp.Body.Close()
153+
154+
if resp.StatusCode != 200 {
155+
return nil, fmt.Errorf("error creating tunnel: %v", resp.Status)
156+
}
157+
158+
var response sessionToken
159+
160+
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
161+
return nil, err
162+
}
163+
164+
return &response, nil
165+
}
166+
167+
func (t *tunnelSession) Pipe(conn net.Conn) error {
168+
169+
defer t.Close()
170+
defer conn.Close()
171+
172+
done := make(chan error, 2)
173+
174+
go func() {
175+
for {
176+
_, data, err := t.ws.Read(context.Background())
177+
if err != nil {
178+
done <- err
179+
return
180+
}
181+
182+
if _, err := io.Copy(conn, bytes.NewReader(data)); err != nil {
183+
done <- err
184+
return
185+
}
186+
}
187+
}()
188+
189+
go func() {
190+
buf := make([]byte, 4096) // 4096 is copy from az cli bastion code
191+
192+
for {
193+
n, err := conn.Read(buf)
194+
if err != nil {
195+
done <- err
196+
return
197+
}
198+
199+
if err := t.ws.Write(context.Background(), websocket.MessageBinary, buf[:n]); err != nil {
200+
done <- err
201+
return
202+
}
203+
}
204+
}()
205+
206+
return <-done
207+
}
208+
209+
func sshClientConfig(user string, privateKey []byte) (*ssh.ClientConfig, error) {
210+
signer, err := ssh.ParsePrivateKey(privateKey)
211+
if err != nil {
212+
return nil, err
213+
}
214+
215+
return &ssh.ClientConfig{
216+
User: user,
217+
Auth: []ssh.AuthMethod{
218+
ssh.PublicKeys(signer),
219+
},
220+
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
221+
return nil
222+
}, // same as StrictHostKeyChecking=no
223+
Timeout: 5 * time.Second,
224+
}, nil
225+
}
226+
227+
func DialSSHOverBastion(
228+
ctx context.Context,
229+
bastion *Bastion,
230+
vmPrivateIP string,
231+
sshPrivateKey []byte,
232+
) (*ssh.Client, error) {
233+
234+
// Create Bastion tunnel session (SSH = port 22)
235+
tunnel, err := bastion.NewTunnelSession(
236+
vmPrivateIP,
237+
22,
238+
)
239+
if err != nil {
240+
return nil, err
241+
}
242+
243+
// Create in-memory connection pair
244+
sshSide, tunnelSide := net.Pipe()
245+
246+
// Start Bastion tunnel piping
247+
go func() {
248+
_ = tunnel.Pipe(tunnelSide)
249+
fmt.Printf("Closed tunnel for VM IP %s\n", vmPrivateIP)
250+
}()
251+
252+
// SSH client configuration
253+
sshConfig, err := sshClientConfig("azureuser", sshPrivateKey)
254+
if err != nil {
255+
return nil, err
256+
}
257+
258+
// Establish SSH over the Bastion tunnel
259+
sshConn, chans, reqs, err := ssh.NewClientConn(
260+
sshSide,
261+
vmPrivateIP,
262+
sshConfig,
263+
)
264+
if err != nil {
265+
sshSide.Close()
266+
return nil, err
267+
}
268+
269+
return ssh.NewClient(sshConn, chans, reqs), nil
270+
}

e2e/cluster.go

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ type Cluster struct {
4141
ClusterParams *ClusterParams
4242
Maintenance *armcontainerservice.MaintenanceConfiguration
4343
DebugPod *corev1.Pod
44-
Bastion *armnetwork.BastionHost
44+
Bastion *Bastion
4545
}
4646

4747
// Returns true if the cluster is configured with Azure CNI
@@ -479,7 +479,7 @@ func createNewMaintenanceConfiguration(ctx context.Context, cluster *armcontaine
479479
return &maintenance, nil
480480
}
481481

482-
func getOrCreateBastion(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (*armnetwork.BastionHost, error) {
482+
func getOrCreateBastion(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (*Bastion, error) {
483483
nodeRG := *cluster.Properties.NodeResourceGroup
484484
bastionName := fmt.Sprintf("%s-bastion", *cluster.Name)
485485

@@ -491,10 +491,11 @@ func getOrCreateBastion(ctx context.Context, cluster *armcontainerservice.Manage
491491
if err != nil {
492492
return nil, fmt.Errorf("failed to get bastion %q in rg %q: %w", bastionName, nodeRG, err)
493493
}
494-
return &existing.BastionHost, nil
494+
495+
return NewBastion(config.Azure.Credential, config.Config.SubscriptionID, nodeRG, *existing.BastionHost.Properties.DNSName), nil
495496
}
496497

497-
func createNewBastion(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (*armnetwork.BastionHost, error) {
498+
func createNewBastion(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (*Bastion, error) {
498499
nodeRG := *cluster.Properties.NodeResourceGroup
499500
location := *cluster.Location
500501
bastionName := fmt.Sprintf("%s-bastion", *cluster.Name)
@@ -565,7 +566,7 @@ func createNewBastion(ctx context.Context, cluster *armcontainerservice.ManagedC
565566
return nil, fmt.Errorf("bastion public IP response missing ID")
566567
}
567568

568-
bastion := armnetwork.BastionHost{
569+
bastionHost := armnetwork.BastionHost{
569570
Location: to.Ptr(location),
570571
SKU: &armnetwork.SKU{
571572
Name: to.Ptr(armnetwork.BastionHostSKUNameStandard),
@@ -590,7 +591,7 @@ func createNewBastion(ctx context.Context, cluster *armcontainerservice.ManagedC
590591
}
591592

592593
logf(ctx, "creating bastion %s (native client/tunneling enabled) in rg %s", bastionName, nodeRG)
593-
bastionPoller, err := config.Azure.BastionHosts.BeginCreateOrUpdate(ctx, nodeRG, bastionName, bastion, nil)
594+
bastionPoller, err := config.Azure.BastionHosts.BeginCreateOrUpdate(ctx, nodeRG, bastionName, bastionHost, nil)
594595
if err != nil {
595596
return nil, fmt.Errorf("failed to start creating bastion: %w", err)
596597
}
@@ -599,13 +600,15 @@ func createNewBastion(ctx context.Context, cluster *armcontainerservice.ManagedC
599600
return nil, fmt.Errorf("failed to create bastion: %w", err)
600601
}
601602

602-
if err := verifyBastion(ctx, cluster, &resp.BastionHost); err != nil {
603+
bastion := NewBastion(config.Azure.Credential, config.Config.SubscriptionID, nodeRG, *resp.BastionHost.Properties.DNSName)
604+
605+
if err := verifyBastion(ctx, cluster, bastion); err != nil {
603606
return nil, fmt.Errorf("failed to verify bastion: %w", err)
604607
}
605-
return &resp.BastionHost, nil
608+
return bastion, nil
606609
}
607610

608-
func verifyBastion(ctx context.Context, cluster *armcontainerservice.ManagedCluster, bastion *armnetwork.BastionHost) error {
611+
func verifyBastion(ctx context.Context, cluster *armcontainerservice.ManagedCluster, bastion *Bastion) error {
609612
nodeRG := *cluster.Properties.NodeResourceGroup
610613
vmssName, err := getSystemPoolVMSSName(ctx, cluster)
611614
if err != nil {
@@ -624,23 +627,26 @@ func verifyBastion(ctx context.Context, cluster *armcontainerservice.ManagedClus
624627
}
625628
}
626629

630+
vmPrivateIP, err := getPrivateIPFromVMSSVM(ctx, nodeRG, vmssName, *vmssVM.InstanceID)
631+
627632
ctx, cancel := context.WithCancel(ctx)
628633
defer cancel()
629-
localPort, pid, err := startBastionTunnel(ctx, *bastion.Name, nodeRG, *vmssVM.ID)
634+
635+
sshClient, err := DialSSHOverBastion(ctx, bastion, vmPrivateIP, config.SysSSHPrivateKey)
630636
if err != nil {
631637
return err
632638
}
633639

634-
defer cleanupBastionTunnel(localPort, pid)
640+
defer sshClient.Close()
635641

636-
result, err := runSSHCommandWithPrivateKeyFile(ctx, localPort, "uname -a", config.SysSSHPrivateKey)
642+
result, err := runSSHCommandWithPrivateKeyFile(ctx, sshClient, "uname -a")
637643
if err != nil {
638644
return err
639645
}
640-
if strings.Contains(result.stdout, *vmssVM.Name) {
646+
if strings.Contains(result.stdout, vmssName) {
641647
return nil
642648
}
643-
return fmt.Errorf("Executed ssh on wrong VM: %s", result.stdout)
649+
return fmt.Errorf("Executed ssh on wrong VM, Expected %s: %s", vmssName, result.stdout)
644650
}
645651

646652
func getSystemPoolVMSSName(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (string, error) {

e2e/config/config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ func mustLoadConfig() *Configuration {
145145
if cfg.SysSSHPrivateKeyB64 == "" {
146146
SysSSHPrivateKeyFileName = VMSSHPrivateKeyFileName
147147
} else {
148-
SysSSHPrivateKey, err := base64.StdEncoding.DecodeString(cfg.SysSSHPrivateKeyB64)
148+
SysSSHPrivateKey, err = base64.StdEncoding.DecodeString(cfg.SysSSHPrivateKeyB64)
149149
if err != nil {
150150
panic(err)
151151
}

e2e/config/vhd.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,7 @@ func GetRandomLinuxAMD64VHD() *Image {
323323
vhds := []*Image{
324324
VHDUbuntu2404Gen2Containerd,
325325
VHDUbuntu2204Gen2Containerd,
326-
VHDAzureLinuxV2Gen2,
327326
VHDAzureLinuxV3Gen2,
328-
VHDCBLMarinerV2Gen2,
329327
}
330328

331329
// Return a random VHD from the list

0 commit comments

Comments
 (0)