diff --git a/cni/network/network.go b/cni/network/network.go index 6b0635e1c7..0106af68da 100644 --- a/cni/network/network.go +++ b/cni/network/network.go @@ -131,8 +131,10 @@ func NewPlugin(name string, } nl := netlink.NewNetlink() + plc := platform.NewExecClient(logger) + nio := &netio.NetIO{} // Setup network manager. - nm, err := network.NewNetworkManager(nl, platform.NewExecClient(logger), &netio.NetIO{}, network.NewNamespaceClient(), iptables.NewClient(), dhcp.New(logger)) + nm, err := network.NewNetworkManager(nl, plc, nio, network.NewNamespaceClient(), iptables.NewClient(), dhcp.New(logger, nio)) if err != nil { return nil, err } @@ -144,7 +146,7 @@ func NewPlugin(name string, nm: nm, nnsClient: client, multitenancyClient: multitenancyClient, - netClient: &netio.NetIO{}, + netClient: nio, }, nil } @@ -1526,7 +1528,7 @@ func (plugin *NetPlugin) validateArgs(args *cniSkel.CmdArgs, nwCfg *cni.NetworkC if !allowedInput.MatchString(args.ContainerID) || !allowedInput.MatchString(args.IfName) { return errors.New("invalid args value") } - if !allowedInput.MatchString(nwCfg.Bridge) { + if !allowedInput.MatchString(nwCfg.Bridge) || !allowedInput.MatchString(nwCfg.Master) { return errors.New("invalid network config value") } diff --git a/dhcp/dhcp.go b/dhcp/dhcp.go new file mode 100644 index 0000000000..e74dc1c77c --- /dev/null +++ b/dhcp/dhcp.go @@ -0,0 +1,265 @@ +package dhcp + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "io" + "net" + "time" + + "github.com/pkg/errors" + "go.uber.org/zap" + "golang.org/x/net/ipv4" +) + +const ( + dhcpDiscover = 1 + bootRequest = 1 + ethPAll = 0x0003 + MaxUDPReceivedPacketSize = 8192 + dhcpServerPort = 67 + dhcpClientPort = 68 + dhcpOpCodeReply = 2 + bootpMinLen = 300 + bytesInAddress = 4 // bytes in an ip address + macBytes = 6 // bytes in a mac address + udpProtocol = 17 + + opRequest = 1 + htypeEthernet = 1 + hlenEthernet = 6 + hops = 0 + secs = 0 + flags = 0x8000 // Broadcast flag +) + +// TransactionID represents a 4-byte DHCP transaction ID as defined in RFC 951, +// Section 3. +// +// The TransactionID is used to match DHCP replies to their original request. +type TransactionID [4]byte + +var ( + magicCookie = []byte{0x63, 0x82, 0x53, 0x63} // DHCP magic cookie + DefaultReadTimeout = 3 * time.Second + DefaultTimeout = 3 * time.Second +) + +type ExecClient interface { + ExecuteCommand(ctx context.Context, command string, args ...string) (string, error) +} + +type NetIOClient interface { + GetNetworkInterfaceByName(name string) (*net.Interface, error) + GetNetworkInterfaceAddrs(iface *net.Interface) ([]net.Addr, error) +} + +type DHCP struct { + logger *zap.Logger + netioClient NetIOClient +} + +func New(logger *zap.Logger, netio NetIOClient) *DHCP { + return &DHCP{ + logger: logger, + netioClient: netio, + } +} + +// GenerateTransactionID generates a random 32-bits number suitable for use as TransactionID +func GenerateTransactionID() (TransactionID, error) { + var xid TransactionID + _, err := rand.Read(xid[:]) + if err != nil { + return xid, errors.Errorf("could not get random number: %v", err) + } + return xid, nil +} + +// Build DHCP Discover Packet +func buildDHCPDiscover(mac net.HardwareAddr, txid TransactionID) ([]byte, error) { + if len(mac) != macBytes { + return nil, errors.Errorf("invalid MAC address length") + } + + var packet bytes.Buffer + + // BOOTP header + packet.WriteByte(opRequest) // op: BOOTREQUEST (1) + packet.WriteByte(htypeEthernet) // htype: Ethernet (1) + packet.WriteByte(hlenEthernet) // hlen: MAC address length (6) + packet.WriteByte(hops) // hops: 0 + packet.Write(txid[:]) // xid: Transaction ID (4 bytes) + err := binary.Write(&packet, binary.BigEndian, uint16(secs)) // secs: Seconds elapsed + if err != nil { + return nil, errors.Wrap(err, "failed to write seconds elapsed") + } + err = binary.Write(&packet, binary.BigEndian, uint16(flags)) // flags: Broadcast flag + if err != nil { + return nil, errors.Wrap(err, "failed to write broadcast flag") + } + + // Client IP address (0.0.0.0) + packet.Write(make([]byte, bytesInAddress)) + // Your IP address (0.0.0.0) + packet.Write(make([]byte, bytesInAddress)) + // Server IP address (0.0.0.0) + packet.Write(make([]byte, bytesInAddress)) + // Gateway IP address (0.0.0.0) + packet.Write(make([]byte, bytesInAddress)) + + // chaddr: Client hardware address (MAC address) + paddingBytes := 10 + packet.Write(mac) // MAC address (6 bytes) + packet.Write(make([]byte, paddingBytes)) // Padding to 16 bytes + + // sname: Server host name (64 bytes) + serverHostNameBytes := 64 + packet.Write(make([]byte, serverHostNameBytes)) + // file: Boot file name (128 bytes) + bootFileNameBytes := 128 + packet.Write(make([]byte, bootFileNameBytes)) + + // Magic cookie (DHCP) + err = binary.Write(&packet, binary.BigEndian, magicCookie) + if err != nil { + return nil, errors.Wrap(err, "failed to write magic cookie") + } + + // DHCP options (minimal required options for DISCOVER) + packet.Write([]byte{ + 53, 1, 1, // Option 53: DHCP Message Type (1 = DHCP Discover) + 55, 3, 1, 3, 6, // Option 55: Parameter Request List (1 = Subnet Mask, 3 = Router, 6 = DNS) + 255, // End option + }) + + // padding length to 300 bytes + var value uint8 // default is zero + if packet.Len() < bootpMinLen { + packet.Write(bytes.Repeat([]byte{value}, bootpMinLen-packet.Len())) + } + + return packet.Bytes(), nil +} + +// MakeRawUDPPacket converts a payload (a serialized packet) into a +// raw UDP packet for the specified serverAddr from the specified clientAddr. +func MakeRawUDPPacket(payload []byte, serverAddr, clientAddr net.UDPAddr) ([]byte, error) { + udpBytes := 8 + udp := make([]byte, udpBytes) + binary.BigEndian.PutUint16(udp[:2], uint16(clientAddr.Port)) + binary.BigEndian.PutUint16(udp[2:4], uint16(serverAddr.Port)) + totalLen := uint16(udpBytes + len(payload)) + binary.BigEndian.PutUint16(udp[4:6], totalLen) + binary.BigEndian.PutUint16(udp[6:8], 0) // try to offload the checksum + + headerVersion := 4 + headerLen := 20 + headerTTL := 64 + + h := ipv4.Header{ + Version: headerVersion, // nolint + Len: headerLen, // nolint + TotalLen: headerLen + len(udp) + len(payload), + TTL: headerTTL, + Protocol: udpProtocol, // UDP + Dst: serverAddr.IP, + Src: clientAddr.IP, + } + ret, err := h.Marshal() + if err != nil { + return nil, errors.Wrap(err, "failed to marshal when making udp packet") + } + ret = append(ret, udp...) + ret = append(ret, payload...) + return ret, nil +} + +// Receive DHCP response packet using reader +func (c *DHCP) receiveDHCPResponse(ctx context.Context, reader io.ReadCloser, xid TransactionID) error { + recvErrors := make(chan error, 1) + // Recvfrom is a blocking call, so if something goes wrong with its timeout it won't return. + + // Additionally, the timeout on the socket (on the Read(...)) call is how long until the socket times out and gives an error, + // but it won't error if we do get some sort of data within the time out period. + + // If we get some data (even if it is not the packet we are looking for, like wrong txid, wrong response opcode etc.) + // then we continue in the for loop. We then call recvfrom again which will reset the timeout period + // Without the secondary timeout at the bottom of the function, we could stay stuck in the for loop as long as we receive packets. + go func(errs chan<- error) { + // loop will only exit if there is an error, context canceled, or we find our reply packet + for { + if ctx.Err() != nil { + errs <- ctx.Err() + return + } + + buf := make([]byte, MaxUDPReceivedPacketSize) + // Blocks until data received or timeout period is reached + n, innerErr := reader.Read(buf) + if innerErr != nil { + errs <- innerErr + return + } + // check header + var iph ipv4.Header + if err := iph.Parse(buf[:n]); err != nil { + // skip non-IP data + continue + } + if iph.Protocol != udpProtocol { + // skip non-UDP packets + continue + } + udph := buf[iph.Len:n] + // source is from dhcp server if receiving + srcPort := int(binary.BigEndian.Uint16(udph[0:2])) + if srcPort != dhcpServerPort { + continue + } + // client is to dhcp client if receiving + dstPort := int(binary.BigEndian.Uint16(udph[2:4])) + if dstPort != dhcpClientPort { + continue + } + // check payload + pLen := int(binary.BigEndian.Uint16(udph[4:6])) + payload := buf[iph.Len+8 : iph.Len+pLen] + + // retrieve opcode from payload + opcode := payload[0] // opcode is first byte + // retrieve txid from payload + txidOffset := 4 // after 4 bytes, the txid starts + // the txid is 4 bytes, so we take four bytes after the offset + txid := payload[txidOffset : txidOffset+4] + + c.logger.Info("Received packet", zap.Int("opCode", int(opcode)), zap.Any("transactionID", TransactionID(txid))) + if opcode != dhcpOpCodeReply { + continue // opcode is not a reply, so continue + } + c.logger.Info("Received DHCP reply packet", zap.Int("opCode", int(opcode)), zap.Any("transactionID", TransactionID(txid))) + if TransactionID(txid) == xid { + break + } + } + // only occurs if we find our reply packet successfully + // a nil error means a reply was found for this txid + recvErrors <- nil + }(recvErrors) + + // sends a message on repeat after timeout, but only the first one matters + ticker := time.NewTicker(DefaultReadTimeout) + defer ticker.Stop() + + select { + case err := <-recvErrors: + if err != nil { + return errors.Wrap(err, "error during receiving") + } + case <-ticker.C: + return errors.New("timed out waiting for replies") + } + return nil +} diff --git a/dhcp/dhcp_linux.go b/dhcp/dhcp_linux.go index 9e7a029c05..5e71cc4264 100644 --- a/dhcp/dhcp_linux.go +++ b/dhcp/dhcp_linux.go @@ -4,9 +4,7 @@ package dhcp import ( - "bytes" "context" - "crypto/rand" "encoding/binary" "io" "net" @@ -14,53 +12,9 @@ import ( "github.com/pkg/errors" "go.uber.org/zap" - "golang.org/x/net/ipv4" "golang.org/x/sys/unix" ) -const ( - dhcpDiscover = 1 - bootRequest = 1 - ethPAll = 0x0003 - MaxUDPReceivedPacketSize = 8192 - dhcpServerPort = 67 - dhcpClientPort = 68 - dhcpOpCodeReply = 2 - bootpMinLen = 300 - bytesInAddress = 4 // bytes in an ip address - macBytes = 6 // bytes in a mac address - udpProtocol = 17 - - opRequest = 1 - htypeEthernet = 1 - hlenEthernet = 6 - hops = 0 - secs = 0 - flags = 0x8000 // Broadcast flag -) - -// TransactionID represents a 4-byte DHCP transaction ID as defined in RFC 951, -// Section 3. -// -// The TransactionID is used to match DHCP replies to their original request. -type TransactionID [4]byte - -var ( - magicCookie = []byte{0x63, 0x82, 0x53, 0x63} // DHCP magic cookie - DefaultReadTimeout = 3 * time.Second - DefaultTimeout = 3 * time.Second -) - -type DHCP struct { - logger *zap.Logger -} - -func New(logger *zap.Logger) *DHCP { - return &DHCP{ - logger: logger, - } -} - type Socket struct { fd int remoteAddr unix.SockaddrInet4 @@ -122,16 +76,6 @@ func (s *Socket) Close() error { return nil } -// GenerateTransactionID generates a random 32-bits number suitable for use as TransactionID -func GenerateTransactionID() (TransactionID, error) { - var xid TransactionID - _, err := rand.Read(xid[:]) - if err != nil { - return xid, errors.Errorf("could not get random number: %v", err) - } - return xid, nil -} - func makeListeningSocket(ifname string, timeout time.Duration) (int, error) { // reference: https://manned.org/packet.7 // starts listening to the specified protocol, or none if zero @@ -204,192 +148,6 @@ func makeRawSocket(ifname string) (int, error) { return fd, nil } -// Build DHCP Discover Packet -func buildDHCPDiscover(mac net.HardwareAddr, txid TransactionID) ([]byte, error) { - if len(mac) != macBytes { - return nil, errors.Errorf("invalid MAC address length") - } - - var packet bytes.Buffer - - // BOOTP header - packet.WriteByte(opRequest) // op: BOOTREQUEST (1) - packet.WriteByte(htypeEthernet) // htype: Ethernet (1) - packet.WriteByte(hlenEthernet) // hlen: MAC address length (6) - packet.WriteByte(hops) // hops: 0 - packet.Write(txid[:]) // xid: Transaction ID (4 bytes) - err := binary.Write(&packet, binary.BigEndian, uint16(secs)) // secs: Seconds elapsed - if err != nil { - return nil, errors.Wrap(err, "failed to write seconds elapsed") - } - err = binary.Write(&packet, binary.BigEndian, uint16(flags)) // flags: Broadcast flag - if err != nil { - return nil, errors.Wrap(err, "failed to write broadcast flag") - } - - // Client IP address (0.0.0.0) - packet.Write(make([]byte, bytesInAddress)) - // Your IP address (0.0.0.0) - packet.Write(make([]byte, bytesInAddress)) - // Server IP address (0.0.0.0) - packet.Write(make([]byte, bytesInAddress)) - // Gateway IP address (0.0.0.0) - packet.Write(make([]byte, bytesInAddress)) - - // chaddr: Client hardware address (MAC address) - paddingBytes := 10 - packet.Write(mac) // MAC address (6 bytes) - packet.Write(make([]byte, paddingBytes)) // Padding to 16 bytes - - // sname: Server host name (64 bytes) - serverHostNameBytes := 64 - packet.Write(make([]byte, serverHostNameBytes)) - // file: Boot file name (128 bytes) - bootFileNameBytes := 128 - packet.Write(make([]byte, bootFileNameBytes)) - - // Magic cookie (DHCP) - err = binary.Write(&packet, binary.BigEndian, magicCookie) - if err != nil { - return nil, errors.Wrap(err, "failed to write magic cookie") - } - - // DHCP options (minimal required options for DISCOVER) - packet.Write([]byte{ - 53, 1, 1, // Option 53: DHCP Message Type (1 = DHCP Discover) - 55, 3, 1, 3, 6, // Option 55: Parameter Request List (1 = Subnet Mask, 3 = Router, 6 = DNS) - 255, // End option - }) - - // padding length to 300 bytes - var value uint8 // default is zero - if packet.Len() < bootpMinLen { - packet.Write(bytes.Repeat([]byte{value}, bootpMinLen-packet.Len())) - } - - return packet.Bytes(), nil -} - -// MakeRawUDPPacket converts a payload (a serialized packet) into a -// raw UDP packet for the specified serverAddr from the specified clientAddr. -func MakeRawUDPPacket(payload []byte, serverAddr, clientAddr net.UDPAddr) ([]byte, error) { - udpBytes := 8 - udp := make([]byte, udpBytes) - binary.BigEndian.PutUint16(udp[:2], uint16(clientAddr.Port)) - binary.BigEndian.PutUint16(udp[2:4], uint16(serverAddr.Port)) - totalLen := uint16(udpBytes + len(payload)) - binary.BigEndian.PutUint16(udp[4:6], totalLen) - binary.BigEndian.PutUint16(udp[6:8], 0) // try to offload the checksum - - headerVersion := 4 - headerLen := 20 - headerTTL := 64 - - h := ipv4.Header{ - Version: headerVersion, // nolint - Len: headerLen, // nolint - TotalLen: headerLen + len(udp) + len(payload), - TTL: headerTTL, - Protocol: udpProtocol, // UDP - Dst: serverAddr.IP, - Src: clientAddr.IP, - } - ret, err := h.Marshal() - if err != nil { - return nil, errors.Wrap(err, "failed to marshal when making udp packet") - } - ret = append(ret, udp...) - ret = append(ret, payload...) - return ret, nil -} - -// Receive DHCP response packet using reader -func (c *DHCP) receiveDHCPResponse(ctx context.Context, reader io.ReadCloser, xid TransactionID) error { - recvErrors := make(chan error, 1) - // Recvfrom is a blocking call, so if something goes wrong with its timeout it won't return. - - // Additionally, the timeout on the socket (on the Read(...)) call is how long until the socket times out and gives an error, - // but it won't error if we do get some sort of data within the time out period. - - // If we get some data (even if it is not the packet we are looking for, like wrong txid, wrong response opcode etc.) - // then we continue in the for loop. We then call recvfrom again which will reset the timeout period - // Without the secondary timeout at the bottom of the function, we could stay stuck in the for loop as long as we receive packets. - go func(errs chan<- error) { - // loop will only exit if there is an error, context canceled, or we find our reply packet - for { - if ctx.Err() != nil { - errs <- ctx.Err() - return - } - - buf := make([]byte, MaxUDPReceivedPacketSize) - // Blocks until data received or timeout period is reached - n, innerErr := reader.Read(buf) - if innerErr != nil { - errs <- innerErr - return - } - // check header - var iph ipv4.Header - if err := iph.Parse(buf[:n]); err != nil { - // skip non-IP data - continue - } - if iph.Protocol != udpProtocol { - // skip non-UDP packets - continue - } - udph := buf[iph.Len:n] - // source is from dhcp server if receiving - srcPort := int(binary.BigEndian.Uint16(udph[0:2])) - if srcPort != dhcpServerPort { - continue - } - // client is to dhcp client if receiving - dstPort := int(binary.BigEndian.Uint16(udph[2:4])) - if dstPort != dhcpClientPort { - continue - } - // check payload - pLen := int(binary.BigEndian.Uint16(udph[4:6])) - payload := buf[iph.Len+8 : iph.Len+pLen] - - // retrieve opcode from payload - opcode := payload[0] // opcode is first byte - // retrieve txid from payload - txidOffset := 4 // after 4 bytes, the txid starts - // the txid is 4 bytes, so we take four bytes after the offset - txid := payload[txidOffset : txidOffset+4] - - c.logger.Info("Received packet", zap.Int("opCode", int(opcode)), zap.Any("transactionID", TransactionID(txid))) - if opcode != dhcpOpCodeReply { - continue // opcode is not a reply, so continue - } - - if TransactionID(txid) == xid { - break - } - } - // only occurs if we find our reply packet successfully - // a nil error means a reply was found for this txid - recvErrors <- nil - }(recvErrors) - - // sends a message on repeat after timeout, but only the first one matters - ticker := time.NewTicker(DefaultReadTimeout) - defer ticker.Stop() - - select { - case err := <-recvErrors: - if err != nil { - return errors.Wrap(err, "error during receiving") - } - case <-ticker.C: - return errors.New("timed out waiting for replies") - } - return nil -} - // Issues a DHCP Discover packet from the nic specified by mac and name ifname // Returns nil if a reply to the transaction was received, or error if time out // Does not return the DHCP Offer that was received from the DHCP server diff --git a/dhcp/dhcp_windows.go b/dhcp/dhcp_windows.go index 7b23dbeeff..a1b66068c1 100644 --- a/dhcp/dhcp_windows.go +++ b/dhcp/dhcp_windows.go @@ -3,20 +3,199 @@ package dhcp import ( "context" "net" + "time" + "github.com/Azure/azure-container-networking/retry" + "github.com/pkg/errors" "go.uber.org/zap" + "golang.org/x/sys/windows" ) -type DHCP struct { - logger *zap.Logger +const ( + retryCount = 4 + retryDelay = 500 * time.Millisecond + ipAssignRetryDelay = 2000 * time.Millisecond + socketTimeoutMillis = 1000 +) + +var ( + errInvalidIPv4Address = errors.New("invalid ipv4 address") + errIncorrectAddressCount = errors.New("address count found not equal to expected") +) + +type Socket struct { + fd windows.Handle + destAddr windows.SockaddrInet4 } -func New(logger *zap.Logger) *DHCP { - return &DHCP{ - logger: logger, +func NewSocket(destAddr windows.SockaddrInet4) (*Socket, error) { + // Create a raw socket using windows.WSASocket + fd, err := windows.WSASocket(windows.AF_INET, windows.SOCK_RAW, windows.IPPROTO_UDP, nil, 0, windows.WSA_FLAG_OVERLAPPED) + ret := &Socket{ + fd: fd, + destAddr: destAddr, + } + if err != nil { + return ret, errors.Wrap(err, "error creating socket") + } + + // Set IP_HDRINCL to indicate that we are including our own IP header + err = windows.SetsockoptInt(fd, windows.IPPROTO_IP, windows.IP_HDRINCL, 1) + if err != nil { + return ret, errors.Wrap(err, "error setting IP_HDRINCL") } + // Set the SO_BROADCAST option or else we get an error saying that we access a socket in a way forbidden by its access perms + err = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_BROADCAST, 1) + if err != nil { + return ret, errors.Wrap(err, "error setting SO_BROADCAST") + } + // Set timeout + if err = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVTIMEO, socketTimeoutMillis); err != nil { + return ret, errors.Wrap(err, "error setting receive timeout") + } + return ret, nil +} + +func (s *Socket) Write(packetBytes []byte) (int, error) { + err := windows.Sendto(s.fd, packetBytes, 0, &s.destAddr) + if err != nil { + return 0, errors.Wrap(err, "failed windows send to") + } + return len(packetBytes), nil } -func (c *DHCP) DiscoverRequest(_ context.Context, _ net.HardwareAddr, _ string) error { +func (s *Socket) Read(p []byte) (n int, err error) { + n, _, innerErr := windows.Recvfrom(s.fd, p, 0) + if innerErr != nil { + return 0, errors.Wrap(err, "failed windows recv from") + } + return n, nil +} + +func (s *Socket) Close() error { + // do not attempt to close invalid fd (happens on socket creation failure) + if s.fd == windows.InvalidHandle { + return nil + } + // Ensure the file descriptor is closed when done + if err := windows.Closesocket(s.fd); err != nil { + return errors.Wrap(err, "error closing dhcp windows socket") + } + return nil +} + +func (c *DHCP) getIPv4InterfaceAddresses(ifName string) ([]net.IP, error) { + nic, err := c.netioClient.GetNetworkInterfaceByName(ifName) + if err != nil { + return []net.IP{}, errors.Wrap(err, "failed to get interface by name to find ipv4 addresses") + } + addresses, err := c.netioClient.GetNetworkInterfaceAddrs(nic) + if err != nil { + return []net.IP{}, errors.Wrap(err, "failed to get interface addresses") + } + ret := []net.IP{} + for _, address := range addresses { + // check if the ip is ipv4 and parse it + ip, _, cidrErr := net.ParseCIDR(address.String()) + if cidrErr != nil || ip.To4() == nil { + continue + } + ret = append(ret, ip) + } + + c.logger.Info("Interface addresses found", zap.Any("foundIPs", addresses), zap.Any("selectedIPs", ret)) + return ret, nil +} + +func (c *DHCP) verifyIPv4InterfaceAddressCount(ctx context.Context, ifName string, count, numRetries int, sleep time.Duration) error { + retrier := retry.Retrier{ + Cooldown: retry.Max(numRetries, retry.Fixed(sleep)), + } + addressCountErr := retrier.Do(ctx, func() error { + addresses, err := c.getIPv4InterfaceAddresses(ifName) + if err != nil || len(addresses) != count { + return retry.WrapTemporaryError(errIncorrectAddressCount) // nolint + } + return nil + }) + return errors.Wrap(addressCountErr, "failed to verify interface ipv4 address count") +} + +// issues a dhcp discover request on an interface by finding the secondary's ip and sending on its ip +func (c *DHCP) DiscoverRequest(ctx context.Context, macAddress net.HardwareAddr, ifName string) error { + // Find the ipv4 address of the secondary interface (we're betting that this gets autoconfigured) + err := c.verifyIPv4InterfaceAddressCount(ctx, ifName, 1, retryCount, ipAssignRetryDelay) + if err != nil { + return errors.Wrap(err, "failed to get auto ip config assigned in apipa range in time") + } + ipv4Addresses, err := c.getIPv4InterfaceAddresses(ifName) + if err != nil || len(ipv4Addresses) == 0 { + return errors.Wrap(err, "failed to get ipv4 addresses on interface") + } + uniqueAddress := ipv4Addresses[0].To4() + if uniqueAddress == nil { + return errInvalidIPv4Address + } + uniqueAddressStr := uniqueAddress.String() + c.logger.Info("Retrieved automatic ip configuration: ", zap.Any("ip", uniqueAddress), zap.String("ipStr", uniqueAddressStr)) + + // now begin the dhcp request + txid, err := GenerateTransactionID() + if err != nil { + return errors.Wrap(err, "failed to generate transaction id") + } + + // Prepare an IP and UDP header + raddr := &net.UDPAddr{IP: net.IPv4bcast, Port: dhcpServerPort} + laddr := &net.UDPAddr{IP: uniqueAddress, Port: dhcpClientPort} + + dhcpDiscover, err := buildDHCPDiscover(macAddress, txid) + if err != nil { + return errors.Wrap(err, "failed to build dhcp discover") + } + + // Fill out the headers, add payload, and construct the full packet + bytesToSend, err := MakeRawUDPPacket(dhcpDiscover, *raddr, *laddr) + if err != nil { + return errors.Wrap(err, "failed to make raw udp packet") + } + + destAddr := windows.SockaddrInet4{ + Addr: [4]byte{255, 255, 255, 255}, // Destination IP + Port: dhcpServerPort, // Destination Port + } + // create new socket for writing and reading + sock, err := NewSocket(destAddr) + defer func() { + // always clean up the socket, even if we fail while setting options + closeErr := sock.Close() + if closeErr != nil { + c.logger.Error("Error closing dhcp socket:", zap.Error(closeErr)) + } + }() + if err != nil { + return errors.Wrap(err, "failed to create socket") + } + + retrier := retry.Retrier{ + Cooldown: retry.Max(retryCount, retry.Fixed(retryDelay)), + } + // retry sending the packet until it succeeds + err = retrier.Do(ctx, func() error { + _, sockErr := sock.Write(bytesToSend) + return retry.WrapTemporaryError(sockErr) // nolint + }) + if err != nil { + return errors.Wrap(err, "failed to write to dhcp socket") + } + + c.logger.Info("DHCP Discover packet was sent successfully", zap.Any("transactionID", txid)) + + // Wait for DHCP response (Offer) + err = c.receiveDHCPResponse(ctx, sock, txid) + if err != nil { + return errors.Wrap(err, "failed to read from dhcp socket") + } + return nil } diff --git a/network/network_windows.go b/network/network_windows.go index 4d870827ca..f3614fc1c3 100644 --- a/network/network_windows.go +++ b/network/network_windows.go @@ -4,8 +4,10 @@ package network import ( + "context" "encoding/json" "fmt" + "net" "strconv" "strings" "time" @@ -44,6 +46,7 @@ const ( defaultIPv6Route = "::/0" // Default IPv6 nextHop defaultIPv6NextHop = "fe80::1234:5678:9abc" + dhcpTimeout = 15 * time.Second ) // Windows implementation of route. @@ -433,8 +436,29 @@ func (nm *networkManager) newNetworkImplHnsV2(nwInfo *EndpointInfo, extIf *exter return nw, nil } +func (nm *networkManager) sendDHCPDiscoverOnSecondary(client dhcpClient, mac net.HardwareAddr, ifName string) error { + // issue dhcp discover packet to ensure mapping created for dns via wireserver to work + // we do not use the response for anything + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(dhcpTimeout)) + defer cancel() + logger.Info("Sending DHCP packet", zap.Any("macAddress", mac), zap.String("ifName", ifName)) + err := client.DiscoverRequest(ctx, mac, ifName) + if err != nil { + return errors.Wrapf(err, "failed to issue dhcp discover packet to create mapping in host") + } + logger.Info("Successfully received DHCP reply packet") + return nil +} + // NewNetworkImpl creates a new container network. func (nm *networkManager) newNetworkImpl(nwInfo *EndpointInfo, extIf *externalInterface) (*network, error) { + if nwInfo.NICType == cns.NodeNetworkInterfaceFrontendNIC { + // use master interface name, interface name, or adapter name? + if err := nm.sendDHCPDiscoverOnSecondary(nm.dhcpClient, nwInfo.MacAddress, nwInfo.MasterIfName); err != nil { + return nil, err + } + } + if useHnsV2, err := UseHnsV2(nwInfo.NetNs); useHnsV2 { if err != nil { return nil, err diff --git a/network/transparent_vlan_endpointclient_linux.go b/network/transparent_vlan_endpointclient_linux.go index 731353c231..05e7c6d458 100644 --- a/network/transparent_vlan_endpointclient_linux.go +++ b/network/transparent_vlan_endpointclient_linux.go @@ -1,6 +1,7 @@ package network import ( + "context" "fmt" "net" "strings" @@ -13,6 +14,7 @@ import ( "github.com/Azure/azure-container-networking/network/networkutils" "github.com/Azure/azure-container-networking/network/snat" "github.com/Azure/azure-container-networking/platform" + "github.com/Azure/azure-container-networking/retry" "github.com/pkg/errors" vishnetlink "github.com/vishvananda/netlink" "go.uber.org/zap" @@ -26,8 +28,8 @@ const ( tunnelingTable = 2 // Packets not entering on the vlan interface go to this routing table tunnelingMark = 333 // The packets that are to tunnel will be marked with this number DisableRPFilterCmd = "sysctl -w net.ipv4.conf.all.rp_filter=0" // Command to disable the rp filter for tunneling - numRetries = 5 - sleepInMs = 100 + numRetries = 4 + sleepDelay = 100 * time.Millisecond ) var errNamespaceCreation = fmt.Errorf("network namespace creation error") @@ -192,13 +194,13 @@ func (client *TransparentVlanEndpointClient) setLinkNetNSAndConfirm(name string, } // confirm veth was moved successfully - err = RunWithRetries(func() error { + err = client.Retry(func() error { // retry checking in the namespace if the interface is not detected return ExecuteInNS(client.nsClient, client.vnetNSName, func() error { _, ifDetectedErr := client.netioshim.GetNetworkInterfaceByName(client.vlanIfName) return errors.Wrap(ifDetectedErr, "failed to get vlan veth in namespace") }) - }, numRetries, sleepInMs) + }) if err != nil { return errors.Wrapf(err, "failed to detect %v inside namespace %v", name, fd) } @@ -220,9 +222,9 @@ func (client *TransparentVlanEndpointClient) PopulateVM(epInfo *EndpointInfo) er // We assume the only possible error is that the namespace doesn't exist logger.Info("No existing NS detected. Creating the vnet namespace and switching to it", zap.String("message", existingErr.Error())) - err = RunWithRetries(func() error { + err = client.Retry(func() error { return client.createNetworkNamespace(vmNS) - }, numRetries, sleepInMs) + }) if err != nil { return errors.Wrap(err, "failed to create network namespace") } @@ -279,10 +281,10 @@ func (client *TransparentVlanEndpointClient) PopulateVM(epInfo *EndpointInfo) er }() // sometimes there is slight delay in interface creation. check if it exists - err = RunWithRetries(func() error { + err = client.Retry(func() error { _, err = client.netioshim.GetNetworkInterfaceByName(client.vlanIfName) return errors.Wrap(err, "failed to get vlan veth") - }, numRetries, sleepInMs) + }) if err != nil { deleteNSIfNotNilErr = errors.Wrapf(err, "failed to get vlan veth interface:%s", client.vlanIfName) @@ -316,21 +318,21 @@ func (client *TransparentVlanEndpointClient) PopulateVM(epInfo *EndpointInfo) er } // Ensure vnet veth is created, as there may be a slight delay - err = RunWithRetries(func() error { + err = client.Retry(func() error { _, getErr := client.netioshim.GetNetworkInterfaceByName(client.vnetVethName) return errors.Wrap(getErr, "failed to get vnet veth") - }, numRetries, sleepInMs) + }) if err != nil { return errors.Wrap(err, "vnet veth does not exist") } // Ensure container veth is created, as there may be a slight delay var containerIf *net.Interface - err = RunWithRetries(func() error { + err = client.Retry(func() error { var getErr error containerIf, getErr = client.netioshim.GetNetworkInterfaceByName(client.containerVethName) return errors.Wrap(getErr, "failed to get container veth") - }, numRetries, sleepInMs) + }) if err != nil { return errors.Wrap(err, "container veth does not exist") } @@ -673,6 +675,17 @@ func (client *TransparentVlanEndpointClient) DeleteEndpointsImpl(ep *endpoint, _ return nil } +// Creates a new retrier with a fixed delay, and treats all errors as retriable +func (client *TransparentVlanEndpointClient) Retry(f func() error) error { + retrier := retry.Retrier{ + Cooldown: retry.Max(numRetries, retry.Fixed(sleepDelay)), + } + return errors.Wrap(retrier.Do(context.Background(), func() error { + // we always want to retry, so all errors are temporary errors + return retry.WrapTemporaryError(f()) // nolint + }), "error during retry") +} + // Helper function that allows executing a function in a VM namespace // Does not work for process namespaces func ExecuteInNS(nsc NamespaceClientInterface, nsName string, f func() error) error { @@ -712,16 +725,3 @@ func ExecuteInNS(nsc NamespaceClientInterface, nsName string, f func() error) er }() return f() } - -func RunWithRetries(f func() error, maxRuns, sleepMs int) error { - var err error - for i := 0; i < maxRuns; i++ { - err = f() - if err == nil { - break - } - logger.Info("Retrying after delay...", zap.String("error", err.Error()), zap.Int("retry", i), zap.Int("sleepMs", sleepMs)) - time.Sleep(time.Duration(sleepMs) * time.Millisecond) - } - return err -} diff --git a/network/transparent_vlan_endpointclient_linux_test.go b/network/transparent_vlan_endpointclient_linux_test.go index be64142bc5..fcf0d81a21 100644 --- a/network/transparent_vlan_endpointclient_linux_test.go +++ b/network/transparent_vlan_endpointclient_linux_test.go @@ -190,7 +190,7 @@ func TestTransparentVlanAddEndpoints(t *testing.T) { err := tt.client.setLinkNetNSAndConfirm(tt.client.vlanIfName, 1) if tt.wantErr { require.Error(t, err) - require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v actual:%v", tt.wantErrMsg, err.Error()) + require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v \nActual:%v", tt.wantErrMsg, err.Error()) } else { require.NoError(t, err) } @@ -287,7 +287,7 @@ func TestTransparentVlanAddEndpoints(t *testing.T) { err := tt.client.ensureCleanPopulateVM() if tt.wantErr { require.Error(t, err) - require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v actual:%v", tt.wantErrMsg, err.Error()) + require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v \nActual:%v", tt.wantErrMsg, err.Error()) } else { require.NoError(t, err) } @@ -429,7 +429,7 @@ func TestTransparentVlanAddEndpoints(t *testing.T) { }, epInfo: &EndpointInfo{}, wantErr: true, - wantErrMsg: "container veth does not exist: failed to get container veth: B1veth0: " + errMockNetIOFail.Error() + "", + wantErrMsg: "failed to get container veth: B1veth0: " + errMockNetIOFail.Error() + "", }, { name: "Add endpoints NetNS Get fail", @@ -489,7 +489,7 @@ func TestTransparentVlanAddEndpoints(t *testing.T) { err := tt.client.PopulateVM(tt.epInfo) if tt.wantErr { require.Error(t, err) - require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v actual:%v", tt.wantErrMsg, err.Error()) + require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v \nActual:%v", tt.wantErrMsg, err.Error()) } else { require.NoError(t, err) } @@ -579,7 +579,7 @@ func TestTransparentVlanAddEndpoints(t *testing.T) { err := tt.client.PopulateVnet(tt.epInfo) if tt.wantErr { require.Error(t, err) - require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v actual:%v", tt.wantErrMsg, err.Error()) + require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v \nActual:%v", tt.wantErrMsg, err.Error()) } else { require.NoError(t, err) } @@ -695,7 +695,7 @@ func TestTransparentVlanDeleteEndpoints(t *testing.T) { err := tt.client.DeleteEndpointsImpl(tt.ep, tt.routesLeft) if tt.wantErr { require.Error(t, err) - require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v actual:%v", tt.wantErrMsg, err.Error()) + require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v \nActual:%v", tt.wantErrMsg, err.Error()) } else { require.NoError(t, err) } @@ -830,7 +830,7 @@ func TestTransparentVlanConfigureContainerInterfacesAndRoutes(t *testing.T) { err := tt.client.ConfigureContainerInterfacesAndRoutesImpl(tt.epInfo) if tt.wantErr { require.Error(t, err) - require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v actual:%v", tt.wantErrMsg, err.Error()) + require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v \nActual:%v", tt.wantErrMsg, err.Error()) } else { require.NoError(t, err) } @@ -901,62 +901,7 @@ func TestTransparentVlanConfigureContainerInterfacesAndRoutes(t *testing.T) { err := tt.client.ConfigureVnetInterfacesAndRoutesImpl(tt.epInfo) if tt.wantErr { require.Error(t, err) - require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v actual:%v", tt.wantErrMsg, err.Error()) - } else { - require.NoError(t, err) - } - }) - } -} - -func createFunctionWithFailurePattern(errorPattern []error) func() error { - s := 0 - return func() error { - if s >= len(errorPattern) { - return nil - } - result := errorPattern[s] - s++ - return result - } -} - -func TestRunWithRetries(t *testing.T) { - errMock := errors.New("mock error") - runs := 4 - - tests := []struct { - name string - wantErr bool - f func() error - }{ - { - name: "Succeed on first try", - f: createFunctionWithFailurePattern([]error{}), - wantErr: false, - }, - { - name: "Succeed on first try do not check again", - f: createFunctionWithFailurePattern([]error{nil, errMock, errMock, errMock}), - wantErr: false, - }, - { - name: "Succeed on last try", - f: createFunctionWithFailurePattern([]error{errMock, errMock, errMock, nil, errMock}), - wantErr: false, - }, - { - name: "Fail after too many attempts", - f: createFunctionWithFailurePattern([]error{errMock, errMock, errMock, errMock, nil, nil}), - wantErr: true, - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - err := RunWithRetries(tt.f, runs, 100) - if tt.wantErr { - require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErrMsg, "Expected:%v \nActual:%v", tt.wantErrMsg, err.Error()) } else { require.NoError(t, err) } diff --git a/nmagent/client.go b/nmagent/client.go index 71a0810978..805e4597bb 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -12,6 +12,7 @@ import ( "time" "github.com/Azure/azure-container-networking/nmagent/internal" + "github.com/Azure/azure-container-networking/retry" "github.com/pkg/errors" ) @@ -30,9 +31,9 @@ func NewClient(c Config) (*Client, error) { host: c.Host, port: c.Port, enableTLS: c.UseTLS, - retrier: internal.Retrier{ + retrier: retry.Retrier{ // nolint:gomnd // the base parameter is explained in the function - Cooldown: internal.Exponential(1*time.Second, 2), + Cooldown: retry.Exponential(1*time.Second, 2), }, } diff --git a/nmagent/client_helpers_test.go b/nmagent/client_helpers_test.go index 4f98959742..c659ea6f75 100644 --- a/nmagent/client_helpers_test.go +++ b/nmagent/client_helpers_test.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/Azure/azure-container-networking/nmagent/internal" + "github.com/Azure/azure-container-networking/retry" ) // NewTestClient is a factory function available in tests only for creating @@ -17,8 +18,8 @@ func NewTestClient(transport http.RoundTripper) *Client { }, host: "localhost", port: 12345, - retrier: internal.Retrier{ - Cooldown: internal.AsFastAsPossible(), + retrier: retry.Retrier{ + Cooldown: retry.AsFastAsPossible(), }, } } diff --git a/nmagent/internal/retry.go b/retry/retry.go similarity index 71% rename from nmagent/internal/retry.go rename to retry/retry.go index 9491aea0e4..7104aa91d8 100644 --- a/nmagent/internal/retry.go +++ b/retry/retry.go @@ -1,4 +1,4 @@ -package internal +package retry import ( "context" @@ -17,6 +17,42 @@ const ( ErrMaxAttempts = Error("maximum attempts reached") ) +// Error represents an internal sentinal error which can be defined as a +// constant. +type Error string + +func (e Error) Error() string { + return string(e) +} + +// RetriableError is an implementation of TemporaryError that is always retriable +type RetriableError struct { + err error +} + +func (r RetriableError) Error() string { + if r.err == nil { + return "" + } + return r.err.Error() +} + +func (r RetriableError) Unwrap() error { + return r.err +} + +func (r RetriableError) Temporary() bool { + return true +} + +// Forces the error to be retriable, returns nil if error is nil +func WrapTemporaryError(err error) error { + if err == nil { + return nil + } + return RetriableError{err: err} +} + // TemporaryError is an error that can indicate whether it may be resolved with // another attempt. type TemporaryError interface { @@ -25,7 +61,8 @@ type TemporaryError interface { } // Retrier is a construct for attempting some operation multiple times with a -// configurable backoff strategy. +// configurable backoff strategy. To retry, a returned error must implement the +// TemporaryError interface and return true type Retrier struct { Cooldown CooldownFactory } @@ -47,9 +84,9 @@ func (r Retrier) Do(ctx context.Context, run func() error) error { // check to see if it's temporary. var tempErr TemporaryError if ok := errors.As(err, &tempErr); ok && tempErr.Temporary() { - delay, err := cooldown() // nolint:govet // the shadow is intentional - if err != nil { - return pkgerrors.Wrap(err, "sleeping during retry") + delay, cooldownErr := cooldown() + if cooldownErr != nil { + return pkgerrors.Wrap(cooldownErr, "sleeping during retry, last error:"+err.Error()) } time.Sleep(delay) continue @@ -72,7 +109,9 @@ type CooldownFunc func() (time.Duration, error) type CooldownFactory func() CooldownFunc // Max provides a fixed limit for the number of times a subordinate cooldown -// function can be invoked. +// function can be invoked. Note if we set the limit to 0, we still +// invoke the target method once in the retrier, but do not retry +// Read this as the max number of *retries* func Max(limit int, factory CooldownFactory) CooldownFactory { return func() CooldownFunc { cooldown := factory() diff --git a/nmagent/internal/retry_example_test.go b/retry/retry_example_test.go similarity index 98% rename from nmagent/internal/retry_example_test.go rename to retry/retry_example_test.go index c66bc194fb..919e35dcf8 100644 --- a/nmagent/internal/retry_example_test.go +++ b/retry/retry_example_test.go @@ -1,4 +1,4 @@ -package internal +package retry import ( "fmt" diff --git a/nmagent/internal/retry_test.go b/retry/retry_test.go similarity index 59% rename from nmagent/internal/retry_test.go rename to retry/retry_test.go index 55824de38b..b5c92126a4 100644 --- a/nmagent/internal/retry_test.go +++ b/retry/retry_test.go @@ -1,12 +1,17 @@ -package internal +package retry import ( "context" "errors" "testing" "time" + + pkgerrors "github.com/pkg/errors" + "github.com/stretchr/testify/require" ) +var errTest = errors.New("mock error") + type TestError struct{} func (t TestError) Error() string { @@ -162,3 +167,77 @@ func TestMax(t *testing.T) { t.Errorf("expected an error after %d invocations but received none", exp+1) } } + +func TestRetriableError(t *testing.T) { + // wrapping nil returns a nil + require.NoError(t, WrapTemporaryError(nil)) + + wrappedMockError := WrapTemporaryError(pkgerrors.Wrap(errTest, "nested")) + + // temporary errors should still be able to be unwrapped + require.ErrorIs(t, wrappedMockError, errTest) + + var temporaryError TemporaryError + require.ErrorAs(t, wrappedMockError, &temporaryError) + require.True(t, temporaryError.Temporary(), "errors returned from wrap temporary error should have temporary set to true") +} + +func createFunctionWithFailurePattern(errorPattern []error) func() error { + s := 0 + return func() error { + if s >= len(errorPattern) { + return nil + } + result := errorPattern[s] + s++ + return result + } +} + +func TestRunWithRetries(t *testing.T) { + errMock := WrapTemporaryError(errTest) + retries := 3 // runs 4 times, then errors before the 5th + retrier := Retrier{ + Cooldown: Max(retries, Fixed(100*time.Millisecond)), + } + + tests := []struct { + name string + wantErr bool + f func() error + }{ + { + name: "Succeed on first try", + f: createFunctionWithFailurePattern([]error{}), + wantErr: false, + }, + { + name: "Succeed on first try do not check again", + f: createFunctionWithFailurePattern([]error{nil, errMock, errMock, errMock}), + wantErr: false, + }, + { + name: "Succeed on last try", + f: createFunctionWithFailurePattern([]error{errMock, errMock, errMock, nil, errMock}), + wantErr: false, + }, + { + name: "Fail after too many attempts", + f: createFunctionWithFailurePattern([]error{errMock, errMock, errMock, errMock, nil, nil}), + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + err := retrier.Do(context.Background(), tt.f) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +}