Skip to content

Commit 29fe9d1

Browse files
committed
implement websockets natively
1 parent c225a3d commit 29fe9d1

File tree

18 files changed

+430
-872
lines changed

18 files changed

+430
-872
lines changed

.pipelines/scripts/e2e_run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ rm -f "$temp_file"
9494
# gotestsum configure to only show logs for failed tests, json file for detailed logs
9595
# Run the tests! Yey!
9696
test_exit_code=0
97-
./bin/gotestsum --format testdox --junitfile "${BUILD_SRC_DIR}/e2e/report.xml" --jsonfile "${BUILD_SRC_DIR}/e2e/test-log.json" -- -parallel 100 -timeout 90m || test_exit_code=$?
97+
./bin/gotestsum --format testdox --junitfile "${BUILD_SRC_DIR}/e2e/report.xml" --jsonfile "${BUILD_SRC_DIR}/e2e/test-log.json" -- -parallel 150 -timeout 90m || test_exit_code=$?
9898

9999
# Upload test results as Azure DevOps artifacts
100100
echo "##vso[artifact.upload containerfolder=test-results;artifactname=e2e-test-log]${BUILD_SRC_DIR}/e2e/test-log.json"

e2e/bastionssh.go

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
package e2e
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"fmt"
8+
"io"
9+
"net"
10+
"net/http"
11+
"net/url"
12+
"strings"
13+
"time"
14+
15+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
16+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
17+
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
18+
"github.com/coder/websocket"
19+
"golang.org/x/crypto/ssh"
20+
)
21+
22+
type Bastion struct {
23+
credential *azidentity.AzureCLICredential
24+
subscriptionID, resourceGroupName, dnsName string
25+
httpClient *http.Client
26+
httpTransport *http.Transport
27+
}
28+
29+
func NewBastion(credential *azidentity.AzureCLICredential, subscriptionID, resourceGroupName, dnsName string) *Bastion {
30+
transport := &http.Transport{
31+
MaxIdleConns: 100,
32+
MaxIdleConnsPerHost: 100,
33+
IdleConnTimeout: 30 * time.Second,
34+
}
35+
36+
return &Bastion{
37+
credential: credential,
38+
subscriptionID: subscriptionID,
39+
resourceGroupName: resourceGroupName,
40+
dnsName: dnsName,
41+
httpTransport: transport,
42+
httpClient: &http.Client{
43+
Transport: transport,
44+
Timeout: 30 * time.Second,
45+
},
46+
}
47+
}
48+
49+
type tunnelSession struct {
50+
bastion *Bastion
51+
ws *websocket.Conn
52+
session *sessionToken
53+
}
54+
55+
func (b *Bastion) NewTunnelSession(targetHost string, port uint16) (*tunnelSession, error) {
56+
session, err := b.newSessionToken(targetHost, port)
57+
if err != nil {
58+
return nil, err
59+
}
60+
61+
wsUrl := fmt.Sprintf("wss://%v/webtunnelv2/%v?X-Node-Id=%v", b.dnsName, session.WebsocketToken, session.NodeID)
62+
63+
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
64+
ws, _, err := websocket.Dial(ctx, wsUrl, &websocket.DialOptions{
65+
CompressionMode: websocket.CompressionDisabled,
66+
})
67+
cancel()
68+
if err != nil {
69+
return nil, err
70+
}
71+
72+
ws.SetReadLimit(32 * 1024 * 1024)
73+
74+
return &tunnelSession{
75+
bastion: b,
76+
ws: ws,
77+
session: session,
78+
}, nil
79+
}
80+
81+
type sessionToken struct {
82+
AuthToken string `json:"authToken"`
83+
Username string `json:"username"`
84+
DataSource string `json:"dataSource"`
85+
NodeID string `json:"nodeId"`
86+
AvailableDataSources []string `json:"availableDataSources"`
87+
WebsocketToken string `json:"websocketToken"`
88+
}
89+
90+
func (t *tunnelSession) Close() error {
91+
_ = t.ws.Close(websocket.StatusNormalClosure, "")
92+
93+
req, err := http.NewRequest("DELETE", fmt.Sprintf("https://%v/api/tokens/%v", t.bastion.dnsName, t.session.AuthToken), nil)
94+
if err != nil {
95+
return err
96+
}
97+
98+
req.Header.Add("X-Node-Id", t.session.NodeID)
99+
100+
resp, err := t.bastion.httpClient.Do(req)
101+
if err != nil {
102+
return err
103+
}
104+
defer resp.Body.Close()
105+
106+
if resp.StatusCode == 404 {
107+
return nil
108+
}
109+
110+
if resp.StatusCode != 204 {
111+
return fmt.Errorf("unexpected status code: %v", resp.StatusCode)
112+
}
113+
114+
if t.bastion.httpTransport != nil {
115+
t.bastion.httpTransport.CloseIdleConnections()
116+
}
117+
118+
return nil
119+
}
120+
121+
func (b *Bastion) newSessionToken(targetHost string, port uint16) (*sessionToken, error) {
122+
123+
token, err := b.credential.GetToken(context.Background(), policy.TokenRequestOptions{
124+
Scopes: []string{fmt.Sprintf("%s/.default", cloud.AzurePublic.Services[cloud.ResourceManager].Endpoint)},
125+
})
126+
127+
if err != nil {
128+
return nil, err
129+
}
130+
131+
apiUrl := fmt.Sprintf("https://%v/api/tokens", b.dnsName)
132+
133+
// target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"
134+
data := url.Values{}
135+
data.Set("resourceId", fmt.Sprintf("/subscriptions/%v/resourceGroups/%v/providers/Microsoft.Network/bh-hostConnect/%v", b.subscriptionID, b.resourceGroupName, targetHost))
136+
data.Set("protocol", "tcptunnel")
137+
data.Set("workloadHostPort", fmt.Sprintf("%v", port))
138+
data.Set("aztoken", token.Token)
139+
data.Set("hostname", targetHost)
140+
141+
req, err := http.NewRequest("POST", apiUrl, strings.NewReader(data.Encode()))
142+
if err != nil {
143+
return nil, err
144+
}
145+
146+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
147+
resp, err := b.httpClient.Do(req) // TODO client settings
148+
if err != nil {
149+
return nil, err
150+
}
151+
152+
defer resp.Body.Close()
153+
154+
if resp.StatusCode != 200 {
155+
return nil, fmt.Errorf("error creating tunnel: %v", resp.Status)
156+
}
157+
158+
var response sessionToken
159+
160+
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
161+
return nil, err
162+
}
163+
164+
return &response, nil
165+
}
166+
167+
func (t *tunnelSession) Pipe(conn net.Conn) error {
168+
169+
defer t.Close()
170+
defer conn.Close()
171+
172+
done := make(chan error, 2)
173+
174+
go func() {
175+
for {
176+
_, data, err := t.ws.Read(context.Background())
177+
if err != nil {
178+
done <- err
179+
return
180+
}
181+
182+
if _, err := io.Copy(conn, bytes.NewReader(data)); err != nil {
183+
done <- err
184+
return
185+
}
186+
}
187+
}()
188+
189+
go func() {
190+
buf := make([]byte, 4096) // 4096 is copy from az cli bastion code
191+
192+
for {
193+
n, err := conn.Read(buf)
194+
if err != nil {
195+
done <- err
196+
return
197+
}
198+
199+
if err := t.ws.Write(context.Background(), websocket.MessageBinary, buf[:n]); err != nil {
200+
done <- err
201+
return
202+
}
203+
}
204+
}()
205+
206+
return <-done
207+
}
208+
209+
func sshClientConfig(user string, privateKey []byte) (*ssh.ClientConfig, error) {
210+
signer, err := ssh.ParsePrivateKey(privateKey)
211+
if err != nil {
212+
return nil, err
213+
}
214+
215+
return &ssh.ClientConfig{
216+
User: user,
217+
Auth: []ssh.AuthMethod{
218+
ssh.PublicKeys(signer),
219+
},
220+
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
221+
return nil
222+
}, // same as StrictHostKeyChecking=no
223+
Timeout: 5 * time.Second,
224+
}, nil
225+
}
226+
227+
func DialSSHOverBastion(
228+
ctx context.Context,
229+
bastion *Bastion,
230+
vmPrivateIP string,
231+
sshPrivateKey []byte,
232+
) (*ssh.Client, error) {
233+
234+
// Create Bastion tunnel session (SSH = port 22)
235+
tunnel, err := bastion.NewTunnelSession(
236+
vmPrivateIP,
237+
22,
238+
)
239+
if err != nil {
240+
return nil, err
241+
}
242+
243+
// Create in-memory connection pair
244+
sshSide, tunnelSide := net.Pipe()
245+
246+
// Start Bastion tunnel piping
247+
go func() {
248+
_ = tunnel.Pipe(tunnelSide)
249+
fmt.Printf("Closed tunnel for VM IP %s\n", vmPrivateIP)
250+
}()
251+
252+
// SSH client configuration
253+
sshConfig, err := sshClientConfig("azureuser", sshPrivateKey)
254+
if err != nil {
255+
return nil, err
256+
}
257+
258+
// Establish SSH over the Bastion tunnel
259+
sshConn, chans, reqs, err := ssh.NewClientConn(
260+
sshSide,
261+
vmPrivateIP,
262+
sshConfig,
263+
)
264+
if err != nil {
265+
sshSide.Close()
266+
return nil, err
267+
}
268+
269+
return ssh.NewClient(sshConn, chans, reqs), nil
270+
}

e2e/cache.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,25 +153,39 @@ func clusterLatestKubernetesVersion(ctx context.Context, request ClusterRequest)
153153
return prepareCluster(ctx, model, false, false)
154154
}
155155

156+
var ClusterKubenetGPU = cachedFunc(clusterKubenetGPU)
157+
158+
// clusterKubenet creates a basic cluster using kubenet networking
159+
func clusterKubenetGPU(ctx context.Context, request ClusterRequest) (*Cluster, error) {
160+
return prepareCluster(ctx, getKubenetClusterModel("abe2e-kubenet-gpu-v4", request.Location, request.K8sSystemPoolSKU), false, false)
161+
}
162+
156163
var ClusterKubenet = cachedFunc(clusterKubenet)
157164

158165
// clusterKubenet creates a basic cluster using kubenet networking
159166
func clusterKubenet(ctx context.Context, request ClusterRequest) (*Cluster, error) {
160167
return prepareCluster(ctx, getKubenetClusterModel("abe2e-kubenet-v4", request.Location, request.K8sSystemPoolSKU), false, false)
161168
}
162169

170+
var ClusterKubenetVHDCaching = cachedFunc(clusterKubenetVHDCaching)
171+
172+
// clusterKubenetVHDCaching creates a basic cluster using kubenet networking
173+
func clusterKubenetVHDCaching(ctx context.Context, request ClusterRequest) (*Cluster, error) {
174+
return prepareCluster(ctx, getKubenetClusterModel("abe2e-kubenet-vhd-caching-v4", request.Location, request.K8sSystemPoolSKU), false, false)
175+
}
176+
163177
var ClusterKubenetAirgap = cachedFunc(clusterKubenetAirgap)
164178

165179
// clusterKubenetAirgap creates an airgapped kubenet cluster (no internet access)
166180
func clusterKubenetAirgap(ctx context.Context, request ClusterRequest) (*Cluster, error) {
167-
return prepareCluster(ctx, getKubenetClusterModel("abe2e-kubenet-airgap-v2", request.Location, request.K8sSystemPoolSKU), true, false)
181+
return prepareCluster(ctx, getKubenetClusterModel("abe2e-kubenet-airgap-v3", request.Location, request.K8sSystemPoolSKU), true, false)
168182
}
169183

170184
var ClusterKubenetAirgapNonAnon = cachedFunc(clusterKubenetAirgapNonAnon)
171185

172186
// clusterKubenetAirgapNonAnon creates an airgapped kubenet cluster with non-anonymous image pulls
173187
func clusterKubenetAirgapNonAnon(ctx context.Context, request ClusterRequest) (*Cluster, error) {
174-
return prepareCluster(ctx, getKubenetClusterModel("abe2e-kubenet-nonanonpull-airgap-v2", request.Location, request.K8sSystemPoolSKU), true, true)
188+
return prepareCluster(ctx, getKubenetClusterModel("abe2e-kubenet-nonanonpull-airgap-v3", request.Location, request.K8sSystemPoolSKU), true, true)
175189
}
176190

177191
var ClusterAzureNetwork = cachedFunc(clusterAzureNetwork)

0 commit comments

Comments
 (0)