From c19735566abe4c049bf81704ba02a15917dadfe8 Mon Sep 17 00:00:00 2001 From: QxBytes Date: Fri, 27 Sep 2024 17:08:33 -0700 Subject: [PATCH 01/14] move dhcp code to common file --- dhcp/dhcp.go | 254 +++++++++++++++++++++++++++++++++++++++++++ dhcp/dhcp_linux.go | 242 ----------------------------------------- dhcp/dhcp_windows.go | 12 -- 3 files changed, 254 insertions(+), 254 deletions(-) create mode 100644 dhcp/dhcp.go diff --git a/dhcp/dhcp.go b/dhcp/dhcp.go new file mode 100644 index 0000000000..d7a5302699 --- /dev/null +++ b/dhcp/dhcp.go @@ -0,0 +1,254 @@ +package dhcp + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "io" + "net" + "time" + + "github.com/pkg/errors" + "go.uber.org/zap" + "golang.org/x/net/ipv4" +) + +const ( + dhcpDiscover = 1 + bootRequest = 1 + ethPAll = 0x0003 + MaxUDPReceivedPacketSize = 8192 + dhcpServerPort = 67 + dhcpClientPort = 68 + dhcpOpCodeReply = 2 + bootpMinLen = 300 + bytesInAddress = 4 // bytes in an ip address + macBytes = 6 // bytes in a mac address + udpProtocol = 17 + + opRequest = 1 + htypeEthernet = 1 + hlenEthernet = 6 + hops = 0 + secs = 0 + flags = 0x8000 // Broadcast flag +) + +// TransactionID represents a 4-byte DHCP transaction ID as defined in RFC 951, +// Section 3. +// +// The TransactionID is used to match DHCP replies to their original request. +type TransactionID [4]byte + +var ( + magicCookie = []byte{0x63, 0x82, 0x53, 0x63} // DHCP magic cookie + DefaultReadTimeout = 3 * time.Second + DefaultTimeout = 3 * time.Second +) + +type DHCP struct { + logger *zap.Logger +} + +func New(logger *zap.Logger) *DHCP { + return &DHCP{ + logger: logger, + } +} + +// GenerateTransactionID generates a random 32-bits number suitable for use as TransactionID +func GenerateTransactionID() (TransactionID, error) { + var xid TransactionID + _, err := rand.Read(xid[:]) + if err != nil { + return xid, errors.Errorf("could not get random number: %v", err) + } + return xid, nil +} + +// Build DHCP Discover Packet +func buildDHCPDiscover(mac net.HardwareAddr, txid TransactionID) ([]byte, error) { + if len(mac) != macBytes { + return nil, errors.Errorf("invalid MAC address length") + } + + var packet bytes.Buffer + + // BOOTP header + packet.WriteByte(opRequest) // op: BOOTREQUEST (1) + packet.WriteByte(htypeEthernet) // htype: Ethernet (1) + packet.WriteByte(hlenEthernet) // hlen: MAC address length (6) + packet.WriteByte(hops) // hops: 0 + packet.Write(txid[:]) // xid: Transaction ID (4 bytes) + err := binary.Write(&packet, binary.BigEndian, uint16(secs)) // secs: Seconds elapsed + if err != nil { + return nil, errors.Wrap(err, "failed to write seconds elapsed") + } + err = binary.Write(&packet, binary.BigEndian, uint16(flags)) // flags: Broadcast flag + if err != nil { + return nil, errors.Wrap(err, "failed to write broadcast flag") + } + + // Client IP address (0.0.0.0) + packet.Write(make([]byte, bytesInAddress)) + // Your IP address (0.0.0.0) + packet.Write(make([]byte, bytesInAddress)) + // Server IP address (0.0.0.0) + packet.Write(make([]byte, bytesInAddress)) + // Gateway IP address (0.0.0.0) + packet.Write(make([]byte, bytesInAddress)) + + // chaddr: Client hardware address (MAC address) + paddingBytes := 10 + packet.Write(mac) // MAC address (6 bytes) + packet.Write(make([]byte, paddingBytes)) // Padding to 16 bytes + + // sname: Server host name (64 bytes) + serverHostNameBytes := 64 + packet.Write(make([]byte, serverHostNameBytes)) + // file: Boot file name (128 bytes) + bootFileNameBytes := 128 + packet.Write(make([]byte, bootFileNameBytes)) + + // Magic cookie (DHCP) + err = binary.Write(&packet, binary.BigEndian, magicCookie) + if err != nil { + return nil, errors.Wrap(err, "failed to write magic cookie") + } + + // DHCP options (minimal required options for DISCOVER) + packet.Write([]byte{ + 53, 1, 1, // Option 53: DHCP Message Type (1 = DHCP Discover) + 55, 3, 1, 3, 6, // Option 55: Parameter Request List (1 = Subnet Mask, 3 = Router, 6 = DNS) + 255, // End option + }) + + // padding length to 300 bytes + var value uint8 // default is zero + if packet.Len() < bootpMinLen { + packet.Write(bytes.Repeat([]byte{value}, bootpMinLen-packet.Len())) + } + + return packet.Bytes(), nil +} + +// MakeRawUDPPacket converts a payload (a serialized packet) into a +// raw UDP packet for the specified serverAddr from the specified clientAddr. +func MakeRawUDPPacket(payload []byte, serverAddr, clientAddr net.UDPAddr) ([]byte, error) { + udpBytes := 8 + udp := make([]byte, udpBytes) + binary.BigEndian.PutUint16(udp[:2], uint16(clientAddr.Port)) + binary.BigEndian.PutUint16(udp[2:4], uint16(serverAddr.Port)) + totalLen := uint16(udpBytes + len(payload)) + binary.BigEndian.PutUint16(udp[4:6], totalLen) + binary.BigEndian.PutUint16(udp[6:8], 0) // try to offload the checksum + + headerVersion := 4 + headerLen := 20 + headerTTL := 64 + + h := ipv4.Header{ + Version: headerVersion, // nolint + Len: headerLen, // nolint + TotalLen: headerLen + len(udp) + len(payload), + TTL: headerTTL, + Protocol: udpProtocol, // UDP + Dst: serverAddr.IP, + Src: clientAddr.IP, + } + ret, err := h.Marshal() + if err != nil { + return nil, errors.Wrap(err, "failed to marshal when making udp packet") + } + ret = append(ret, udp...) + ret = append(ret, payload...) + return ret, nil +} + +// Receive DHCP response packet using reader +func (c *DHCP) receiveDHCPResponse(ctx context.Context, reader io.ReadCloser, xid TransactionID) error { + recvErrors := make(chan error, 1) + // Recvfrom is a blocking call, so if something goes wrong with its timeout it won't return. + + // Additionally, the timeout on the socket (on the Read(...)) call is how long until the socket times out and gives an error, + // but it won't error if we do get some sort of data within the time out period. + + // If we get some data (even if it is not the packet we are looking for, like wrong txid, wrong response opcode etc.) + // then we continue in the for loop. We then call recvfrom again which will reset the timeout period + // Without the secondary timeout at the bottom of the function, we could stay stuck in the for loop as long as we receive packets. + go func(errs chan<- error) { + // loop will only exit if there is an error, context canceled, or we find our reply packet + for { + if ctx.Err() != nil { + errs <- ctx.Err() + return + } + + buf := make([]byte, MaxUDPReceivedPacketSize) + // Blocks until data received or timeout period is reached + n, innerErr := reader.Read(buf) + if innerErr != nil { + errs <- innerErr + return + } + // check header + var iph ipv4.Header + if err := iph.Parse(buf[:n]); err != nil { + // skip non-IP data + continue + } + if iph.Protocol != udpProtocol { + // skip non-UDP packets + continue + } + udph := buf[iph.Len:n] + // source is from dhcp server if receiving + srcPort := int(binary.BigEndian.Uint16(udph[0:2])) + if srcPort != dhcpServerPort { + continue + } + // client is to dhcp client if receiving + dstPort := int(binary.BigEndian.Uint16(udph[2:4])) + if dstPort != dhcpClientPort { + continue + } + // check payload + pLen := int(binary.BigEndian.Uint16(udph[4:6])) + payload := buf[iph.Len+8 : iph.Len+pLen] + + // retrieve opcode from payload + opcode := payload[0] // opcode is first byte + // retrieve txid from payload + txidOffset := 4 // after 4 bytes, the txid starts + // the txid is 4 bytes, so we take four bytes after the offset + txid := payload[txidOffset : txidOffset+4] + + c.logger.Info("Received packet", zap.Int("opCode", int(opcode)), zap.Any("transactionID", TransactionID(txid))) + if opcode != dhcpOpCodeReply { + continue // opcode is not a reply, so continue + } + + if TransactionID(txid) == xid { + break + } + } + // only occurs if we find our reply packet successfully + // a nil error means a reply was found for this txid + recvErrors <- nil + }(recvErrors) + + // sends a message on repeat after timeout, but only the first one matters + ticker := time.NewTicker(DefaultReadTimeout) + defer ticker.Stop() + + select { + case err := <-recvErrors: + if err != nil { + return errors.Wrap(err, "error during receiving") + } + case <-ticker.C: + return errors.New("timed out waiting for replies") + } + return nil +} diff --git a/dhcp/dhcp_linux.go b/dhcp/dhcp_linux.go index 9e7a029c05..5e71cc4264 100644 --- a/dhcp/dhcp_linux.go +++ b/dhcp/dhcp_linux.go @@ -4,9 +4,7 @@ package dhcp import ( - "bytes" "context" - "crypto/rand" "encoding/binary" "io" "net" @@ -14,53 +12,9 @@ import ( "github.com/pkg/errors" "go.uber.org/zap" - "golang.org/x/net/ipv4" "golang.org/x/sys/unix" ) -const ( - dhcpDiscover = 1 - bootRequest = 1 - ethPAll = 0x0003 - MaxUDPReceivedPacketSize = 8192 - dhcpServerPort = 67 - dhcpClientPort = 68 - dhcpOpCodeReply = 2 - bootpMinLen = 300 - bytesInAddress = 4 // bytes in an ip address - macBytes = 6 // bytes in a mac address - udpProtocol = 17 - - opRequest = 1 - htypeEthernet = 1 - hlenEthernet = 6 - hops = 0 - secs = 0 - flags = 0x8000 // Broadcast flag -) - -// TransactionID represents a 4-byte DHCP transaction ID as defined in RFC 951, -// Section 3. -// -// The TransactionID is used to match DHCP replies to their original request. -type TransactionID [4]byte - -var ( - magicCookie = []byte{0x63, 0x82, 0x53, 0x63} // DHCP magic cookie - DefaultReadTimeout = 3 * time.Second - DefaultTimeout = 3 * time.Second -) - -type DHCP struct { - logger *zap.Logger -} - -func New(logger *zap.Logger) *DHCP { - return &DHCP{ - logger: logger, - } -} - type Socket struct { fd int remoteAddr unix.SockaddrInet4 @@ -122,16 +76,6 @@ func (s *Socket) Close() error { return nil } -// GenerateTransactionID generates a random 32-bits number suitable for use as TransactionID -func GenerateTransactionID() (TransactionID, error) { - var xid TransactionID - _, err := rand.Read(xid[:]) - if err != nil { - return xid, errors.Errorf("could not get random number: %v", err) - } - return xid, nil -} - func makeListeningSocket(ifname string, timeout time.Duration) (int, error) { // reference: https://manned.org/packet.7 // starts listening to the specified protocol, or none if zero @@ -204,192 +148,6 @@ func makeRawSocket(ifname string) (int, error) { return fd, nil } -// Build DHCP Discover Packet -func buildDHCPDiscover(mac net.HardwareAddr, txid TransactionID) ([]byte, error) { - if len(mac) != macBytes { - return nil, errors.Errorf("invalid MAC address length") - } - - var packet bytes.Buffer - - // BOOTP header - packet.WriteByte(opRequest) // op: BOOTREQUEST (1) - packet.WriteByte(htypeEthernet) // htype: Ethernet (1) - packet.WriteByte(hlenEthernet) // hlen: MAC address length (6) - packet.WriteByte(hops) // hops: 0 - packet.Write(txid[:]) // xid: Transaction ID (4 bytes) - err := binary.Write(&packet, binary.BigEndian, uint16(secs)) // secs: Seconds elapsed - if err != nil { - return nil, errors.Wrap(err, "failed to write seconds elapsed") - } - err = binary.Write(&packet, binary.BigEndian, uint16(flags)) // flags: Broadcast flag - if err != nil { - return nil, errors.Wrap(err, "failed to write broadcast flag") - } - - // Client IP address (0.0.0.0) - packet.Write(make([]byte, bytesInAddress)) - // Your IP address (0.0.0.0) - packet.Write(make([]byte, bytesInAddress)) - // Server IP address (0.0.0.0) - packet.Write(make([]byte, bytesInAddress)) - // Gateway IP address (0.0.0.0) - packet.Write(make([]byte, bytesInAddress)) - - // chaddr: Client hardware address (MAC address) - paddingBytes := 10 - packet.Write(mac) // MAC address (6 bytes) - packet.Write(make([]byte, paddingBytes)) // Padding to 16 bytes - - // sname: Server host name (64 bytes) - serverHostNameBytes := 64 - packet.Write(make([]byte, serverHostNameBytes)) - // file: Boot file name (128 bytes) - bootFileNameBytes := 128 - packet.Write(make([]byte, bootFileNameBytes)) - - // Magic cookie (DHCP) - err = binary.Write(&packet, binary.BigEndian, magicCookie) - if err != nil { - return nil, errors.Wrap(err, "failed to write magic cookie") - } - - // DHCP options (minimal required options for DISCOVER) - packet.Write([]byte{ - 53, 1, 1, // Option 53: DHCP Message Type (1 = DHCP Discover) - 55, 3, 1, 3, 6, // Option 55: Parameter Request List (1 = Subnet Mask, 3 = Router, 6 = DNS) - 255, // End option - }) - - // padding length to 300 bytes - var value uint8 // default is zero - if packet.Len() < bootpMinLen { - packet.Write(bytes.Repeat([]byte{value}, bootpMinLen-packet.Len())) - } - - return packet.Bytes(), nil -} - -// MakeRawUDPPacket converts a payload (a serialized packet) into a -// raw UDP packet for the specified serverAddr from the specified clientAddr. -func MakeRawUDPPacket(payload []byte, serverAddr, clientAddr net.UDPAddr) ([]byte, error) { - udpBytes := 8 - udp := make([]byte, udpBytes) - binary.BigEndian.PutUint16(udp[:2], uint16(clientAddr.Port)) - binary.BigEndian.PutUint16(udp[2:4], uint16(serverAddr.Port)) - totalLen := uint16(udpBytes + len(payload)) - binary.BigEndian.PutUint16(udp[4:6], totalLen) - binary.BigEndian.PutUint16(udp[6:8], 0) // try to offload the checksum - - headerVersion := 4 - headerLen := 20 - headerTTL := 64 - - h := ipv4.Header{ - Version: headerVersion, // nolint - Len: headerLen, // nolint - TotalLen: headerLen + len(udp) + len(payload), - TTL: headerTTL, - Protocol: udpProtocol, // UDP - Dst: serverAddr.IP, - Src: clientAddr.IP, - } - ret, err := h.Marshal() - if err != nil { - return nil, errors.Wrap(err, "failed to marshal when making udp packet") - } - ret = append(ret, udp...) - ret = append(ret, payload...) - return ret, nil -} - -// Receive DHCP response packet using reader -func (c *DHCP) receiveDHCPResponse(ctx context.Context, reader io.ReadCloser, xid TransactionID) error { - recvErrors := make(chan error, 1) - // Recvfrom is a blocking call, so if something goes wrong with its timeout it won't return. - - // Additionally, the timeout on the socket (on the Read(...)) call is how long until the socket times out and gives an error, - // but it won't error if we do get some sort of data within the time out period. - - // If we get some data (even if it is not the packet we are looking for, like wrong txid, wrong response opcode etc.) - // then we continue in the for loop. We then call recvfrom again which will reset the timeout period - // Without the secondary timeout at the bottom of the function, we could stay stuck in the for loop as long as we receive packets. - go func(errs chan<- error) { - // loop will only exit if there is an error, context canceled, or we find our reply packet - for { - if ctx.Err() != nil { - errs <- ctx.Err() - return - } - - buf := make([]byte, MaxUDPReceivedPacketSize) - // Blocks until data received or timeout period is reached - n, innerErr := reader.Read(buf) - if innerErr != nil { - errs <- innerErr - return - } - // check header - var iph ipv4.Header - if err := iph.Parse(buf[:n]); err != nil { - // skip non-IP data - continue - } - if iph.Protocol != udpProtocol { - // skip non-UDP packets - continue - } - udph := buf[iph.Len:n] - // source is from dhcp server if receiving - srcPort := int(binary.BigEndian.Uint16(udph[0:2])) - if srcPort != dhcpServerPort { - continue - } - // client is to dhcp client if receiving - dstPort := int(binary.BigEndian.Uint16(udph[2:4])) - if dstPort != dhcpClientPort { - continue - } - // check payload - pLen := int(binary.BigEndian.Uint16(udph[4:6])) - payload := buf[iph.Len+8 : iph.Len+pLen] - - // retrieve opcode from payload - opcode := payload[0] // opcode is first byte - // retrieve txid from payload - txidOffset := 4 // after 4 bytes, the txid starts - // the txid is 4 bytes, so we take four bytes after the offset - txid := payload[txidOffset : txidOffset+4] - - c.logger.Info("Received packet", zap.Int("opCode", int(opcode)), zap.Any("transactionID", TransactionID(txid))) - if opcode != dhcpOpCodeReply { - continue // opcode is not a reply, so continue - } - - if TransactionID(txid) == xid { - break - } - } - // only occurs if we find our reply packet successfully - // a nil error means a reply was found for this txid - recvErrors <- nil - }(recvErrors) - - // sends a message on repeat after timeout, but only the first one matters - ticker := time.NewTicker(DefaultReadTimeout) - defer ticker.Stop() - - select { - case err := <-recvErrors: - if err != nil { - return errors.Wrap(err, "error during receiving") - } - case <-ticker.C: - return errors.New("timed out waiting for replies") - } - return nil -} - // Issues a DHCP Discover packet from the nic specified by mac and name ifname // Returns nil if a reply to the transaction was received, or error if time out // Does not return the DHCP Offer that was received from the DHCP server diff --git a/dhcp/dhcp_windows.go b/dhcp/dhcp_windows.go index 7b23dbeeff..056ca3058b 100644 --- a/dhcp/dhcp_windows.go +++ b/dhcp/dhcp_windows.go @@ -3,20 +3,8 @@ package dhcp import ( "context" "net" - - "go.uber.org/zap" ) -type DHCP struct { - logger *zap.Logger -} - -func New(logger *zap.Logger) *DHCP { - return &DHCP{ - logger: logger, - } -} - func (c *DHCP) DiscoverRequest(_ context.Context, _ net.HardwareAddr, _ string) error { return nil } From ee8dde1e5a8f2118265dc9acb2f9440d8e8f4d6f Mon Sep 17 00:00:00 2001 From: QxBytes Date: Thu, 3 Oct 2024 09:56:02 -0700 Subject: [PATCH 02/14] create initial windows dhcp client --- cni/network/network.go | 5 +- dhcp/dhcp.go | 12 ++- dhcp/dhcp_windows.go | 163 +++++++++++++++++++++++++++++++++++- network/endpoint_windows.go | 24 +++++- 4 files changed, 197 insertions(+), 7 deletions(-) diff --git a/cni/network/network.go b/cni/network/network.go index 6b0635e1c7..7bc1f35394 100644 --- a/cni/network/network.go +++ b/cni/network/network.go @@ -131,8 +131,9 @@ func NewPlugin(name string, } nl := netlink.NewNetlink() + plc := platform.NewExecClient(logger) // Setup network manager. - nm, err := network.NewNetworkManager(nl, platform.NewExecClient(logger), &netio.NetIO{}, network.NewNamespaceClient(), iptables.NewClient(), dhcp.New(logger)) + nm, err := network.NewNetworkManager(nl, plc, &netio.NetIO{}, network.NewNamespaceClient(), iptables.NewClient(), dhcp.New(logger, plc)) if err != nil { return nil, err } @@ -1526,7 +1527,7 @@ func (plugin *NetPlugin) validateArgs(args *cniSkel.CmdArgs, nwCfg *cni.NetworkC if !allowedInput.MatchString(args.ContainerID) || !allowedInput.MatchString(args.IfName) { return errors.New("invalid args value") } - if !allowedInput.MatchString(nwCfg.Bridge) { + if !allowedInput.MatchString(nwCfg.Bridge) || !allowedInput.MatchString(nwCfg.Master) { return errors.New("invalid network config value") } diff --git a/dhcp/dhcp.go b/dhcp/dhcp.go index d7a5302699..d86514b60e 100644 --- a/dhcp/dhcp.go +++ b/dhcp/dhcp.go @@ -47,13 +47,19 @@ var ( DefaultTimeout = 3 * time.Second ) +type ExecClient interface { + ExecuteCommand(ctx context.Context, command string, args ...string) (string, error) +} + type DHCP struct { - logger *zap.Logger + logger *zap.Logger + execClient ExecClient } -func New(logger *zap.Logger) *DHCP { +func New(logger *zap.Logger, plc ExecClient) *DHCP { return &DHCP{ - logger: logger, + logger: logger, + execClient: plc, } } diff --git a/dhcp/dhcp_windows.go b/dhcp/dhcp_windows.go index 056ca3058b..3baf3c8e5a 100644 --- a/dhcp/dhcp_windows.go +++ b/dhcp/dhcp_windows.go @@ -2,9 +2,170 @@ package dhcp import ( "context" + "golang.org/x/sys/windows" "net" + "regexp" + "time" + + "github.com/pkg/errors" + "go.uber.org/zap" +) + +const ( + dummyIPAddressStr = "169.254.128.10" + dummySubnetMask = "255.255.128.0" + addIPAddressTimeout = 10 * time.Second + deleteIPAddressTimeout = 2 * time.Second + + socketTimeout = 1000 ) -func (c *DHCP) DiscoverRequest(_ context.Context, _ net.HardwareAddr, _ string) error { +var ( + dummyIPAddress = net.IPv4(169, 254, 128, 10) + // matches if the string fully consists of zero or more alphanumeric, dots, dashes, parentheses, spaces, or underscores + allowedInput = regexp.MustCompile(`^[a-zA-Z0-9._\-\(\) ]*$`) +) + +type Socket struct { + fd windows.Handle + destAddr windows.SockaddrInet4 +} + +func NewSocket(destAddr windows.SockaddrInet4) (*Socket, error) { + // Create a raw socket using windows.WSASocket + fd, err := windows.WSASocket(windows.AF_INET, windows.SOCK_RAW, windows.IPPROTO_UDP, nil, 0, windows.WSA_FLAG_OVERLAPPED) + ret := &Socket{ + fd: fd, + destAddr: destAddr, + } + if err != nil { + return ret, errors.Wrap(err, "error creating socket") + } + defer windows.Closesocket(fd) + + // Set IP_HDRINCL to indicate that we are including our own IP header + err = windows.SetsockoptInt(fd, windows.IPPROTO_IP, windows.IP_HDRINCL, 1) + if err != nil { + return ret, errors.Wrap(err, "error setting IP_HDRINCL") + } + // Set the SO_BROADCAST option or else we get an error saying that we access a socket in a way forbidden by its access perms + err = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_BROADCAST, 1) + if err != nil { + return ret, errors.Wrap(err, "error setting SO_BROADCAST") + } + // Set timeout + if err = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVTIMEO, socketTimeout); err != nil { + return ret, errors.Wrap(err, "error setting receive timeout") + } + return ret, nil +} + +func (s *Socket) Write(packetBytes []byte) (int, error) { + err := windows.Sendto(s.fd, packetBytes, 0, &s.destAddr) + if err != nil { + return 0, errors.Wrap(err, "failed windows send to") + } + return len(packetBytes), nil +} +func (s *Socket) Read(p []byte) (n int, err error) { + n, _, innerErr := windows.Recvfrom(s.fd, p, 0) + if innerErr != nil { + return 0, errors.Wrap(err, "failed windows recv from") + } + return n, nil +} + +func (s *Socket) Close() error { + // do not attempt to close invalid fd (happens on socket creation failure) + if s.fd == windows.InvalidHandle { + return nil + } + // Ensure the file descriptor is closed when done + if err := windows.Close(s.fd); err != nil { + return errors.Wrap(err, "error closing dhcp windows socket") + } + return nil +} + +// issues a dhcp discover request on an interface by assigning an ip to that interface +// then, sends a packet with that interface's dummy ip, and then unassigns the dummy ip +func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, ifName string) error { + // validate interface name + if !allowedInput.MatchString(ifName) { + return errors.New("invalid dhcp discover request interface name") + } + // delete dummy ip off the interface if it already exists + ret, err := c.execClient.ExecuteCommand(ctx, "netsh", "interface", "ipv4", "delete", "address", ifName, dummyIPAddressStr) + if err != nil { + c.logger.Info("Could not remove dummy ip", zap.String("output", ret), zap.Error(err)) + } + time.Sleep(deleteIPAddressTimeout) + + // create dummy ip so we can direct the packet to the correct interface + ret, err = c.execClient.ExecuteCommand(ctx, "netsh", "interface", "ipv4", "add", "address", ifName, dummyIPAddressStr, dummySubnetMask) + if err != nil { + return errors.Wrap(err, "failed to add dummy ip to interface: "+ret) + } + // ensure we always remove the dummy ip we added from the interface + defer func() { + ret, err := c.execClient.ExecuteCommand(ctx, "netsh", "interface", "ipv4", "delete", "address", ifName, dummyIPAddressStr) + if err != nil { + c.logger.Info("Could not remove dummy ip on leaving function", zap.String("output", ret), zap.Error(err)) + } + }() + // it takes time for the address to be assigned + time.Sleep(addIPAddressTimeout) + + // now begin the dhcp request + txid, err := GenerateTransactionID() + if err != nil { + return errors.Wrap(err, "failed to generate transaction id") + } + + // Prepare an IP and UDP header + raddr := &net.UDPAddr{IP: net.IPv4bcast, Port: dhcpServerPort} + laddr := &net.UDPAddr{IP: dummyIPAddress, Port: dhcpClientPort} + + dhcpDiscover, err := buildDHCPDiscover(macAddress, txid) + if err != nil { + return errors.Wrap(err, "failed to build dhcp discover") + } + + // Fill out the headers, add payload, and construct the full packet + bytesToSend, err := MakeRawUDPPacket(dhcpDiscover, *raddr, *laddr) + if err != nil { + return errors.Wrap(err, "failed to make raw udp packet") + } + + destAddr := windows.SockaddrInet4{ + Addr: [4]byte{255, 255, 255, 255}, // Destination IP + Port: 67, // Destination Port + } + // create new socket for writing and reading + sock, err := NewSocket(destAddr) + defer func() { + // always clean up the socket, even if we fail while setting options + closeErr := sock.Close() + if closeErr != nil { + c.logger.Error("Error closing dhcp socket:", zap.Error(closeErr)) + } + }() + if err != nil { + return errors.Wrap(err, "failed to create socket") + } + + _, err = sock.Write(bytesToSend) + if err != nil { + return errors.Wrap(err, "failed to write to dhcp socket") + } + + c.logger.Info("DHCP Discover packet was sent successfully", zap.Any("transactionID", txid)) + + // Wait for DHCP response (Offer) + err = c.receiveDHCPResponse(ctx, sock, txid) + if err != nil { + return errors.Wrap(err, "failed to read from dhcp socket") + } + return nil } diff --git a/network/endpoint_windows.go b/network/endpoint_windows.go index edd52327f2..8ab39964fe 100644 --- a/network/endpoint_windows.go +++ b/network/endpoint_windows.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "strings" + "time" "github.com/Azure/azure-container-networking/cns" "github.com/Azure/azure-container-networking/netio" @@ -141,6 +142,21 @@ func (nw *network) getEndpointWithVFDevice(plc platform.ExecClient, epInfo *Endp return ep, nil } +func (nw *network) sendDHCPDiscoverOnSecondary(client dhcpClient, mac net.HardwareAddr, ifName string) error { + // issue dhcp discover packet to ensure mapping created for dns via wireserver to work + // we do not use the response for anything + numSecs := 15 // we need to wait for the address to be assigned + timeout := time.Duration(numSecs) * time.Second + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(timeout)) + defer cancel() + logger.Info("Sending DHCP packet", zap.Any("macAddress", mac), zap.String("ifName", ifName)) + err := client.DiscoverRequest(ctx, mac, ifName) + if err != nil { + return errors.Wrapf(err, "failed to issue dhcp discover packet to create mapping in host") + } + return nil +} + // newEndpointImpl creates a new endpoint in the network. func (nw *network) newEndpointImpl( cli apipaClient, @@ -150,12 +166,18 @@ func (nw *network) newEndpointImpl( _ EndpointClient, _ NamespaceClientInterface, _ ipTablesClient, - _ dhcpClient, + dhcpc dhcpClient, epInfo *EndpointInfo, ) (*endpoint, error) { if epInfo.NICType == cns.BackendNIC { return nw.getEndpointWithVFDevice(plc, epInfo) } + if epInfo.NICType == cns.DelegatedVMNIC { + // use master interface name, interface name, or adapter name? + if err := nw.sendDHCPDiscoverOnSecondary(dhcpc, epInfo.MacAddress, epInfo.MasterIfName); err != nil { + return nil, err + } + } if useHnsV2, err := UseHnsV2(epInfo.NetNsPath); useHnsV2 { if err != nil { From b773feec2d126d23bd84afa88a4054119e28b264 Mon Sep 17 00:00:00 2001 From: QxBytes Date: Thu, 3 Oct 2024 10:50:58 -0700 Subject: [PATCH 03/14] use node network interface frontend nic --- network/endpoint_windows.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/network/endpoint_windows.go b/network/endpoint_windows.go index 8ab39964fe..52078bc32c 100644 --- a/network/endpoint_windows.go +++ b/network/endpoint_windows.go @@ -172,7 +172,7 @@ func (nw *network) newEndpointImpl( if epInfo.NICType == cns.BackendNIC { return nw.getEndpointWithVFDevice(plc, epInfo) } - if epInfo.NICType == cns.DelegatedVMNIC { + if epInfo.NICType == cns.NodeNetworkInterfaceFrontendNIC { // use master interface name, interface name, or adapter name? if err := nw.sendDHCPDiscoverOnSecondary(dhcpc, epInfo.MacAddress, epInfo.MasterIfName); err != nil { return nil, err From 4b357e2e03ed64ceea5a252813840adf67aaa407 Mon Sep 17 00:00:00 2001 From: QxBytes Date: Thu, 3 Oct 2024 11:04:51 -0700 Subject: [PATCH 04/14] address linter issues --- dhcp/dhcp_windows.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/dhcp/dhcp_windows.go b/dhcp/dhcp_windows.go index 3baf3c8e5a..fe71d36495 100644 --- a/dhcp/dhcp_windows.go +++ b/dhcp/dhcp_windows.go @@ -2,13 +2,13 @@ package dhcp import ( "context" - "golang.org/x/sys/windows" "net" "regexp" "time" "github.com/pkg/errors" "go.uber.org/zap" + "golang.org/x/sys/windows" ) const ( @@ -21,7 +21,7 @@ const ( ) var ( - dummyIPAddress = net.IPv4(169, 254, 128, 10) + dummyIPAddress = net.IPv4(169, 254, 128, 10) // nolint // matches if the string fully consists of zero or more alphanumeric, dots, dashes, parentheses, spaces, or underscores allowedInput = regexp.MustCompile(`^[a-zA-Z0-9._\-\(\) ]*$`) ) @@ -41,7 +41,6 @@ func NewSocket(destAddr windows.SockaddrInet4) (*Socket, error) { if err != nil { return ret, errors.Wrap(err, "error creating socket") } - defer windows.Closesocket(fd) // Set IP_HDRINCL to indicate that we are including our own IP header err = windows.SetsockoptInt(fd, windows.IPPROTO_IP, windows.IP_HDRINCL, 1) @@ -67,6 +66,7 @@ func (s *Socket) Write(packetBytes []byte) (int, error) { } return len(packetBytes), nil } + func (s *Socket) Read(p []byte) (n int, err error) { n, _, innerErr := windows.Recvfrom(s.fd, p, 0) if innerErr != nil { @@ -81,7 +81,7 @@ func (s *Socket) Close() error { return nil } // Ensure the file descriptor is closed when done - if err := windows.Close(s.fd); err != nil { + if err := windows.Closesocket(s.fd); err != nil { return errors.Wrap(err, "error closing dhcp windows socket") } return nil @@ -108,8 +108,8 @@ func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, } // ensure we always remove the dummy ip we added from the interface defer func() { - ret, err := c.execClient.ExecuteCommand(ctx, "netsh", "interface", "ipv4", "delete", "address", ifName, dummyIPAddressStr) - if err != nil { + ret, cleanupErr := c.execClient.ExecuteCommand(ctx, "netsh", "interface", "ipv4", "delete", "address", ifName, dummyIPAddressStr) + if cleanupErr != nil { c.logger.Info("Could not remove dummy ip on leaving function", zap.String("output", ret), zap.Error(err)) } }() @@ -139,7 +139,7 @@ func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, destAddr := windows.SockaddrInet4{ Addr: [4]byte{255, 255, 255, 255}, // Destination IP - Port: 67, // Destination Port + Port: dhcpServerPort, // Destination Port } // create new socket for writing and reading sock, err := NewSocket(destAddr) From fd36e391f559796d42a993c07e091bf78ec2d758 Mon Sep 17 00:00:00 2001 From: QxBytes Date: Thu, 3 Oct 2024 15:05:44 -0700 Subject: [PATCH 05/14] address feedback, move send request to before hns network creation --- dhcp/dhcp_windows.go | 16 ++++++++-------- network/endpoint_windows.go | 22 ---------------------- network/network_windows.go | 24 ++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 30 deletions(-) diff --git a/dhcp/dhcp_windows.go b/dhcp/dhcp_windows.go index fe71d36495..ecb40b332d 100644 --- a/dhcp/dhcp_windows.go +++ b/dhcp/dhcp_windows.go @@ -12,12 +12,12 @@ import ( ) const ( - dummyIPAddressStr = "169.254.128.10" - dummySubnetMask = "255.255.128.0" - addIPAddressTimeout = 10 * time.Second - deleteIPAddressTimeout = 2 * time.Second + dummyIPAddressStr = "169.254.128.10" + dummySubnetMask = "255.255.128.0" + addIPAddressDelay = 4 * time.Second + deleteIPAddressDelay = 2 * time.Second - socketTimeout = 1000 + socketTimeoutMillis = 1000 ) var ( @@ -53,7 +53,7 @@ func NewSocket(destAddr windows.SockaddrInet4) (*Socket, error) { return ret, errors.Wrap(err, "error setting SO_BROADCAST") } // Set timeout - if err = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVTIMEO, socketTimeout); err != nil { + if err = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVTIMEO, socketTimeoutMillis); err != nil { return ret, errors.Wrap(err, "error setting receive timeout") } return ret, nil @@ -99,7 +99,7 @@ func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, if err != nil { c.logger.Info("Could not remove dummy ip", zap.String("output", ret), zap.Error(err)) } - time.Sleep(deleteIPAddressTimeout) + time.Sleep(deleteIPAddressDelay) // create dummy ip so we can direct the packet to the correct interface ret, err = c.execClient.ExecuteCommand(ctx, "netsh", "interface", "ipv4", "add", "address", ifName, dummyIPAddressStr, dummySubnetMask) @@ -114,7 +114,7 @@ func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, } }() // it takes time for the address to be assigned - time.Sleep(addIPAddressTimeout) + time.Sleep(addIPAddressDelay) // now begin the dhcp request txid, err := GenerateTransactionID() diff --git a/network/endpoint_windows.go b/network/endpoint_windows.go index 52078bc32c..4220247fd1 100644 --- a/network/endpoint_windows.go +++ b/network/endpoint_windows.go @@ -8,7 +8,6 @@ import ( "fmt" "net" "strings" - "time" "github.com/Azure/azure-container-networking/cns" "github.com/Azure/azure-container-networking/netio" @@ -142,21 +141,6 @@ func (nw *network) getEndpointWithVFDevice(plc platform.ExecClient, epInfo *Endp return ep, nil } -func (nw *network) sendDHCPDiscoverOnSecondary(client dhcpClient, mac net.HardwareAddr, ifName string) error { - // issue dhcp discover packet to ensure mapping created for dns via wireserver to work - // we do not use the response for anything - numSecs := 15 // we need to wait for the address to be assigned - timeout := time.Duration(numSecs) * time.Second - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(timeout)) - defer cancel() - logger.Info("Sending DHCP packet", zap.Any("macAddress", mac), zap.String("ifName", ifName)) - err := client.DiscoverRequest(ctx, mac, ifName) - if err != nil { - return errors.Wrapf(err, "failed to issue dhcp discover packet to create mapping in host") - } - return nil -} - // newEndpointImpl creates a new endpoint in the network. func (nw *network) newEndpointImpl( cli apipaClient, @@ -172,12 +156,6 @@ func (nw *network) newEndpointImpl( if epInfo.NICType == cns.BackendNIC { return nw.getEndpointWithVFDevice(plc, epInfo) } - if epInfo.NICType == cns.NodeNetworkInterfaceFrontendNIC { - // use master interface name, interface name, or adapter name? - if err := nw.sendDHCPDiscoverOnSecondary(dhcpc, epInfo.MacAddress, epInfo.MasterIfName); err != nil { - return nil, err - } - } if useHnsV2, err := UseHnsV2(epInfo.NetNsPath); useHnsV2 { if err != nil { diff --git a/network/network_windows.go b/network/network_windows.go index 4d870827ca..7c1f2c6020 100644 --- a/network/network_windows.go +++ b/network/network_windows.go @@ -4,8 +4,10 @@ package network import ( + "context" "encoding/json" "fmt" + "net" "strconv" "strings" "time" @@ -433,8 +435,30 @@ func (nm *networkManager) newNetworkImplHnsV2(nwInfo *EndpointInfo, extIf *exter return nw, nil } +func (nw *networkManager) sendDHCPDiscoverOnSecondary(client dhcpClient, mac net.HardwareAddr, ifName string) error { + // issue dhcp discover packet to ensure mapping created for dns via wireserver to work + // we do not use the response for anything + numSecs := 15 // we need to wait for the address to be assigned + timeout := time.Duration(numSecs) * time.Second + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(timeout)) + defer cancel() + logger.Info("Sending DHCP packet", zap.Any("macAddress", mac), zap.String("ifName", ifName)) + err := client.DiscoverRequest(ctx, mac, ifName) + if err != nil { + return errors.Wrapf(err, "failed to issue dhcp discover packet to create mapping in host") + } + return nil +} + // NewNetworkImpl creates a new container network. func (nm *networkManager) newNetworkImpl(nwInfo *EndpointInfo, extIf *externalInterface) (*network, error) { + if nwInfo.NICType == cns.NodeNetworkInterfaceFrontendNIC { + // use master interface name, interface name, or adapter name? + if err := nm.sendDHCPDiscoverOnSecondary(nm.dhcpClient, nwInfo.MacAddress, nwInfo.MasterIfName); err != nil { + return nil, err + } + } + if useHnsV2, err := UseHnsV2(nwInfo.NetNs); useHnsV2 { if err != nil { return nil, err From 5990ef713dedbb4fd469d4ba5b4780084b46d368 Mon Sep 17 00:00:00 2001 From: QxBytes Date: Thu, 3 Oct 2024 17:05:39 -0700 Subject: [PATCH 06/14] add delay before returning from windows dhcp otherwise net adapter not found --- dhcp/dhcp.go | 2 +- dhcp/dhcp_windows.go | 7 +++++-- network/network_windows.go | 1 + 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/dhcp/dhcp.go b/dhcp/dhcp.go index d86514b60e..595ce3c11f 100644 --- a/dhcp/dhcp.go +++ b/dhcp/dhcp.go @@ -234,7 +234,7 @@ func (c *DHCP) receiveDHCPResponse(ctx context.Context, reader io.ReadCloser, xi if opcode != dhcpOpCodeReply { continue // opcode is not a reply, so continue } - + c.logger.Info("Received DHCP reply packet", zap.Int("opCode", int(opcode)), zap.Any("transactionID", TransactionID(txid))) if TransactionID(txid) == xid { break } diff --git a/dhcp/dhcp_windows.go b/dhcp/dhcp_windows.go index ecb40b332d..01aa15eede 100644 --- a/dhcp/dhcp_windows.go +++ b/dhcp/dhcp_windows.go @@ -16,6 +16,7 @@ const ( dummySubnetMask = "255.255.128.0" addIPAddressDelay = 4 * time.Second deleteIPAddressDelay = 2 * time.Second + returnDelay = 8 * time.Second // time to wait before returning from DiscoverRequest socketTimeoutMillis = 1000 ) @@ -97,7 +98,7 @@ func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, // delete dummy ip off the interface if it already exists ret, err := c.execClient.ExecuteCommand(ctx, "netsh", "interface", "ipv4", "delete", "address", ifName, dummyIPAddressStr) if err != nil { - c.logger.Info("Could not remove dummy ip", zap.String("output", ret), zap.Error(err)) + c.logger.Info("Could not remove dummy ip, likely because it doesn't exist", zap.String("output", ret), zap.Error(err)) } time.Sleep(deleteIPAddressDelay) @@ -110,8 +111,10 @@ func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, defer func() { ret, cleanupErr := c.execClient.ExecuteCommand(ctx, "netsh", "interface", "ipv4", "delete", "address", ifName, dummyIPAddressStr) if cleanupErr != nil { - c.logger.Info("Could not remove dummy ip on leaving function", zap.String("output", ret), zap.Error(err)) + c.logger.Info("Failed to remove dummy ip on leaving function", zap.String("output", ret), zap.Error(err)) } + // wait for nic to retrieve autoconfiguration ip + time.Sleep(returnDelay) }() // it takes time for the address to be assigned time.Sleep(addIPAddressDelay) diff --git a/network/network_windows.go b/network/network_windows.go index 7c1f2c6020..c459992e48 100644 --- a/network/network_windows.go +++ b/network/network_windows.go @@ -447,6 +447,7 @@ func (nw *networkManager) sendDHCPDiscoverOnSecondary(client dhcpClient, mac net if err != nil { return errors.Wrapf(err, "failed to issue dhcp discover packet to create mapping in host") } + logger.Info("Successfully received DHCP reply packet") return nil } From 03efe8eb4d1281c5af5708f7f90a7f4a3918c6a3 Mon Sep 17 00:00:00 2001 From: QxBytes Date: Thu, 3 Oct 2024 17:44:19 -0700 Subject: [PATCH 07/14] silence non-error and always run command to remove dummy ip --- dhcp/dhcp_windows.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/dhcp/dhcp_windows.go b/dhcp/dhcp_windows.go index 01aa15eede..a39b1739fb 100644 --- a/dhcp/dhcp_windows.go +++ b/dhcp/dhcp_windows.go @@ -6,6 +6,7 @@ import ( "regexp" "time" + "github.com/Azure/azure-container-networking/cni/log" "github.com/pkg/errors" "go.uber.org/zap" "golang.org/x/sys/windows" @@ -98,7 +99,7 @@ func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, // delete dummy ip off the interface if it already exists ret, err := c.execClient.ExecuteCommand(ctx, "netsh", "interface", "ipv4", "delete", "address", ifName, dummyIPAddressStr) if err != nil { - c.logger.Info("Could not remove dummy ip, likely because it doesn't exist", zap.String("output", ret), zap.Error(err)) + c.logger.Info("Could not remove dummy ip, likely because it doesn't exist", zap.String("output", ret), zap.Error(log.NewErrorWithoutStackTrace(err))) } time.Sleep(deleteIPAddressDelay) @@ -109,7 +110,9 @@ func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, } // ensure we always remove the dummy ip we added from the interface defer func() { - ret, cleanupErr := c.execClient.ExecuteCommand(ctx, "netsh", "interface", "ipv4", "delete", "address", ifName, dummyIPAddressStr) + // we always want to try to remove the dummy ip, even if the deadline was reached + // so we have context.Background() + ret, cleanupErr := c.execClient.ExecuteCommand(context.Background(), "netsh", "interface", "ipv4", "delete", "address", ifName, dummyIPAddressStr) if cleanupErr != nil { c.logger.Info("Failed to remove dummy ip on leaving function", zap.String("output", ret), zap.Error(err)) } From e1edebebae8eda6c93b239a5c3f317d93246390f Mon Sep 17 00:00:00 2001 From: QxBytes Date: Fri, 4 Oct 2024 09:34:37 -0700 Subject: [PATCH 08/14] address linter issues --- network/endpoint_windows.go | 2 +- network/network_windows.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/network/endpoint_windows.go b/network/endpoint_windows.go index 4220247fd1..edd52327f2 100644 --- a/network/endpoint_windows.go +++ b/network/endpoint_windows.go @@ -150,7 +150,7 @@ func (nw *network) newEndpointImpl( _ EndpointClient, _ NamespaceClientInterface, _ ipTablesClient, - dhcpc dhcpClient, + _ dhcpClient, epInfo *EndpointInfo, ) (*endpoint, error) { if epInfo.NICType == cns.BackendNIC { diff --git a/network/network_windows.go b/network/network_windows.go index c459992e48..d62f0afc2f 100644 --- a/network/network_windows.go +++ b/network/network_windows.go @@ -435,7 +435,7 @@ func (nm *networkManager) newNetworkImplHnsV2(nwInfo *EndpointInfo, extIf *exter return nw, nil } -func (nw *networkManager) sendDHCPDiscoverOnSecondary(client dhcpClient, mac net.HardwareAddr, ifName string) error { +func (nm *networkManager) sendDHCPDiscoverOnSecondary(client dhcpClient, mac net.HardwareAddr, ifName string) error { // issue dhcp discover packet to ensure mapping created for dns via wireserver to work // we do not use the response for anything numSecs := 15 // we need to wait for the address to be assigned From 84acbc9c1ad8315f384851047cb369ee90d44ca9 Mon Sep 17 00:00:00 2001 From: QxBytes Date: Fri, 4 Oct 2024 10:18:17 -0700 Subject: [PATCH 09/14] move retrier to separate package and reuse code --- dhcp/dhcp_windows.go | 12 +++- .../transparent_vlan_endpointclient_linux.go | 25 ++------ ...nsparent_vlan_endpointclient_linux_test.go | 55 ---------------- retry/retry.go | 15 +++++ retry/retry_test.go | 63 +++++++++++++++++++ 5 files changed, 93 insertions(+), 77 deletions(-) create mode 100644 retry/retry.go create mode 100644 retry/retry_test.go diff --git a/dhcp/dhcp_windows.go b/dhcp/dhcp_windows.go index a39b1739fb..1f289213fc 100644 --- a/dhcp/dhcp_windows.go +++ b/dhcp/dhcp_windows.go @@ -7,6 +7,7 @@ import ( "time" "github.com/Azure/azure-container-networking/cni/log" + "github.com/Azure/azure-container-networking/retry" "github.com/pkg/errors" "go.uber.org/zap" "golang.org/x/sys/windows" @@ -18,8 +19,9 @@ const ( addIPAddressDelay = 4 * time.Second deleteIPAddressDelay = 2 * time.Second returnDelay = 8 * time.Second // time to wait before returning from DiscoverRequest - - socketTimeoutMillis = 1000 + retryCount = 5 + retryDelayMillis = 500 + socketTimeoutMillis = 1000 ) var ( @@ -160,7 +162,11 @@ func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, return errors.Wrap(err, "failed to create socket") } - _, err = sock.Write(bytesToSend) + // retry sending the packet until it succeeds + err = retry.Do(func() error { + _, sockErr := sock.Write(bytesToSend) + return sockErr + }, retryCount, retryDelayMillis) if err != nil { return errors.Wrap(err, "failed to write to dhcp socket") } diff --git a/network/transparent_vlan_endpointclient_linux.go b/network/transparent_vlan_endpointclient_linux.go index 731353c231..193529cef5 100644 --- a/network/transparent_vlan_endpointclient_linux.go +++ b/network/transparent_vlan_endpointclient_linux.go @@ -4,7 +4,6 @@ import ( "fmt" "net" "strings" - "time" "github.com/Azure/azure-container-networking/iptables" "github.com/Azure/azure-container-networking/netio" @@ -13,6 +12,7 @@ import ( "github.com/Azure/azure-container-networking/network/networkutils" "github.com/Azure/azure-container-networking/network/snat" "github.com/Azure/azure-container-networking/platform" + "github.com/Azure/azure-container-networking/retry" "github.com/pkg/errors" vishnetlink "github.com/vishvananda/netlink" "go.uber.org/zap" @@ -192,7 +192,7 @@ func (client *TransparentVlanEndpointClient) setLinkNetNSAndConfirm(name string, } // confirm veth was moved successfully - err = RunWithRetries(func() error { + err = retry.Do(func() error { // retry checking in the namespace if the interface is not detected return ExecuteInNS(client.nsClient, client.vnetNSName, func() error { _, ifDetectedErr := client.netioshim.GetNetworkInterfaceByName(client.vlanIfName) @@ -220,7 +220,7 @@ func (client *TransparentVlanEndpointClient) PopulateVM(epInfo *EndpointInfo) er // We assume the only possible error is that the namespace doesn't exist logger.Info("No existing NS detected. Creating the vnet namespace and switching to it", zap.String("message", existingErr.Error())) - err = RunWithRetries(func() error { + err = retry.Do(func() error { return client.createNetworkNamespace(vmNS) }, numRetries, sleepInMs) if err != nil { @@ -279,7 +279,7 @@ func (client *TransparentVlanEndpointClient) PopulateVM(epInfo *EndpointInfo) er }() // sometimes there is slight delay in interface creation. check if it exists - err = RunWithRetries(func() error { + err = retry.Do(func() error { _, err = client.netioshim.GetNetworkInterfaceByName(client.vlanIfName) return errors.Wrap(err, "failed to get vlan veth") }, numRetries, sleepInMs) @@ -316,7 +316,7 @@ func (client *TransparentVlanEndpointClient) PopulateVM(epInfo *EndpointInfo) er } // Ensure vnet veth is created, as there may be a slight delay - err = RunWithRetries(func() error { + err = retry.Do(func() error { _, getErr := client.netioshim.GetNetworkInterfaceByName(client.vnetVethName) return errors.Wrap(getErr, "failed to get vnet veth") }, numRetries, sleepInMs) @@ -326,7 +326,7 @@ func (client *TransparentVlanEndpointClient) PopulateVM(epInfo *EndpointInfo) er // Ensure container veth is created, as there may be a slight delay var containerIf *net.Interface - err = RunWithRetries(func() error { + err = retry.Do(func() error { var getErr error containerIf, getErr = client.netioshim.GetNetworkInterfaceByName(client.containerVethName) return errors.Wrap(getErr, "failed to get container veth") @@ -712,16 +712,3 @@ func ExecuteInNS(nsc NamespaceClientInterface, nsName string, f func() error) er }() return f() } - -func RunWithRetries(f func() error, maxRuns, sleepMs int) error { - var err error - for i := 0; i < maxRuns; i++ { - err = f() - if err == nil { - break - } - logger.Info("Retrying after delay...", zap.String("error", err.Error()), zap.Int("retry", i), zap.Int("sleepMs", sleepMs)) - time.Sleep(time.Duration(sleepMs) * time.Millisecond) - } - return err -} diff --git a/network/transparent_vlan_endpointclient_linux_test.go b/network/transparent_vlan_endpointclient_linux_test.go index be64142bc5..eb9300fc15 100644 --- a/network/transparent_vlan_endpointclient_linux_test.go +++ b/network/transparent_vlan_endpointclient_linux_test.go @@ -908,58 +908,3 @@ func TestTransparentVlanConfigureContainerInterfacesAndRoutes(t *testing.T) { }) } } - -func createFunctionWithFailurePattern(errorPattern []error) func() error { - s := 0 - return func() error { - if s >= len(errorPattern) { - return nil - } - result := errorPattern[s] - s++ - return result - } -} - -func TestRunWithRetries(t *testing.T) { - errMock := errors.New("mock error") - runs := 4 - - tests := []struct { - name string - wantErr bool - f func() error - }{ - { - name: "Succeed on first try", - f: createFunctionWithFailurePattern([]error{}), - wantErr: false, - }, - { - name: "Succeed on first try do not check again", - f: createFunctionWithFailurePattern([]error{nil, errMock, errMock, errMock}), - wantErr: false, - }, - { - name: "Succeed on last try", - f: createFunctionWithFailurePattern([]error{errMock, errMock, errMock, nil, errMock}), - wantErr: false, - }, - { - name: "Fail after too many attempts", - f: createFunctionWithFailurePattern([]error{errMock, errMock, errMock, errMock, nil, nil}), - wantErr: true, - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - err := RunWithRetries(tt.f, runs, 100) - if tt.wantErr { - require.Error(t, err) - } else { - require.NoError(t, err) - } - }) - } -} diff --git a/retry/retry.go b/retry/retry.go new file mode 100644 index 0000000000..b679bbfd04 --- /dev/null +++ b/retry/retry.go @@ -0,0 +1,15 @@ +package retry + +import "time" + +func Do(f func() error, maxRuns, sleepMs int) error { + var err error + for i := 0; i < maxRuns; i++ { + err = f() + if err == nil { + break + } + time.Sleep(time.Duration(sleepMs) * time.Millisecond) + } + return err +} diff --git a/retry/retry_test.go b/retry/retry_test.go new file mode 100644 index 0000000000..4cfab01e68 --- /dev/null +++ b/retry/retry_test.go @@ -0,0 +1,63 @@ +package retry + +import ( + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +func createFunctionWithFailurePattern(errorPattern []error) func() error { + s := 0 + return func() error { + if s >= len(errorPattern) { + return nil + } + result := errorPattern[s] + s++ + return result + } +} + +func TestRunWithRetries(t *testing.T) { + errMock := errors.New("mock error") + runs := 4 + + tests := []struct { + name string + wantErr bool + f func() error + }{ + { + name: "Succeed on first try", + f: createFunctionWithFailurePattern([]error{}), + wantErr: false, + }, + { + name: "Succeed on first try do not check again", + f: createFunctionWithFailurePattern([]error{nil, errMock, errMock, errMock}), + wantErr: false, + }, + { + name: "Succeed on last try", + f: createFunctionWithFailurePattern([]error{errMock, errMock, errMock, nil, errMock}), + wantErr: false, + }, + { + name: "Fail after too many attempts", + f: createFunctionWithFailurePattern([]error{errMock, errMock, errMock, errMock, nil, nil}), + wantErr: true, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + err := Do(tt.f, runs, 100) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} From 8a56efbd50e15cd7c738f36bfbde160fc682dead Mon Sep 17 00:00:00 2001 From: QxBytes Date: Mon, 7 Oct 2024 17:04:25 -0700 Subject: [PATCH 10/14] leverage autoconfig ipv4 address for dhcp (tested) --- cni/network/network.go | 5 ++- dhcp/dhcp.go | 15 ++++--- dhcp/dhcp_windows.go | 94 ++++++++++++++++++++++-------------------- 3 files changed, 63 insertions(+), 51 deletions(-) diff --git a/cni/network/network.go b/cni/network/network.go index 7bc1f35394..47980f2d24 100644 --- a/cni/network/network.go +++ b/cni/network/network.go @@ -132,8 +132,9 @@ func NewPlugin(name string, nl := netlink.NewNetlink() plc := platform.NewExecClient(logger) + netio := &netio.NetIO{} // Setup network manager. - nm, err := network.NewNetworkManager(nl, plc, &netio.NetIO{}, network.NewNamespaceClient(), iptables.NewClient(), dhcp.New(logger, plc)) + nm, err := network.NewNetworkManager(nl, plc, netio, network.NewNamespaceClient(), iptables.NewClient(), dhcp.New(logger, netio)) if err != nil { return nil, err } @@ -145,7 +146,7 @@ func NewPlugin(name string, nm: nm, nnsClient: client, multitenancyClient: multitenancyClient, - netClient: &netio.NetIO{}, + netClient: netio, }, nil } diff --git a/dhcp/dhcp.go b/dhcp/dhcp.go index 595ce3c11f..e74dc1c77c 100644 --- a/dhcp/dhcp.go +++ b/dhcp/dhcp.go @@ -51,15 +51,20 @@ type ExecClient interface { ExecuteCommand(ctx context.Context, command string, args ...string) (string, error) } +type NetIOClient interface { + GetNetworkInterfaceByName(name string) (*net.Interface, error) + GetNetworkInterfaceAddrs(iface *net.Interface) ([]net.Addr, error) +} + type DHCP struct { - logger *zap.Logger - execClient ExecClient + logger *zap.Logger + netioClient NetIOClient } -func New(logger *zap.Logger, plc ExecClient) *DHCP { +func New(logger *zap.Logger, netio NetIOClient) *DHCP { return &DHCP{ - logger: logger, - execClient: plc, + logger: logger, + netioClient: netio, } } diff --git a/dhcp/dhcp_windows.go b/dhcp/dhcp_windows.go index 1f289213fc..4e8b8a1701 100644 --- a/dhcp/dhcp_windows.go +++ b/dhcp/dhcp_windows.go @@ -3,10 +3,7 @@ package dhcp import ( "context" "net" - "regexp" - "time" - "github.com/Azure/azure-container-networking/cni/log" "github.com/Azure/azure-container-networking/retry" "github.com/pkg/errors" "go.uber.org/zap" @@ -14,20 +11,10 @@ import ( ) const ( - dummyIPAddressStr = "169.254.128.10" - dummySubnetMask = "255.255.128.0" - addIPAddressDelay = 4 * time.Second - deleteIPAddressDelay = 2 * time.Second - returnDelay = 8 * time.Second // time to wait before returning from DiscoverRequest - retryCount = 5 - retryDelayMillis = 500 - socketTimeoutMillis = 1000 -) - -var ( - dummyIPAddress = net.IPv4(169, 254, 128, 10) // nolint - // matches if the string fully consists of zero or more alphanumeric, dots, dashes, parentheses, spaces, or underscores - allowedInput = regexp.MustCompile(`^[a-zA-Z0-9._\-\(\) ]*$`) + retryCount = 5 + retryDelayMillis = 500 + ipAssignRetryDelayMillis = 2000 + socketTimeoutMillis = 1000 ) type Socket struct { @@ -91,38 +78,57 @@ func (s *Socket) Close() error { return nil } -// issues a dhcp discover request on an interface by assigning an ip to that interface -// then, sends a packet with that interface's dummy ip, and then unassigns the dummy ip -func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, ifName string) error { - // validate interface name - if !allowedInput.MatchString(ifName) { - return errors.New("invalid dhcp discover request interface name") +func (c *DHCP) getIPv4InterfaceAddresses(ifName string) ([]net.IP, error) { + nic, err := c.netioClient.GetNetworkInterfaceByName(ifName) + if err != nil { + return []net.IP{}, err } - // delete dummy ip off the interface if it already exists - ret, err := c.execClient.ExecuteCommand(ctx, "netsh", "interface", "ipv4", "delete", "address", ifName, dummyIPAddressStr) + addresses, err := c.netioClient.GetNetworkInterfaceAddrs(nic) if err != nil { - c.logger.Info("Could not remove dummy ip, likely because it doesn't exist", zap.String("output", ret), zap.Error(log.NewErrorWithoutStackTrace(err))) + return []net.IP{}, err + } + ret := []net.IP{} + for _, address := range addresses { + // check if the ip is ipv4 and parse it + ip, _, err := net.ParseCIDR(address.String()) + if err != nil || ip.To4() == nil { + continue + } + ret = append(ret, ip) } - time.Sleep(deleteIPAddressDelay) - // create dummy ip so we can direct the packet to the correct interface - ret, err = c.execClient.ExecuteCommand(ctx, "netsh", "interface", "ipv4", "add", "address", ifName, dummyIPAddressStr, dummySubnetMask) + c.logger.Info("Interface addresses found", zap.Any("foundIPs", addresses), zap.Any("selectedIPs", ret)) + return ret, err +} + +func (c *DHCP) verifyIPv4InterfaceAddressCount(ifName string, count, maxRuns, sleepMs int) error { + addressCountErr := retry.Do(func() error { + addresses, err := c.getIPv4InterfaceAddresses(ifName) + if err != nil || len(addresses) != count { + return errors.New("address count found not equal to expected") + } + return nil + }, maxRuns, sleepMs) + return addressCountErr +} + +// issues a dhcp discover request on an interface by finding the secondary's ip and sending on its ip +func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, ifName string) error { + // Find the ipv4 address of the secondary interface (we're betting that this gets autoconfigured) + err := c.verifyIPv4InterfaceAddressCount(ifName, 1, retryCount, ipAssignRetryDelayMillis) if err != nil { - return errors.Wrap(err, "failed to add dummy ip to interface: "+ret) + return errors.Wrap(err, "failed to get auto ip config assigned in apipa range in time") } - // ensure we always remove the dummy ip we added from the interface - defer func() { - // we always want to try to remove the dummy ip, even if the deadline was reached - // so we have context.Background() - ret, cleanupErr := c.execClient.ExecuteCommand(context.Background(), "netsh", "interface", "ipv4", "delete", "address", ifName, dummyIPAddressStr) - if cleanupErr != nil { - c.logger.Info("Failed to remove dummy ip on leaving function", zap.String("output", ret), zap.Error(err)) - } - // wait for nic to retrieve autoconfiguration ip - time.Sleep(returnDelay) - }() - // it takes time for the address to be assigned - time.Sleep(addIPAddressDelay) + ipv4Addresses, err := c.getIPv4InterfaceAddresses(ifName) + if err != nil || len(ipv4Addresses) == 0 { + return errors.Wrap(err, "failed to get ipv4 addresses on interface") + } + uniqueAddress := ipv4Addresses[0].To4() + if uniqueAddress == nil { + return errors.New("invalid ipv4 address") + } + uniqueAddressStr := uniqueAddress.String() + c.logger.Info("Retrieved automatic ip configuration: ", zap.Any("ip", uniqueAddress), zap.String("ipStr", uniqueAddressStr)) // now begin the dhcp request txid, err := GenerateTransactionID() @@ -132,7 +138,7 @@ func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, // Prepare an IP and UDP header raddr := &net.UDPAddr{IP: net.IPv4bcast, Port: dhcpServerPort} - laddr := &net.UDPAddr{IP: dummyIPAddress, Port: dhcpClientPort} + laddr := &net.UDPAddr{IP: uniqueAddress, Port: dhcpClientPort} dhcpDiscover, err := buildDHCPDiscover(macAddress, txid) if err != nil { From ac4f51c2fc5c3ef476d39b1b416b20b2240ebe04 Mon Sep 17 00:00:00 2001 From: QxBytes Date: Tue, 8 Oct 2024 10:13:44 -0700 Subject: [PATCH 11/14] address feedback --- cni/network/network.go | 6 +++--- dhcp/dhcp_windows.go | 9 +++++++-- network/network_windows.go | 3 +-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/cni/network/network.go b/cni/network/network.go index 47980f2d24..0106af68da 100644 --- a/cni/network/network.go +++ b/cni/network/network.go @@ -132,9 +132,9 @@ func NewPlugin(name string, nl := netlink.NewNetlink() plc := platform.NewExecClient(logger) - netio := &netio.NetIO{} + nio := &netio.NetIO{} // Setup network manager. - nm, err := network.NewNetworkManager(nl, plc, netio, network.NewNamespaceClient(), iptables.NewClient(), dhcp.New(logger, netio)) + nm, err := network.NewNetworkManager(nl, plc, nio, network.NewNamespaceClient(), iptables.NewClient(), dhcp.New(logger, nio)) if err != nil { return nil, err } @@ -146,7 +146,7 @@ func NewPlugin(name string, nm: nm, nnsClient: client, multitenancyClient: multitenancyClient, - netClient: netio, + netClient: nio, }, nil } diff --git a/dhcp/dhcp_windows.go b/dhcp/dhcp_windows.go index 4e8b8a1701..1847c893bc 100644 --- a/dhcp/dhcp_windows.go +++ b/dhcp/dhcp_windows.go @@ -17,6 +17,11 @@ const ( socketTimeoutMillis = 1000 ) +var ( + errInvalidIPv4Address = errors.New("invalid ipv4 address") + errIncorrectAddressCount = errors.New("address count found not equal to expected") +) + type Socket struct { fd windows.Handle destAddr windows.SockaddrInet4 @@ -105,7 +110,7 @@ func (c *DHCP) verifyIPv4InterfaceAddressCount(ifName string, count, maxRuns, sl addressCountErr := retry.Do(func() error { addresses, err := c.getIPv4InterfaceAddresses(ifName) if err != nil || len(addresses) != count { - return errors.New("address count found not equal to expected") + return errIncorrectAddressCount } return nil }, maxRuns, sleepMs) @@ -125,7 +130,7 @@ func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, } uniqueAddress := ipv4Addresses[0].To4() if uniqueAddress == nil { - return errors.New("invalid ipv4 address") + return errInvalidIPv4Address } uniqueAddressStr := uniqueAddress.String() c.logger.Info("Retrieved automatic ip configuration: ", zap.Any("ip", uniqueAddress), zap.String("ipStr", uniqueAddressStr)) diff --git a/network/network_windows.go b/network/network_windows.go index d62f0afc2f..b87a00bc2b 100644 --- a/network/network_windows.go +++ b/network/network_windows.go @@ -438,8 +438,7 @@ func (nm *networkManager) newNetworkImplHnsV2(nwInfo *EndpointInfo, extIf *exter func (nm *networkManager) sendDHCPDiscoverOnSecondary(client dhcpClient, mac net.HardwareAddr, ifName string) error { // issue dhcp discover packet to ensure mapping created for dns via wireserver to work // we do not use the response for anything - numSecs := 15 // we need to wait for the address to be assigned - timeout := time.Duration(numSecs) * time.Second + timeout := 15 * time.Second ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(timeout)) defer cancel() logger.Info("Sending DHCP packet", zap.Any("macAddress", mac), zap.String("ifName", ifName)) From 0e0414ee7c589dec381c4abc8ea4d0fc2bc64712 Mon Sep 17 00:00:00 2001 From: QxBytes Date: Tue, 8 Oct 2024 12:33:14 -0700 Subject: [PATCH 12/14] promote nmagent retry to package, include final retry error, and adapt existing cni code include last error cause when cooldown function errors (discuss) for compatability use durations instead of millis correct usage of number retries vs number of runs create helper function for retrying in transparent vlan mode create wrapper method that makes all passed in errors temporary (retriable) errors --- dhcp/dhcp_windows.go | 27 ++- .../transparent_vlan_endpointclient_linux.go | 37 ++-- ...nsparent_vlan_endpointclient_linux_test.go | 16 +- nmagent/client.go | 5 +- nmagent/client_helpers_test.go | 5 +- nmagent/internal/retry.go | 125 ------------ nmagent/internal/retry_test.go | 164 --------------- retry/retry.go | 164 ++++++++++++++- .../internal => retry}/retry_example_test.go | 2 +- retry/retry_test.go | 187 +++++++++++++++++- 10 files changed, 395 insertions(+), 337 deletions(-) delete mode 100644 nmagent/internal/retry.go delete mode 100644 nmagent/internal/retry_test.go rename {nmagent/internal => retry}/retry_example_test.go (98%) diff --git a/dhcp/dhcp_windows.go b/dhcp/dhcp_windows.go index 1847c893bc..da092ce3bc 100644 --- a/dhcp/dhcp_windows.go +++ b/dhcp/dhcp_windows.go @@ -3,6 +3,7 @@ package dhcp import ( "context" "net" + "time" "github.com/Azure/azure-container-networking/retry" "github.com/pkg/errors" @@ -11,10 +12,10 @@ import ( ) const ( - retryCount = 5 - retryDelayMillis = 500 - ipAssignRetryDelayMillis = 2000 - socketTimeoutMillis = 1000 + retryCount = 4 + retryDelay = 500 * time.Millisecond + ipAssignRetryDelay = 2000 * time.Millisecond + socketTimeoutMillis = 1000 ) var ( @@ -106,21 +107,24 @@ func (c *DHCP) getIPv4InterfaceAddresses(ifName string) ([]net.IP, error) { return ret, err } -func (c *DHCP) verifyIPv4InterfaceAddressCount(ifName string, count, maxRuns, sleepMs int) error { - addressCountErr := retry.Do(func() error { +func (c *DHCP) verifyIPv4InterfaceAddressCount(ifName string, count, maxRuns int, sleep time.Duration) error { + retrier := retry.Retrier{ + Cooldown: retry.Max(maxRuns, retry.Fixed(sleep)), + } + addressCountErr := retrier.Do(context.Background(), func() error { addresses, err := c.getIPv4InterfaceAddresses(ifName) if err != nil || len(addresses) != count { return errIncorrectAddressCount } return nil - }, maxRuns, sleepMs) + }) return addressCountErr } // issues a dhcp discover request on an interface by finding the secondary's ip and sending on its ip func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, ifName string) error { // Find the ipv4 address of the secondary interface (we're betting that this gets autoconfigured) - err := c.verifyIPv4InterfaceAddressCount(ifName, 1, retryCount, ipAssignRetryDelayMillis) + err := c.verifyIPv4InterfaceAddressCount(ifName, 1, retryCount, ipAssignRetryDelay) if err != nil { return errors.Wrap(err, "failed to get auto ip config assigned in apipa range in time") } @@ -173,11 +177,14 @@ func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, return errors.Wrap(err, "failed to create socket") } + retrier := retry.Retrier{ + Cooldown: retry.Max(retryCount, retry.Fixed(retryDelay)), + } // retry sending the packet until it succeeds - err = retry.Do(func() error { + err = retrier.Do(context.Background(), func() error { _, sockErr := sock.Write(bytesToSend) return sockErr - }, retryCount, retryDelayMillis) + }) if err != nil { return errors.Wrap(err, "failed to write to dhcp socket") } diff --git a/network/transparent_vlan_endpointclient_linux.go b/network/transparent_vlan_endpointclient_linux.go index 193529cef5..e9151ec5c0 100644 --- a/network/transparent_vlan_endpointclient_linux.go +++ b/network/transparent_vlan_endpointclient_linux.go @@ -1,9 +1,11 @@ package network import ( + "context" "fmt" "net" "strings" + "time" "github.com/Azure/azure-container-networking/iptables" "github.com/Azure/azure-container-networking/netio" @@ -26,8 +28,8 @@ const ( tunnelingTable = 2 // Packets not entering on the vlan interface go to this routing table tunnelingMark = 333 // The packets that are to tunnel will be marked with this number DisableRPFilterCmd = "sysctl -w net.ipv4.conf.all.rp_filter=0" // Command to disable the rp filter for tunneling - numRetries = 5 - sleepInMs = 100 + numRetries = 4 + sleepDelay = 100 * time.Millisecond ) var errNamespaceCreation = fmt.Errorf("network namespace creation error") @@ -192,13 +194,13 @@ func (client *TransparentVlanEndpointClient) setLinkNetNSAndConfirm(name string, } // confirm veth was moved successfully - err = retry.Do(func() error { + err = client.Retry(func() error { // retry checking in the namespace if the interface is not detected return ExecuteInNS(client.nsClient, client.vnetNSName, func() error { _, ifDetectedErr := client.netioshim.GetNetworkInterfaceByName(client.vlanIfName) return errors.Wrap(ifDetectedErr, "failed to get vlan veth in namespace") }) - }, numRetries, sleepInMs) + }) if err != nil { return errors.Wrapf(err, "failed to detect %v inside namespace %v", name, fd) } @@ -220,9 +222,9 @@ func (client *TransparentVlanEndpointClient) PopulateVM(epInfo *EndpointInfo) er // We assume the only possible error is that the namespace doesn't exist logger.Info("No existing NS detected. Creating the vnet namespace and switching to it", zap.String("message", existingErr.Error())) - err = retry.Do(func() error { + err = client.Retry(func() error { return client.createNetworkNamespace(vmNS) - }, numRetries, sleepInMs) + }) if err != nil { return errors.Wrap(err, "failed to create network namespace") } @@ -279,10 +281,10 @@ func (client *TransparentVlanEndpointClient) PopulateVM(epInfo *EndpointInfo) er }() // sometimes there is slight delay in interface creation. check if it exists - err = retry.Do(func() error { + err = client.Retry(func() error { _, err = client.netioshim.GetNetworkInterfaceByName(client.vlanIfName) return errors.Wrap(err, "failed to get vlan veth") - }, numRetries, sleepInMs) + }) if err != nil { deleteNSIfNotNilErr = errors.Wrapf(err, "failed to get vlan veth interface:%s", client.vlanIfName) @@ -316,21 +318,21 @@ func (client *TransparentVlanEndpointClient) PopulateVM(epInfo *EndpointInfo) er } // Ensure vnet veth is created, as there may be a slight delay - err = retry.Do(func() error { + err = client.Retry(func() error { _, getErr := client.netioshim.GetNetworkInterfaceByName(client.vnetVethName) return errors.Wrap(getErr, "failed to get vnet veth") - }, numRetries, sleepInMs) + }) if err != nil { return errors.Wrap(err, "vnet veth does not exist") } // Ensure container veth is created, as there may be a slight delay var containerIf *net.Interface - err = retry.Do(func() error { + err = client.Retry(func() error { var getErr error containerIf, getErr = client.netioshim.GetNetworkInterfaceByName(client.containerVethName) return errors.Wrap(getErr, "failed to get container veth") - }, numRetries, sleepInMs) + }) if err != nil { return errors.Wrap(err, "container veth does not exist") } @@ -673,6 +675,17 @@ func (client *TransparentVlanEndpointClient) DeleteEndpointsImpl(ep *endpoint, _ return nil } +// Creates a new retrier with a fixed delay, and treats all errors as retriable +func (client *TransparentVlanEndpointClient) Retry(f func() error) error { + retrier := retry.Retrier{ + Cooldown: retry.Max(numRetries, retry.Fixed(sleepDelay)), + } + return retrier.Do(context.Background(), func() error { + // we always want to retry, so all errors are temporary errors + return retry.WrapTemporaryError(f()) + }) +} + // Helper function that allows executing a function in a VM namespace // Does not work for process namespaces func ExecuteInNS(nsc NamespaceClientInterface, nsName string, f func() error) error { diff --git a/network/transparent_vlan_endpointclient_linux_test.go b/network/transparent_vlan_endpointclient_linux_test.go index eb9300fc15..fcf0d81a21 100644 --- a/network/transparent_vlan_endpointclient_linux_test.go +++ b/network/transparent_vlan_endpointclient_linux_test.go @@ -190,7 +190,7 @@ func TestTransparentVlanAddEndpoints(t *testing.T) { err := tt.client.setLinkNetNSAndConfirm(tt.client.vlanIfName, 1) if tt.wantErr { require.Error(t, err) - require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v actual:%v", tt.wantErrMsg, err.Error()) + require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v \nActual:%v", tt.wantErrMsg, err.Error()) } else { require.NoError(t, err) } @@ -287,7 +287,7 @@ func TestTransparentVlanAddEndpoints(t *testing.T) { err := tt.client.ensureCleanPopulateVM() if tt.wantErr { require.Error(t, err) - require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v actual:%v", tt.wantErrMsg, err.Error()) + require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v \nActual:%v", tt.wantErrMsg, err.Error()) } else { require.NoError(t, err) } @@ -429,7 +429,7 @@ func TestTransparentVlanAddEndpoints(t *testing.T) { }, epInfo: &EndpointInfo{}, wantErr: true, - wantErrMsg: "container veth does not exist: failed to get container veth: B1veth0: " + errMockNetIOFail.Error() + "", + wantErrMsg: "failed to get container veth: B1veth0: " + errMockNetIOFail.Error() + "", }, { name: "Add endpoints NetNS Get fail", @@ -489,7 +489,7 @@ func TestTransparentVlanAddEndpoints(t *testing.T) { err := tt.client.PopulateVM(tt.epInfo) if tt.wantErr { require.Error(t, err) - require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v actual:%v", tt.wantErrMsg, err.Error()) + require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v \nActual:%v", tt.wantErrMsg, err.Error()) } else { require.NoError(t, err) } @@ -579,7 +579,7 @@ func TestTransparentVlanAddEndpoints(t *testing.T) { err := tt.client.PopulateVnet(tt.epInfo) if tt.wantErr { require.Error(t, err) - require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v actual:%v", tt.wantErrMsg, err.Error()) + require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v \nActual:%v", tt.wantErrMsg, err.Error()) } else { require.NoError(t, err) } @@ -695,7 +695,7 @@ func TestTransparentVlanDeleteEndpoints(t *testing.T) { err := tt.client.DeleteEndpointsImpl(tt.ep, tt.routesLeft) if tt.wantErr { require.Error(t, err) - require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v actual:%v", tt.wantErrMsg, err.Error()) + require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v \nActual:%v", tt.wantErrMsg, err.Error()) } else { require.NoError(t, err) } @@ -830,7 +830,7 @@ func TestTransparentVlanConfigureContainerInterfacesAndRoutes(t *testing.T) { err := tt.client.ConfigureContainerInterfacesAndRoutesImpl(tt.epInfo) if tt.wantErr { require.Error(t, err) - require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v actual:%v", tt.wantErrMsg, err.Error()) + require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v \nActual:%v", tt.wantErrMsg, err.Error()) } else { require.NoError(t, err) } @@ -901,7 +901,7 @@ func TestTransparentVlanConfigureContainerInterfacesAndRoutes(t *testing.T) { err := tt.client.ConfigureVnetInterfacesAndRoutesImpl(tt.epInfo) if tt.wantErr { require.Error(t, err) - require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v actual:%v", tt.wantErrMsg, err.Error()) + require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v \nActual:%v", tt.wantErrMsg, err.Error()) } else { require.NoError(t, err) } diff --git a/nmagent/client.go b/nmagent/client.go index 71a0810978..805e4597bb 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -12,6 +12,7 @@ import ( "time" "github.com/Azure/azure-container-networking/nmagent/internal" + "github.com/Azure/azure-container-networking/retry" "github.com/pkg/errors" ) @@ -30,9 +31,9 @@ func NewClient(c Config) (*Client, error) { host: c.Host, port: c.Port, enableTLS: c.UseTLS, - retrier: internal.Retrier{ + retrier: retry.Retrier{ // nolint:gomnd // the base parameter is explained in the function - Cooldown: internal.Exponential(1*time.Second, 2), + Cooldown: retry.Exponential(1*time.Second, 2), }, } diff --git a/nmagent/client_helpers_test.go b/nmagent/client_helpers_test.go index 4f98959742..c659ea6f75 100644 --- a/nmagent/client_helpers_test.go +++ b/nmagent/client_helpers_test.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/Azure/azure-container-networking/nmagent/internal" + "github.com/Azure/azure-container-networking/retry" ) // NewTestClient is a factory function available in tests only for creating @@ -17,8 +18,8 @@ func NewTestClient(transport http.RoundTripper) *Client { }, host: "localhost", port: 12345, - retrier: internal.Retrier{ - Cooldown: internal.AsFastAsPossible(), + retrier: retry.Retrier{ + Cooldown: retry.AsFastAsPossible(), }, } } diff --git a/nmagent/internal/retry.go b/nmagent/internal/retry.go deleted file mode 100644 index 9491aea0e4..0000000000 --- a/nmagent/internal/retry.go +++ /dev/null @@ -1,125 +0,0 @@ -package internal - -import ( - "context" - "errors" - "math" - "time" - - pkgerrors "github.com/pkg/errors" -) - -const ( - noDelay = 0 * time.Nanosecond -) - -const ( - ErrMaxAttempts = Error("maximum attempts reached") -) - -// TemporaryError is an error that can indicate whether it may be resolved with -// another attempt. -type TemporaryError interface { - error - Temporary() bool -} - -// Retrier is a construct for attempting some operation multiple times with a -// configurable backoff strategy. -type Retrier struct { - Cooldown CooldownFactory -} - -// Do repeatedly invokes the provided run function while the context remains -// active. It waits in between invocations of the provided functions by -// delegating to the provided Cooldown function. -func (r Retrier) Do(ctx context.Context, run func() error) error { - cooldown := r.Cooldown() - - for { - if err := ctx.Err(); err != nil { - // nolint:wrapcheck // no meaningful information can be added to this error - return err - } - - err := run() - if err != nil { - // check to see if it's temporary. - var tempErr TemporaryError - if ok := errors.As(err, &tempErr); ok && tempErr.Temporary() { - delay, err := cooldown() // nolint:govet // the shadow is intentional - if err != nil { - return pkgerrors.Wrap(err, "sleeping during retry") - } - time.Sleep(delay) - continue - } - - // since it's not temporary, it can't be retried, so... - return err - } - return nil - } -} - -// CooldownFunc is a function that will block when called. It is intended for -// use with retry logic. -type CooldownFunc func() (time.Duration, error) - -// CooldownFactory is a function that returns CooldownFuncs. It helps -// CooldownFuncs dispose of any accumulated state so that they function -// correctly upon successive uses. -type CooldownFactory func() CooldownFunc - -// Max provides a fixed limit for the number of times a subordinate cooldown -// function can be invoked. -func Max(limit int, factory CooldownFactory) CooldownFactory { - return func() CooldownFunc { - cooldown := factory() - count := 0 - return func() (time.Duration, error) { - if count >= limit { - return noDelay, ErrMaxAttempts - } - - delay, err := cooldown() - if err != nil { - return noDelay, err - } - count++ - return delay, nil - } - } -} - -// AsFastAsPossible is a Cooldown strategy that does not block, allowing retry -// logic to proceed as fast as possible. This is particularly useful in tests. -func AsFastAsPossible() CooldownFactory { - return func() CooldownFunc { - return func() (time.Duration, error) { - return noDelay, nil - } - } -} - -// Exponential provides an exponential increase the the base interval provided. -func Exponential(interval time.Duration, base int) CooldownFactory { - return func() CooldownFunc { - count := 0 - return func() (time.Duration, error) { - increment := math.Pow(float64(base), float64(count)) - delay := interval.Nanoseconds() * int64(increment) - count++ - return time.Duration(delay), nil - } - } -} - -// Fixed produced the same delay value upon each invocation. -func Fixed(delay time.Duration) CooldownFactory { - return func() CooldownFunc { - return func() (time.Duration, error) { - return delay, nil - } - } -} diff --git a/nmagent/internal/retry_test.go b/nmagent/internal/retry_test.go deleted file mode 100644 index 55824de38b..0000000000 --- a/nmagent/internal/retry_test.go +++ /dev/null @@ -1,164 +0,0 @@ -package internal - -import ( - "context" - "errors" - "testing" - "time" -) - -type TestError struct{} - -func (t TestError) Error() string { - return "oh no!" -} - -func (t TestError) Temporary() bool { - return true -} - -func TestBackoffRetry(t *testing.T) { - got := 0 - exp := 10 - - ctx := context.Background() - - rt := Retrier{ - Cooldown: AsFastAsPossible(), - } - - err := rt.Do(ctx, func() error { - if got < exp { - got++ - return TestError{} - } - return nil - }) - if err != nil { - t.Fatal("unexpected error: err:", err) - } - - if got < exp { - t.Error("unexpected number of invocations: got:", got, "exp:", exp) - } -} - -func TestBackoffRetryWithCancel(t *testing.T) { - got := 0 - exp := 5 - total := 10 - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - rt := Retrier{ - Cooldown: AsFastAsPossible(), - } - - err := rt.Do(ctx, func() error { - got++ - if got >= exp { - cancel() - } - - if got < total { - return TestError{} - } - return nil - }) - - if err == nil { - t.Error("expected context cancellation error, but received none") - } - - if got != exp { - t.Error("unexpected number of iterations: exp:", exp, "got:", got) - } -} - -func TestBackoffRetryUnretriableError(t *testing.T) { - rt := Retrier{ - Cooldown: AsFastAsPossible(), - } - - err := rt.Do(context.Background(), func() error { - return errors.New("boom") // nolint:goerr113 // it's just a test - }) - - if err == nil { - t.Fatal("expected an error, but none was returned") - } -} - -func TestFixed(t *testing.T) { - exp := 20 * time.Millisecond - - cooldown := Fixed(exp)() - - got, err := cooldown() - if err != nil { - t.Fatal("unexpected error invoking cooldown: err:", err) - } - - if got != exp { - t.Fatal("unexpected sleep duration: exp:", exp, "got:", got) - } -} - -func TestExp(t *testing.T) { - exp := 10 * time.Millisecond - base := 2 - - cooldown := Exponential(exp, base)() - - first, err := cooldown() - if err != nil { - t.Fatal("unexpected error invoking cooldown: err:", err) - } - - if first != exp { - t.Fatal("unexpected sleep during first cooldown: exp:", exp, "got:", first) - } - - // ensure that the sleep increases - second, err := cooldown() - if err != nil { - t.Fatal("unexpected error on second invocation of cooldown: err:", err) - } - - if second < first { - t.Fatal("unexpected sleep during first cooldown: exp:", exp, "got:", second) - } -} - -func TestMax(t *testing.T) { - exp := 10 - got := 0 - - // create a test sleep function - fn := func() CooldownFunc { - return func() (time.Duration, error) { - got++ - return 0 * time.Nanosecond, nil - } - } - - cooldown := Max(10, fn)() - - for i := 0; i < exp; i++ { - _, err := cooldown() - if err != nil { - t.Fatal("unexpected error from cooldown: err:", err) - } - } - - if exp != got { - t.Error("unexpected number of cooldown invocations: exp:", exp, "got:", got) - } - - // attempt one more, we expect an error - _, err := cooldown() - if err == nil { - t.Errorf("expected an error after %d invocations but received none", exp+1) - } -} diff --git a/retry/retry.go b/retry/retry.go index b679bbfd04..e37f86073e 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -1,15 +1,161 @@ package retry -import "time" +import ( + "context" + "errors" + "math" + "time" -func Do(f func() error, maxRuns, sleepMs int) error { - var err error - for i := 0; i < maxRuns; i++ { - err = f() - if err == nil { - break + pkgerrors "github.com/pkg/errors" +) + +const ( + noDelay = 0 * time.Nanosecond +) + +const ( + ErrMaxAttempts = Error("maximum attempts reached") +) + +// Error represents an internal sentinal error which can be defined as a +// constant. +type Error string + +func (e Error) Error() string { + return string(e) +} + +// RetriableError is an implementation of TemporaryError that is always retriable +type RetriableError struct { + err error +} + +func (r RetriableError) Error() string { + if r.err == nil { + return "" + } + return r.err.Error() +} +func (r RetriableError) Unwrap() error { + return r.err +} +func (r RetriableError) Temporary() bool { + return true +} + +// Forces the error to be retriable, returns nil if error is nil +func WrapTemporaryError(err error) error { + if err == nil { + return nil + } + return RetriableError{err: err} +} + +// TemporaryError is an error that can indicate whether it may be resolved with +// another attempt. +type TemporaryError interface { + error + Temporary() bool +} + +// Retrier is a construct for attempting some operation multiple times with a +// configurable backoff strategy. +type Retrier struct { + Cooldown CooldownFactory +} + +// Do repeatedly invokes the provided run function while the context remains +// active. It waits in between invocations of the provided functions by +// delegating to the provided Cooldown function. +func (r Retrier) Do(ctx context.Context, run func() error) error { + cooldown := r.Cooldown() + + for { + if err := ctx.Err(); err != nil { + // nolint:wrapcheck // no meaningful information can be added to this error + return err + } + + err := run() + if err != nil { + // check to see if it's temporary. + var tempErr TemporaryError + if ok := errors.As(err, &tempErr); ok && tempErr.Temporary() { + delay, cooldownErr := cooldown() + if cooldownErr != nil { + return pkgerrors.Wrap(cooldownErr, "sleeping during retry, last error:"+err.Error()) + } + time.Sleep(delay) + continue + } + + // since it's not temporary, it can't be retried, so... + return err + } + return nil + } +} + +// CooldownFunc is a function that will block when called. It is intended for +// use with retry logic. +type CooldownFunc func() (time.Duration, error) + +// CooldownFactory is a function that returns CooldownFuncs. It helps +// CooldownFuncs dispose of any accumulated state so that they function +// correctly upon successive uses. +type CooldownFactory func() CooldownFunc + +// Max provides a fixed limit for the number of times a subordinate cooldown +// function can be invoked. Note if we set the limit to 0, we still +// invoke the target method once in the retrier, but do not retry +// Read this as the max number of *retries* +func Max(limit int, factory CooldownFactory) CooldownFactory { + return func() CooldownFunc { + cooldown := factory() + count := 0 + return func() (time.Duration, error) { + if count >= limit { + return noDelay, ErrMaxAttempts + } + + delay, err := cooldown() + if err != nil { + return noDelay, err + } + count++ + return delay, nil + } + } +} + +// AsFastAsPossible is a Cooldown strategy that does not block, allowing retry +// logic to proceed as fast as possible. This is particularly useful in tests. +func AsFastAsPossible() CooldownFactory { + return func() CooldownFunc { + return func() (time.Duration, error) { + return noDelay, nil + } + } +} + +// Exponential provides an exponential increase the the base interval provided. +func Exponential(interval time.Duration, base int) CooldownFactory { + return func() CooldownFunc { + count := 0 + return func() (time.Duration, error) { + increment := math.Pow(float64(base), float64(count)) + delay := interval.Nanoseconds() * int64(increment) + count++ + return time.Duration(delay), nil + } + } +} + +// Fixed produced the same delay value upon each invocation. +func Fixed(delay time.Duration) CooldownFactory { + return func() CooldownFunc { + return func() (time.Duration, error) { + return delay, nil } - time.Sleep(time.Duration(sleepMs) * time.Millisecond) } - return err } diff --git a/nmagent/internal/retry_example_test.go b/retry/retry_example_test.go similarity index 98% rename from nmagent/internal/retry_example_test.go rename to retry/retry_example_test.go index c66bc194fb..919e35dcf8 100644 --- a/nmagent/internal/retry_example_test.go +++ b/retry/retry_example_test.go @@ -1,4 +1,4 @@ -package internal +package retry import ( "fmt" diff --git a/retry/retry_test.go b/retry/retry_test.go index 4cfab01e68..68f7dc97cc 100644 --- a/retry/retry_test.go +++ b/retry/retry_test.go @@ -1,12 +1,186 @@ package retry import ( + "context" + "errors" "testing" + "time" - "github.com/pkg/errors" + pkgerrors "github.com/pkg/errors" "github.com/stretchr/testify/require" ) +type TestError struct{} + +func (t TestError) Error() string { + return "oh no!" +} + +func (t TestError) Temporary() bool { + return true +} + +func TestBackoffRetry(t *testing.T) { + got := 0 + exp := 10 + + ctx := context.Background() + + rt := Retrier{ + Cooldown: AsFastAsPossible(), + } + + err := rt.Do(ctx, func() error { + if got < exp { + got++ + return TestError{} + } + return nil + }) + if err != nil { + t.Fatal("unexpected error: err:", err) + } + + if got < exp { + t.Error("unexpected number of invocations: got:", got, "exp:", exp) + } +} + +func TestBackoffRetryWithCancel(t *testing.T) { + got := 0 + exp := 5 + total := 10 + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rt := Retrier{ + Cooldown: AsFastAsPossible(), + } + + err := rt.Do(ctx, func() error { + got++ + if got >= exp { + cancel() + } + + if got < total { + return TestError{} + } + return nil + }) + + if err == nil { + t.Error("expected context cancellation error, but received none") + } + + if got != exp { + t.Error("unexpected number of iterations: exp:", exp, "got:", got) + } +} + +func TestBackoffRetryUnretriableError(t *testing.T) { + rt := Retrier{ + Cooldown: AsFastAsPossible(), + } + + err := rt.Do(context.Background(), func() error { + return errors.New("boom") // nolint:goerr113 // it's just a test + }) + + if err == nil { + t.Fatal("expected an error, but none was returned") + } +} + +func TestFixed(t *testing.T) { + exp := 20 * time.Millisecond + + cooldown := Fixed(exp)() + + got, err := cooldown() + if err != nil { + t.Fatal("unexpected error invoking cooldown: err:", err) + } + + if got != exp { + t.Fatal("unexpected sleep duration: exp:", exp, "got:", got) + } +} + +func TestExp(t *testing.T) { + exp := 10 * time.Millisecond + base := 2 + + cooldown := Exponential(exp, base)() + + first, err := cooldown() + if err != nil { + t.Fatal("unexpected error invoking cooldown: err:", err) + } + + if first != exp { + t.Fatal("unexpected sleep during first cooldown: exp:", exp, "got:", first) + } + + // ensure that the sleep increases + second, err := cooldown() + if err != nil { + t.Fatal("unexpected error on second invocation of cooldown: err:", err) + } + + if second < first { + t.Fatal("unexpected sleep during first cooldown: exp:", exp, "got:", second) + } +} + +func TestMax(t *testing.T) { + exp := 10 + got := 0 + + // create a test sleep function + fn := func() CooldownFunc { + return func() (time.Duration, error) { + got++ + return 0 * time.Nanosecond, nil + } + } + + cooldown := Max(10, fn)() + + for i := 0; i < exp; i++ { + _, err := cooldown() + if err != nil { + t.Fatal("unexpected error from cooldown: err:", err) + } + } + + if exp != got { + t.Error("unexpected number of cooldown invocations: exp:", exp, "got:", got) + } + + // attempt one more, we expect an error + _, err := cooldown() + if err == nil { + t.Errorf("expected an error after %d invocations but received none", exp+1) + } +} + +func TestRetriableError(t *testing.T) { + // wrapping nil returns a nil + require.Nil(t, WrapTemporaryError(nil)) + + mockError := errors.New("mock error") + wrappedMockError := WrapTemporaryError(pkgerrors.Wrap(mockError, "nested")) + + // temporary errors should still be able to be unwrapped + require.ErrorIs(t, wrappedMockError, mockError) + + var temporaryError TemporaryError + require.ErrorAs(t, wrappedMockError, &temporaryError) + require.True(t, temporaryError.Temporary(), "errors returned from wrap temporary error should have temporary set to true") +} + func createFunctionWithFailurePattern(errorPattern []error) func() error { s := 0 return func() error { @@ -20,8 +194,11 @@ func createFunctionWithFailurePattern(errorPattern []error) func() error { } func TestRunWithRetries(t *testing.T) { - errMock := errors.New("mock error") - runs := 4 + errMock := WrapTemporaryError(errors.New("mock error")) + retries := 3 // runs 4 times, then errors before the 5th + retrier := Retrier{ + Cooldown: Max(retries, Fixed(100*time.Millisecond)), + } tests := []struct { name string @@ -49,10 +226,12 @@ func TestRunWithRetries(t *testing.T) { wantErr: true, }, } + for _, tt := range tests { tt := tt + t.Run(tt.name, func(t *testing.T) { - err := Do(tt.f, runs, 100) + err := retrier.Do(context.Background(), tt.f) if tt.wantErr { require.Error(t, err) } else { From 524d0020b1517933bba80f655dbf57af84787fef Mon Sep 17 00:00:00 2001 From: QxBytes Date: Tue, 8 Oct 2024 12:47:42 -0700 Subject: [PATCH 13/14] address linter issues --- dhcp/dhcp_windows.go | 22 +++++++++---------- network/network_windows.go | 4 ++-- .../transparent_vlan_endpointclient_linux.go | 6 ++--- retry/retry.go | 2 ++ retry/retry_test.go | 11 +++++----- 5 files changed, 24 insertions(+), 21 deletions(-) diff --git a/dhcp/dhcp_windows.go b/dhcp/dhcp_windows.go index da092ce3bc..7c80ba6eba 100644 --- a/dhcp/dhcp_windows.go +++ b/dhcp/dhcp_windows.go @@ -87,44 +87,44 @@ func (s *Socket) Close() error { func (c *DHCP) getIPv4InterfaceAddresses(ifName string) ([]net.IP, error) { nic, err := c.netioClient.GetNetworkInterfaceByName(ifName) if err != nil { - return []net.IP{}, err + return []net.IP{}, errors.Wrap(err, "failed to get interface by name to find ipv4 addresses") } addresses, err := c.netioClient.GetNetworkInterfaceAddrs(nic) if err != nil { - return []net.IP{}, err + return []net.IP{}, errors.Wrap(err, "failed to get interface addresses") } ret := []net.IP{} for _, address := range addresses { // check if the ip is ipv4 and parse it - ip, _, err := net.ParseCIDR(address.String()) - if err != nil || ip.To4() == nil { + ip, _, cidrErr := net.ParseCIDR(address.String()) + if cidrErr != nil || ip.To4() == nil { continue } ret = append(ret, ip) } c.logger.Info("Interface addresses found", zap.Any("foundIPs", addresses), zap.Any("selectedIPs", ret)) - return ret, err + return ret, nil } -func (c *DHCP) verifyIPv4InterfaceAddressCount(ifName string, count, maxRuns int, sleep time.Duration) error { +func (c *DHCP) verifyIPv4InterfaceAddressCount(ctx context.Context, ifName string, count, numRetries int, sleep time.Duration) error { retrier := retry.Retrier{ - Cooldown: retry.Max(maxRuns, retry.Fixed(sleep)), + Cooldown: retry.Max(numRetries, retry.Fixed(sleep)), } - addressCountErr := retrier.Do(context.Background(), func() error { + addressCountErr := retrier.Do(ctx, func() error { addresses, err := c.getIPv4InterfaceAddresses(ifName) if err != nil || len(addresses) != count { return errIncorrectAddressCount } return nil }) - return addressCountErr + return errors.Wrap(addressCountErr, "failed to verify interface ipv4 address count") } // issues a dhcp discover request on an interface by finding the secondary's ip and sending on its ip func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, ifName string) error { // Find the ipv4 address of the secondary interface (we're betting that this gets autoconfigured) - err := c.verifyIPv4InterfaceAddressCount(ifName, 1, retryCount, ipAssignRetryDelay) + err := c.verifyIPv4InterfaceAddressCount(ctx, ifName, 1, retryCount, ipAssignRetryDelay) if err != nil { return errors.Wrap(err, "failed to get auto ip config assigned in apipa range in time") } @@ -181,7 +181,7 @@ func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, Cooldown: retry.Max(retryCount, retry.Fixed(retryDelay)), } // retry sending the packet until it succeeds - err = retrier.Do(context.Background(), func() error { + err = retrier.Do(ctx, func() error { _, sockErr := sock.Write(bytesToSend) return sockErr }) diff --git a/network/network_windows.go b/network/network_windows.go index b87a00bc2b..f3614fc1c3 100644 --- a/network/network_windows.go +++ b/network/network_windows.go @@ -46,6 +46,7 @@ const ( defaultIPv6Route = "::/0" // Default IPv6 nextHop defaultIPv6NextHop = "fe80::1234:5678:9abc" + dhcpTimeout = 15 * time.Second ) // Windows implementation of route. @@ -438,8 +439,7 @@ func (nm *networkManager) newNetworkImplHnsV2(nwInfo *EndpointInfo, extIf *exter func (nm *networkManager) sendDHCPDiscoverOnSecondary(client dhcpClient, mac net.HardwareAddr, ifName string) error { // issue dhcp discover packet to ensure mapping created for dns via wireserver to work // we do not use the response for anything - timeout := 15 * time.Second - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(timeout)) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(dhcpTimeout)) defer cancel() logger.Info("Sending DHCP packet", zap.Any("macAddress", mac), zap.String("ifName", ifName)) err := client.DiscoverRequest(ctx, mac, ifName) diff --git a/network/transparent_vlan_endpointclient_linux.go b/network/transparent_vlan_endpointclient_linux.go index e9151ec5c0..05e7c6d458 100644 --- a/network/transparent_vlan_endpointclient_linux.go +++ b/network/transparent_vlan_endpointclient_linux.go @@ -680,10 +680,10 @@ func (client *TransparentVlanEndpointClient) Retry(f func() error) error { retrier := retry.Retrier{ Cooldown: retry.Max(numRetries, retry.Fixed(sleepDelay)), } - return retrier.Do(context.Background(), func() error { + return errors.Wrap(retrier.Do(context.Background(), func() error { // we always want to retry, so all errors are temporary errors - return retry.WrapTemporaryError(f()) - }) + return retry.WrapTemporaryError(f()) // nolint + }), "error during retry") } // Helper function that allows executing a function in a VM namespace diff --git a/retry/retry.go b/retry/retry.go index e37f86073e..741d4ac521 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -36,9 +36,11 @@ func (r RetriableError) Error() string { } return r.err.Error() } + func (r RetriableError) Unwrap() error { return r.err } + func (r RetriableError) Temporary() bool { return true } diff --git a/retry/retry_test.go b/retry/retry_test.go index 68f7dc97cc..b5c92126a4 100644 --- a/retry/retry_test.go +++ b/retry/retry_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/require" ) +var errTest = errors.New("mock error") + type TestError struct{} func (t TestError) Error() string { @@ -168,13 +170,12 @@ func TestMax(t *testing.T) { func TestRetriableError(t *testing.T) { // wrapping nil returns a nil - require.Nil(t, WrapTemporaryError(nil)) + require.NoError(t, WrapTemporaryError(nil)) - mockError := errors.New("mock error") - wrappedMockError := WrapTemporaryError(pkgerrors.Wrap(mockError, "nested")) + wrappedMockError := WrapTemporaryError(pkgerrors.Wrap(errTest, "nested")) // temporary errors should still be able to be unwrapped - require.ErrorIs(t, wrappedMockError, mockError) + require.ErrorIs(t, wrappedMockError, errTest) var temporaryError TemporaryError require.ErrorAs(t, wrappedMockError, &temporaryError) @@ -194,7 +195,7 @@ func createFunctionWithFailurePattern(errorPattern []error) func() error { } func TestRunWithRetries(t *testing.T) { - errMock := WrapTemporaryError(errors.New("mock error")) + errMock := WrapTemporaryError(errTest) retries := 3 // runs 4 times, then errors before the 5th retrier := Retrier{ Cooldown: Max(retries, Fixed(100*time.Millisecond)), From bf974f60892e48f3d1d04c28219f280e4d91807a Mon Sep 17 00:00:00 2001 From: QxBytes Date: Wed, 9 Oct 2024 14:26:49 -0700 Subject: [PATCH 14/14] wrap in temporary error to retry --- dhcp/dhcp_windows.go | 4 ++-- retry/retry.go | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/dhcp/dhcp_windows.go b/dhcp/dhcp_windows.go index 7c80ba6eba..a1b66068c1 100644 --- a/dhcp/dhcp_windows.go +++ b/dhcp/dhcp_windows.go @@ -114,7 +114,7 @@ func (c *DHCP) verifyIPv4InterfaceAddressCount(ctx context.Context, ifName strin addressCountErr := retrier.Do(ctx, func() error { addresses, err := c.getIPv4InterfaceAddresses(ifName) if err != nil || len(addresses) != count { - return errIncorrectAddressCount + return retry.WrapTemporaryError(errIncorrectAddressCount) // nolint } return nil }) @@ -183,7 +183,7 @@ func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, // retry sending the packet until it succeeds err = retrier.Do(ctx, func() error { _, sockErr := sock.Write(bytesToSend) - return sockErr + return retry.WrapTemporaryError(sockErr) // nolint }) if err != nil { return errors.Wrap(err, "failed to write to dhcp socket") diff --git a/retry/retry.go b/retry/retry.go index 741d4ac521..7104aa91d8 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -61,7 +61,8 @@ type TemporaryError interface { } // Retrier is a construct for attempting some operation multiple times with a -// configurable backoff strategy. +// configurable backoff strategy. To retry, a returned error must implement the +// TemporaryError interface and return true type Retrier struct { Cooldown CooldownFactory }