|
17 | 17 | package server
|
18 | 18 |
|
19 | 19 | import (
|
20 |
| - "bytes" |
21 | 20 | "context"
|
22 | 21 | "fmt"
|
23 | 22 | "io"
|
| 23 | + "net" |
| 24 | + "time" |
24 | 25 |
|
25 |
| - "k8s.io/utils/exec" |
| 26 | + "github.com/containerd/log" |
26 | 27 |
|
27 |
| - sandboxstore "github.com/containerd/containerd/v2/internal/cri/store/sandbox" |
28 |
| - cioutil "github.com/containerd/containerd/v2/pkg/ioutil" |
| 28 | + netutils "k8s.io/utils/net" |
29 | 29 | )
|
30 | 30 |
|
31 | 31 | func (c *criService) portForward(ctx context.Context, id string, port int32, stream io.ReadWriter) error {
|
32 |
| - stdout := cioutil.NewNopWriteCloser(stream) |
33 |
| - stderrBuffer := new(bytes.Buffer) |
34 |
| - stderr := cioutil.NewNopWriteCloser(stderrBuffer) |
35 |
| - // localhost is resolved to 127.0.0.1 in ipv4, and ::1 in ipv6. |
36 |
| - // Explicitly using ipv4 IP address in here to avoid flakiness. |
37 |
| - cmd := []string{"wincat.exe", "127.0.0.1", fmt.Sprint(port)} |
38 |
| - err := c.execInSandbox(ctx, id, cmd, stream, stdout, stderr) |
| 32 | + sandbox, err := c.sandboxStore.Get(id) |
39 | 33 | if err != nil {
|
40 |
| - return fmt.Errorf("failed to execute port forward in sandbox: %s: %w", stderrBuffer.String(), err) |
| 34 | + return fmt.Errorf("failed to find sandbox %q in store: %w", id, err) |
41 | 35 | }
|
42 |
| - return nil |
43 |
| -} |
44 | 36 |
|
45 |
| -func (c *criService) execInSandbox(ctx context.Context, sandboxID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser) error { |
46 |
| - // Get sandbox from our sandbox store. |
47 |
| - sb, err := c.sandboxStore.Get(sandboxID) |
48 |
| - if err != nil { |
49 |
| - return fmt.Errorf("failed to find sandbox %q in store: %w", sandboxID, err) |
| 37 | + var podIP string |
| 38 | + if !hostNetwork(sandbox.Config) { |
| 39 | + // get ip address of the sandbox |
| 40 | + podIP, _, err = c.getIPs(sandbox) |
| 41 | + if err != nil { |
| 42 | + return fmt.Errorf("failed to get sandbox ip: %w", err) |
| 43 | + } |
| 44 | + } else { |
| 45 | + // HPCs use the host networking namespace. |
| 46 | + // Therefore, dial to localhost. |
| 47 | + podIP = "127.0.0.1" |
50 | 48 | }
|
51 | 49 |
|
52 |
| - // Check the sandbox state |
53 |
| - state := sb.Status.Get().State |
54 |
| - if state != sandboxstore.StateReady { |
55 |
| - return fmt.Errorf("sandbox is in %s state", fmt.Sprint(state)) |
56 |
| - } |
| 50 | + err = func() error { |
| 51 | + var conn net.Conn |
| 52 | + if netutils.IsIPv4String(podIP) { |
| 53 | + conn, err = net.Dial("tcp4", fmt.Sprintf("%s:%d", podIP, port)) |
| 54 | + if err != nil { |
| 55 | + return fmt.Errorf("failed to connect to %s:%d for pod %q: %v", podIP, port, id, err) |
| 56 | + } |
| 57 | + } else { |
| 58 | + conn, err = net.Dial("tcp6", fmt.Sprintf("%s:%d", podIP, port)) |
| 59 | + if err != nil { |
| 60 | + return fmt.Errorf("failed to connect to %s:%d for pod %q: %v", podIP, port, id, err) |
| 61 | + } |
| 62 | + } |
| 63 | + log.G(ctx).Debugf("Connection to ip %s and port %d was successful", podIP, port) |
| 64 | + |
| 65 | + defer conn.Close() |
| 66 | + |
| 67 | + // copy stream |
| 68 | + errCh := make(chan error, 2) |
| 69 | + // Copy from the namespace port connection to the client stream |
| 70 | + go func() { |
| 71 | + log.G(ctx).Debugf("PortForward copying data from namespace %q port %d to the client stream", id, port) |
| 72 | + _, err := io.Copy(stream, conn) |
| 73 | + errCh <- err |
| 74 | + }() |
| 75 | + |
| 76 | + // Copy from the client stream to the namespace port connection |
| 77 | + go func() { |
| 78 | + log.G(ctx).Debugf("PortForward copying data from client stream to namespace %q port %d", id, port) |
| 79 | + _, err := io.Copy(conn, stream) |
| 80 | + errCh <- err |
| 81 | + }() |
| 82 | + |
| 83 | + // Wait until the first error is returned by one of the connections |
| 84 | + // we use errFwd to store the result of the port forwarding operation |
| 85 | + // if the context is cancelled close everything and return |
| 86 | + var errFwd error |
| 87 | + select { |
| 88 | + case errFwd = <-errCh: |
| 89 | + log.G(ctx).Debugf("PortForward stop forwarding in one direction in network namespace %q port %d: %v", id, port, errFwd) |
| 90 | + case <-ctx.Done(): |
| 91 | + log.G(ctx).Debugf("PortForward cancelled in network namespace %q port %d: %v", id, port, ctx.Err()) |
| 92 | + return ctx.Err() |
| 93 | + } |
| 94 | + // give a chance to terminate gracefully or timeout |
| 95 | + // after 1s |
| 96 | + const timeout = time.Second |
| 97 | + select { |
| 98 | + case e := <-errCh: |
| 99 | + if errFwd == nil { |
| 100 | + errFwd = e |
| 101 | + } |
| 102 | + log.G(ctx).Debugf("PortForward stopped forwarding in both directions in network namespace %q port %d: %v", id, port, e) |
| 103 | + case <-time.After(timeout): |
| 104 | + log.G(ctx).Debugf("PortForward timed out waiting to close the connection in network namespace %q port %d", id, port) |
| 105 | + case <-ctx.Done(): |
| 106 | + log.G(ctx).Debugf("PortForward cancelled in network namespace %q port %d: %v", id, port, ctx.Err()) |
| 107 | + errFwd = ctx.Err() |
| 108 | + } |
| 109 | + |
| 110 | + return errFwd |
| 111 | + }() |
57 | 112 |
|
58 |
| - opts := execOptions{ |
59 |
| - cmd: cmd, |
60 |
| - stdin: stdin, |
61 |
| - stdout: stdout, |
62 |
| - stderr: stderr, |
63 |
| - tty: false, |
64 |
| - resize: nil, |
65 |
| - } |
66 |
| - exitCode, err := c.execInternal(ctx, sb.Container, sandboxID, opts) |
67 | 113 | if err != nil {
|
68 |
| - return fmt.Errorf("failed to exec in sandbox: %w", err) |
69 |
| - } |
70 |
| - if *exitCode == 0 { |
71 |
| - return nil |
72 |
| - } |
73 |
| - return &exec.CodeExitError{ |
74 |
| - Err: fmt.Errorf("error executing command %v, exit code %d", cmd, *exitCode), |
75 |
| - Code: int(*exitCode), |
| 114 | + return fmt.Errorf("failed to execute portforward for podId %v, podIp %v, err: %w", id, podIP, err) |
76 | 115 | }
|
| 116 | + log.G(ctx).Debugf("Finish port forwarding for windows %q port %d", id, port) |
| 117 | + |
| 118 | + return nil |
77 | 119 | }
|
0 commit comments