From 1bbdc6f652adfde63966fb42e27d68abf5d59d8c Mon Sep 17 00:00:00 2001 From: Rekseto Date: Thu, 25 Sep 2025 06:58:07 +0200 Subject: [PATCH 01/13] drafting reliable UDP module with connection management and flow control --- mod/udp/README.md | 54 +++++++ mod/udp/conn.go | 12 ++ mod/udp/endpoint.go | 147 ++++++++++++++++++ mod/udp/errors.go | 13 ++ mod/udp/ip.go | 97 ++++++++++++ mod/udp/module.go | 14 ++ mod/udp/src/config.go | 169 +++++++++++++++++++++ mod/udp/src/config_test.go | 78 ++++++++++ mod/udp/src/conn.go | 191 ++++++++++++++++++++++++ mod/udp/src/deps.go | 29 ++++ mod/udp/src/dial.go | 43 ++++++ mod/udp/src/endpoint_resolver.go | 22 +++ mod/udp/src/loader.go | 40 +++++ mod/udp/src/module.go | 81 ++++++++++ mod/udp/src/net.go | 7 + mod/udp/src/packet.go | 106 +++++++++++++ mod/udp/src/packet_test.go | 74 +++++++++ mod/udp/src/parse.go | 16 ++ mod/udp/src/plan.md | 209 ++++++++++++++++++++++++++ mod/udp/src/recv.go | 128 ++++++++++++++++ mod/udp/src/ring_buffer.go | 159 ++++++++++++++++++++ mod/udp/src/ring_buffer_test.go | 170 +++++++++++++++++++++ mod/udp/src/seg_meta.go | 12 ++ mod/udp/src/send.go | 249 +++++++++++++++++++++++++++++++ mod/udp/src/server.go | 96 ++++++++++++ mod/udp/src/timers.go | 45 ++++++ mod/udp/src/unpack.go | 25 ++++ 27 files changed, 2286 insertions(+) create mode 100644 mod/udp/README.md create mode 100644 mod/udp/conn.go create mode 100644 mod/udp/endpoint.go create mode 100644 mod/udp/errors.go create mode 100644 mod/udp/ip.go create mode 100644 mod/udp/module.go create mode 100644 mod/udp/src/config.go create mode 100644 mod/udp/src/config_test.go create mode 100644 mod/udp/src/conn.go create mode 100644 mod/udp/src/deps.go create mode 100644 mod/udp/src/dial.go create mode 100644 mod/udp/src/endpoint_resolver.go create mode 100644 mod/udp/src/loader.go create mode 100644 mod/udp/src/module.go create mode 100644 mod/udp/src/net.go create mode 100644 mod/udp/src/packet.go create mode 100644 mod/udp/src/packet_test.go create mode 100644 mod/udp/src/parse.go create mode 100644 mod/udp/src/plan.md create mode 100644 mod/udp/src/recv.go create mode 100644 mod/udp/src/ring_buffer.go create mode 100644 mod/udp/src/ring_buffer_test.go create mode 100644 mod/udp/src/seg_meta.go create mode 100644 mod/udp/src/send.go create mode 100644 mod/udp/src/server.go create mode 100644 mod/udp/src/timers.go create mode 100644 mod/udp/src/unpack.go diff --git a/mod/udp/README.md b/mod/udp/README.md new file mode 100644 index 000000000..90d5a39a5 --- /dev/null +++ b/mod/udp/README.md @@ -0,0 +1,54 @@ +# Reliable UDP Module + +## Overview +This module provides reliable, ordered, and stream-like communication over UDP. It ensures data integrity and delivery through acknowledgments, retransmissions, and a handshake protocol. The module is part of the Astral ecosystem and integrates with `exonet` for endpoint management. + +## File & Struct Map +- **config.go**: Defines configuration constants (e.g., retransmission timeouts, buffer sizes). +- **conn.go**: Implements the `Conn` struct, representing a reliable UDP connection. Handles sending, receiving, and retransmissions. +- **endpoint_resolver.go**: Resolves endpoints for the module, integrating with `exonet`. +- **loader.go**: Initializes the module with dependencies. +- **module.go**: Defines the `Module` struct, the entry point for the UDP module. +- **recv.go**: Implements the receive loop, processing incoming packets and acknowledgments. +- **send.go**: Handles segmentation, batching, and sending of data. +- **ring_buffer.go**: Provides a circular buffer for efficient data storage and retrieval. +- **packet.go**: Defines the `Packet` struct and serialization logic. +- **server.go**: Implements the `Server` struct, managing incoming connections and demultiplexing. + +## Current Findings & Considerations +- **ACK Handling**: The module uses cumulative acknowledgments to confirm receipt of data up to a specific sequence number. This simplifies state management but requires careful handling of retransmissions to avoid unnecessary duplicates. +- **RTO/Backoff**: Retransmission timeouts are implemented based on RFC 6298, with exponential backoff to handle varying network conditions. This ensures robustness in the face of packet loss. +- **Handshake**: The handshake protocol establishes connection state before data exchange. However, it currently lacks stateless cookie support, which could mitigate DoS attacks by ensuring that resources are only allocated for legitimate connections. + +## Datagram Structure +A datagram in this module is represented by the `Packet` struct. It includes the following fields: +- **Seq (uint32)**: Sequence number indicating the first byte of the segment. +- **Ack (uint32)**: Cumulative acknowledgment number, confirming receipt of all bytes up to this sequence number. +- **Flags (uint8)**: Control flags such as SYN, ACK, and FIN. +- **Win (uint16)**: Advertised receive window size in bytes. +- **Len (uint16)**: Length of the payload. +- **Payload ([]byte)**: The actual data being transmitted. + +### Fragmentation and Reassembly +- **Fragmentation**: Large application data is segmented into smaller packets, each fitting within the Maximum Segment Size (MSS). This ensures compatibility with network MTU limits and avoids IP-level fragmentation. +- **Reassembly**: On the receiving side, out-of-order packets are buffered and reassembled into the original data stream once all fragments are received. + +## Diagrams +### Handshake Protocol +- The handshake establishes a connection between two endpoints before data exchange. +- Steps: + 1. **SYN**: The initiator sends a SYN packet to start the handshake. + 2. **SYN|ACK**: The responder replies with a SYN|ACK packet, acknowledging the initiator's SYN and sending its own sequence number. + 3. **ACK**: The initiator sends an ACK packet to confirm the responder's sequence number. +- Once the ACK is received, the connection is established. + +### Data Flow +- **Write Path**: + - Application data is segmented into smaller packets (segmentation). + - Packets are serialized and sent over the network (packetization). +- **Network Transmission**: + - Packets are transmitted over the network, potentially out of order. +- **Read Path**: + - Received packets are reassembled into the original data stream (reassembly). + - Cumulative acknowledgments (ACKs) confirm receipt of data up to a specific sequence number. +- Retransmissions occur for lost packets based on retransmission timeouts (RTO). diff --git a/mod/udp/conn.go b/mod/udp/conn.go new file mode 100644 index 000000000..e0fb31032 --- /dev/null +++ b/mod/udp/conn.go @@ -0,0 +1,12 @@ +package udp + +// DatagramWriter is how Conn sends bytes to its peer. +type DatagramWriter interface { + WriteDatagram(b []byte) error +} + +// DatagramReceiver is how Conn *receives* parsed packets when it does not own a socket read loop. +// (For active conns, the recvLoop calls HandleDatagram itself.) +type DatagramReceiver interface { + HandleDatagram(raw []byte) // fast path: parse + process (ACK/data) +} diff --git a/mod/udp/endpoint.go b/mod/udp/endpoint.go new file mode 100644 index 000000000..94d2c54c5 --- /dev/null +++ b/mod/udp/endpoint.go @@ -0,0 +1,147 @@ +package udp + +import ( + "bytes" + "errors" + "io" + "net" + "strconv" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/term" + "github.com/cryptopunkscc/astrald/mod/exonet" +) + +var _ exonet.Endpoint = &Endpoint{} +var _ astral.Object = &Endpoint{} + +// NOTE: this is same for UDP and TCP - consider moving to a common package + +// Endpoint is an astral.Object that holds information about a UDP endpoint, +// i.e. an IP address and a port. +// Supports JSON and text. +type Endpoint struct { + IP IP + Port astral.Uint16 +} + +func (e *Endpoint) ObjectType() string { + return "mod.udp.endpoint" +} + +func (e Endpoint) WriteTo(w io.Writer) (n int64, err error) { + return astral.Struct(e).WriteTo(w) +} + +func (e *Endpoint) ReadFrom(r io.Reader) (n int64, err error) { + return astral.Struct(e).ReadFrom(r) +} + +// exonet.Endpoint + +func (e *Endpoint) Address() string { + return net.JoinHostPort(e.IP.String(), strconv.Itoa(int(e.Port))) +} + +func (e *Endpoint) Network() string { + return "udp" +} + +// HostString returns the IP address as a string +func (e *Endpoint) HostString() string { + return e.IP.String() +} + +// PortNumber returns the port number as an int +func (e *Endpoint) PortNumber() int { + return int(e.Port) +} + +func (e *Endpoint) Pack() []byte { + var b = &bytes.Buffer{} + if _, err := e.WriteTo(b); err != nil { + return nil + } + return b.Bytes() +} + +// Text marshaling + +func (e Endpoint) MarshalText() (text []byte, err error) { + return []byte(e.Address()), nil +} + +func (e *Endpoint) UnmarshalText(text []byte) error { + h, p, err := net.SplitHostPort(string(text)) + if err != nil { + return err + } + + ip, err := ParseIP(h) + if err != nil { + return err + } + + port, err := strconv.Atoi(p) + if err != nil { + return err + } + + // check if port fits in 16 bits + if (port >> 16) > 0 { + return errors.New("port out of range") + } + + e.IP = ip + e.Port = astral.Uint16(port) + + return nil +} + +// ... + +func (e *Endpoint) String() string { + return e.Address() +} + +func (e *Endpoint) IsZero() bool { + return e == nil || e.IP == nil +} + +func ParseEndpoint(s string) (*Endpoint, error) { + hostStr, portStr, err := net.SplitHostPort(s) + if err != nil { + return nil, err + } + + ip, err := ParseIP(hostStr) + if err != nil { + return nil, err + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, err + } + + // check if port fits in 16 bits + if (port >> 16) > 0 { + return nil, errors.New("port out of range") + } + + return &Endpoint{ + IP: ip, + Port: astral.Uint16(port), + }, nil +} + +func init() { + _ = astral.DefaultBlueprints.Add(&Endpoint{}) + + term.SetTranslateFunc(func(o *Endpoint) astral.Object { + return &term.ColorString{ + Color: term.HighlightColor, + Text: astral.String32(o.String()), + } + }) +} diff --git a/mod/udp/errors.go b/mod/udp/errors.go new file mode 100644 index 000000000..0b593d49a --- /dev/null +++ b/mod/udp/errors.go @@ -0,0 +1,13 @@ +package udp + +import "errors" + +var ( + ErrPacketTooShort = errors.New("packet too short") + ErrListenerClosed = errors.New("listener closed") + ErrConnClosed = errors.New("connection closed") + ErrInvalidPayloadLength = errors.New("invalid payload length") + ErrClosed = errors.New("connection closed") + ErrZeroMSS = errors.New("invalid MSS") + ErrMalformedPacket = errors.New("malformed packet") +) diff --git a/mod/udp/ip.go b/mod/udp/ip.go new file mode 100644 index 000000000..aa680a0ae --- /dev/null +++ b/mod/udp/ip.go @@ -0,0 +1,97 @@ +package udp + +import ( + "encoding/json" + "errors" + "io" + "net" + + "github.com/cryptopunkscc/astrald/astral" +) + +// NOTE: this is same for tcp / udp modules, consider moving to a common package +type IP net.IP + +func ParseIP(s string) (IP, error) { + return IP(net.ParseIP(s)), nil +} + +// astral + +func (IP) ObjectType() string { + return "mod.tcp.ip_address" +} + +func (ip IP) WriteTo(w io.Writer) (n int64, err error) { + if ip.IsIPv4() { + return astral.Bytes8(net.IP(ip).To4()).WriteTo(w) + } + + return astral.Bytes8(ip).WriteTo(w) +} + +func (ip *IP) ReadFrom(r io.Reader) (n int64, err error) { + return (*astral.Bytes8)(ip).ReadFrom(r) +} + +// json + +func (ip IP) MarshalJSON() ([]byte, error) { + return json.Marshal(ip.String()) +} + +func (ip *IP) UnmarshalJSON(b []byte) error { + var str string + err := json.Unmarshal(b, &str) + if err != nil { + return nil + } + + parsed := IP(net.ParseIP(str)) + + if parsed == nil { + return errors.New("invalid IP") + } + + *ip = parsed + return nil +} + +// text + +func (ip IP) MarshalText() (text []byte, err error) { + return []byte(ip.String()), nil +} + +func (ip *IP) UnmarshalText(text []byte) error { + parsed := IP(net.ParseIP(string(text))) + if parsed == nil { + return errors.New("invalid IP") + } + *ip = parsed + return nil +} + +// ... + +func (ip IP) IsIPv4() bool { + return net.IP(ip).To4() != nil +} + +func (ip IP) IsIPv6() bool { + return net.IP(ip).To16() != nil +} + +func (ip IP) IsLoopback() bool { return net.IP(ip).IsLoopback() } + +func (ip IP) IsGlobalUnicast() bool { return net.IP(ip).IsGlobalUnicast() } + +func (ip IP) IsPrivate() bool { return net.IP(ip).IsPrivate() } + +func (ip IP) String() string { + return net.IP(ip).String() +} + +func init() { + _ = astral.DefaultBlueprints.Add(&IP{}) +} diff --git a/mod/udp/module.go b/mod/udp/module.go new file mode 100644 index 000000000..a646e0d62 --- /dev/null +++ b/mod/udp/module.go @@ -0,0 +1,14 @@ +package udp + +import ( + "github.com/cryptopunkscc/astrald/mod/exonet" +) + +const ModuleName = "udp" + +type Module interface { + exonet.Dialer + exonet.Unpacker + exonet.Parser + ListenPort() int +} diff --git a/mod/udp/src/config.go b/mod/udp/src/config.go new file mode 100644 index 000000000..f7e2f9ee9 --- /dev/null +++ b/mod/udp/src/config.go @@ -0,0 +1,169 @@ +package udp + +import ( + "time" +) + +// RFC-backed constants for src UDP config +const ( + ListenPort = 1791 + // QUIC requires endpoints to handle 1200-byte UDP datagrams without fragmentation (RFC 9000 §14.1) + + DefaultMSS = 1200 - 13 // 1187: 1200 minus our header + MinMSS = 512 // RFC 8085: avoid fragmentation, safe for most links + MaxMSS = 1400 // Keeps under 1500B MTU with IP/UDP/tunnel headroom (RFC 8085, RFC 791, RFC 8200) + + // WindowBytes conservative buffer, RFC 8085 + + DefaultWindowBytes = 16 * DefaultMSS + MinWindowBytes = MinMSS + MaxWindowBytes = 1 << 20 // 1 MiB + + // Retransmission timers: RFC 6298 + + DefaultRTO = 500 * time.Millisecond + DefaultRTOMax = 4 * time.Second + MinRTO = 10 * time.Millisecond // LAN-friendly floor + MaxRTOCeiling = 60 * time.Second // Avoid excessive backoff + + // Retries fail fast on persistent loss + + DefaultRetries = 8 + MinRetries = 1 + MaxRetries = 20 + + // AckDelay QUIC MAX_ACK_DELAY (RFC 9000 §13.2.1) + + DefaultAckDelay = 25 * time.Millisecond + MinAckDelay = 0 + + // Buffer sizes + + DefaultRecvBufBytes = 1 << 20 // 1 MiB + MinRecvBufBytes = DefaultWindowBytes // Should be at least as large as window + MaxRecvBufBytes = 8 << 20 // 8 MiB + DefaultSendBufBytes = 1 << 20 + MinSendBufBytes = DefaultWindowBytes + MaxSendBufBytes = 8 << 20 +) + +// Config holds general settings for the UDP module. +type Config struct { + ListenPort int `yaml:"listen_port,omitempty"` // Port to listen on for incoming connections (default 1791) + PublicEndpoints []string `yaml:"public_endpoints,omitempty"` + DialTimeout time.Duration `yaml:"dial_timeout,omitempty"` // Timeout for dialing connections (default 1 minute) + + FlowControl FlowControlConfig `yaml:"flow_control,omitempty"` // Flow control settings for UDP connections +} + +// FlowControlConfig holds configuration for individual UDP connections. +type FlowControlConfig struct { + MSS int // Maximum Segment Size (default 1187) + WindowBytes int // Send window size in bytes (default 16 * MSS) + RTO time.Duration // Initial retransmission timeout (default 500ms) + RTOMax time.Duration // Maximum retransmission timeout (default 4s) + RetryLimit int // Maximum retransmission attempts (default 8) + IdleTimeout time.Duration // Connection idle timeout (default 60s) + AckDelay time.Duration // Delayed ACK timer (default 25ms) + RecvBufBytes int // Receive buffer size (default 1MB) + SendBufBytes int // Send buffer size (default 1MB) +} + +// Normalize sets sensible defaults for zero-values, clamps to safe ranges, and enforces invariants. +// See RFC 9000, RFC 8085, RFC 6298 for rationale. +func (c *FlowControlConfig) Normalize() { + c.setDefaults() + c.clampValues() +} + +// setDefaults initializes zero-values with sensible defaults. +func (c *FlowControlConfig) setDefaults() { + if c.MSS == 0 { + c.MSS = DefaultMSS + } + if c.WindowBytes == 0 { + c.WindowBytes = DefaultWindowBytes + } + if c.RTO == 0 { + c.RTO = DefaultRTO + } + if c.RTOMax == 0 { + c.RTOMax = DefaultRTOMax + } + if c.RetryLimit == 0 { + c.RetryLimit = DefaultRetries + } + if c.AckDelay == 0 { + c.AckDelay = DefaultAckDelay + } + if c.RecvBufBytes == 0 { + c.RecvBufBytes = DefaultRecvBufBytes + } + if c.SendBufBytes == 0 { + c.SendBufBytes = DefaultSendBufBytes + } +} + +// NOTE: normally i would not introduce such function but when it comes to +// parameters of network protocols, +// i believe it is better to keep things within certain range of values ( +// all of which are stated at the top of this file) + +// clampValues ensures all fields are within safe ranges and enforces invariants. +func (c *FlowControlConfig) clampValues() { + c.MSS = clampInt(c.MSS, MinMSS, MaxMSS) + c.WindowBytes = clampInt(c.WindowBytes, c.MSS, MaxWindowBytes) + c.RTO = clampDur(c.RTO, MinRTO, MaxRTOCeiling) + c.RTOMax = clampDur(c.RTOMax, c.RTO, MaxRTOCeiling) + c.RetryLimit = clampInt(c.RetryLimit, MinRetries, MaxRetries) + c.AckDelay = clampDur(c.AckDelay, MinAckDelay, c.RTO/2) + c.RecvBufBytes = clampInt(c.RecvBufBytes, MinRecvBufBytes, MaxRecvBufBytes) + c.SendBufBytes = clampInt(c.SendBufBytes, MinSendBufBytes, MaxSendBufBytes) +} + +// clampInt clamps an integer value to a specified range. +func clampInt(v, lo, hi int) int { + if v < lo { + return lo + } + if v > hi { + return hi + } + return v +} + +// clampDur clamps a time.Duration value to a specified range. +func clampDur(v, lo, hi time.Duration) time.Duration { + if v < lo { + return lo + } + if v > hi { + return hi + } + return v +} + +var defaultConfig = Config{ + ListenPort: ListenPort, + DialTimeout: time.Minute, + FlowControl: FlowControlConfig{ + MSS: DefaultMSS, + WindowBytes: DefaultWindowBytes, + RTO: DefaultRTO, + RTOMax: DefaultRTOMax, + RetryLimit: DefaultRetries, + IdleTimeout: 60 * time.Second, // Default idle timeout of 1 minute + AckDelay: DefaultAckDelay, + RecvBufBytes: DefaultRecvBufBytes, + SendBufBytes: DefaultSendBufBytes, + }, +} + +// RFC rationale summary: +// +// - MSS: QUIC requires 1200B UDP datagrams (RFC 9000 §14.1), clamped to avoid fragmentation (RFC 8085). +// - WindowBytes: conservative buffer, RFC 8085, must be >= MSS. +// - RTO/RTOMax: TCP discipline (RFC 6298), pragmatic for UDP, exponential backoff. +// - AckDelay: mirrors QUIC MAX_ACK_DELAY (RFC 9000 §13.2.1). +// - Buffer sizes: 1 MiB default, capped for safety, must be >= window. +// - All invariants enforced for safety and interoperability. diff --git a/mod/udp/src/config_test.go b/mod/udp/src/config_test.go new file mode 100644 index 000000000..b9fb7ef15 --- /dev/null +++ b/mod/udp/src/config_test.go @@ -0,0 +1,78 @@ +package udp + +import ( + "testing" + "time" +) + +func TestFlowControlConfigDefaults(t *testing.T) { + def := defaultConfig.FlowControl + if def.MSS != DefaultMSS || def.WindowBytes != DefaultWindowBytes || def.RTO != DefaultRTO || def.RTOMax != DefaultRTOMax || def.RetryLimit != DefaultRetries || def.AckDelay != DefaultAckDelay || def.RecvBufBytes != DefaultRecvBufBytes || def.SendBufBytes != DefaultSendBufBytes { + t.Errorf("defaultConfig.FlowControl does not match expected defaults: %+v", def) + } +} + +func TestFlowControlConfigClamp(t *testing.T) { + tests := []struct { + name string + input FlowControlConfig + expected FlowControlConfig + }{ + { + name: "Values below range are clamped", + input: FlowControlConfig{ + MSS: 100, + WindowBytes: 100, + RTO: 5 * time.Millisecond, + RTOMax: 5 * time.Millisecond, + RetryLimit: 0, + AckDelay: -1 * time.Millisecond, + RecvBufBytes: 100, + SendBufBytes: 100, + }, + expected: FlowControlConfig{ + MSS: MinMSS, + WindowBytes: MinMSS, + RTO: MinRTO, + RTOMax: MinRTO, + RetryLimit: MinRetries, + AckDelay: MinAckDelay, + RecvBufBytes: MinRecvBufBytes, + SendBufBytes: MinSendBufBytes, + }, + }, + { + name: "Values above range are clamped", + input: FlowControlConfig{ + MSS: 2000, + WindowBytes: 2 << 20, + RTO: 70 * time.Second, + RTOMax: 70 * time.Second, + RetryLimit: 50, + AckDelay: 1 * time.Second, + RecvBufBytes: 16 << 20, + SendBufBytes: 16 << 20, + }, + expected: FlowControlConfig{ + MSS: MaxMSS, + WindowBytes: MaxWindowBytes, + RTO: MaxRTOCeiling, + RTOMax: MaxRTOCeiling, + RetryLimit: MaxRetries, + AckDelay: MinAckDelay, // AckDelay is clamped to MinAckDelay if above range + RecvBufBytes: MaxRecvBufBytes, + SendBufBytes: MaxSendBufBytes, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + input := test.input + input.clampValues() + if input != test.expected { + t.Errorf("clampValues() failed.\nGot: %+v\nExpected: %+v", input, test.expected) + } + }) + } +} diff --git a/mod/udp/src/conn.go b/mod/udp/src/conn.go new file mode 100644 index 000000000..1ffc9ef0d --- /dev/null +++ b/mod/udp/src/conn.go @@ -0,0 +1,191 @@ +// conn.go +package udp + +import ( + "bytes" + "context" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/udp" +) + +// Conn implements src, ordered communication over a connected UDP socket. +// Handshake/FIN are out of scope for this MVP; stream semantics only. +type Conn struct { + // socket / addressing + udpConn *net.UDPConn + localEndpoint *udp.Endpoint + remoteEndpoint *udp.Endpoint + + // config + cfg FlowControlConfig + mss int + + // send side (guarded by sendMu unless noted) + sendMu sync.Mutex + sendBase uint32 // first unacked byte + nextSeq uint32 // next byte sequence to assign + nextSendSeq uint32 + sendQ *bytes.Buffer // queued app data (bounded by cfg.SendBufBytes) + sendCond *sync.Cond // signals space available / data added + bytesInFlight int + unacked map[uint32]segMeta // seqStart -> meta + order []uint32 // seqStarts in send order (oldest first) + + // recv side + rcvMu sync.Mutex + rcvNext uint32 + ooo map[uint32][]byte // out-of-order segments by seqStart + appBuf *ringBuffer // ordered bytes for Read() + ackPending atomic.Bool // (reserved) if you add explicit flags later + + // timers + rtoMu sync.Mutex + rto time.Duration + rtoTimer *time.Timer + ackTimer *time.Timer // set on-demand in recv.go + + // control/lifecycle + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + + closed atomic.Bool + closeOnce sync.Once + closeErr atomic.Value // error + + // write serialization (shared per UDP socket if you ever share it) + writeMu *sync.Mutex + // Add a mutex field for synchronization + mutex sync.Mutex +} + +// NewConn constructs a connection around an already-connected UDP socket. +func NewConn(c *net.UDPConn, l, r *udp.Endpoint, cfg FlowControlConfig) (*Conn, error) { + cfg.Normalize() + if cfg.MSS <= 0 { + return nil, udp.ErrZeroMSS + } + + rc := &Conn{ + udpConn: c, + localEndpoint: l, + remoteEndpoint: r, + cfg: cfg, + mss: cfg.MSS, + + sendBase: 1, // start at 1 so 0 can be a sentinel in traces + nextSeq: 1, + sendQ: &bytes.Buffer{}, + unacked: make(map[uint32]segMeta), + order: make([]uint32, 0, 128), + + rcvNext: 1, + ooo: make(map[uint32][]byte), + appBuf: newRingBuffer(cfg.RecvBufBytes), + + rto: cfg.RTO, + writeMu: &sync.Mutex{}, + } + rc.sendCond = sync.NewCond(&rc.sendMu) + + // Start receiver loop (defined in recv.go) + rc.wg.Add(1) + go rc.recvLoop() + + return rc, nil +} + +// Read implements stream semantics. It blocks until data is available or the +// connection is closed and drained. On close, it returns any stored terminal error +// or io.EOF when the buffer is empty. +func (c *Conn) Read(p []byte) (int, error) { + n, err := c.appBuf.Read(p) + if n > 0 { + return n, nil + } + if c.closed.Load() { + if errv := c.closeErr.Load(); errv != nil { + return 0, errv.(error) + } + return 0, io.EOF + } + return n, err +} + +// Close terminates the connection and waits for the recv loop to exit. +func (c *Conn) Close() error { + c.closeOnce.Do(func() { + c.closed.Store(true) + c.cancel() + + // stop timers + c.stopRTO() // defined in send.go + c.rtoMu.Lock() + if c.ackTimer != nil { + c.ackTimer.Stop() + c.ackTimer = nil + } + c.rtoMu.Unlock() + + // wake blocked goroutines + c.sendCond.Broadcast() + c.appBuf.Close() + + _ = c.udpConn.Close() + }) + c.wg.Wait() + return nil +} + +// Outbound reports whether this connection was dialed out. +// For now this always returns true for Dial usage; adjust if you add a listener. +func (c *Conn) Outbound() bool { return true } + +// LocalEndpoint returns the local UDP endpoint. +func (c *Conn) LocalEndpoint() exonet.Endpoint { + return c.localEndpoint +} + +// RemoteEndpoint returns the remote UDP endpoint. +func (c *Conn) RemoteEndpoint() exonet.Endpoint { + return c.remoteEndpoint +} + +// NOTE: Flagged for check (might be ai overcomplexity) +// closeWithError records the error and closes the connection. +func (c *Conn) closeWithError(err error) { + if err != nil { + c.closeErr.Store(err) + } + _ = c.Close() +} + +// seqLT compares sequence numbers with wrap-around semantics. +func seqLT(a, b uint32) bool { return int32(a-b) < 0 } + +// sendPacket sends a packet over the UDP connection. +func (c *Conn) sendPacket(pkt *Packet) error { + raw, err := pkt.Marshal() + if err != nil { + return err + } + + c.writeMu.Lock() + defer c.writeMu.Unlock() + + _, err = c.udpConn.Write(raw) + return err +} + +// handleRTO handles retransmission timeouts by retransmitting the earliest unacked segment. +func (c *Conn) handleRTO() { + // Implementation for retransmission timeout handling + // This will involve retransmitting the earliest unacked segment + // and applying exponential backoff to the retransmission timer. +} diff --git a/mod/udp/src/deps.go b/mod/udp/src/deps.go new file mode 100644 index 000000000..a03aa351b --- /dev/null +++ b/mod/udp/src/deps.go @@ -0,0 +1,29 @@ +package udp + +import ( + "github.com/cryptopunkscc/astrald/core" + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/nodes" + "github.com/cryptopunkscc/astrald/mod/objects" +) + +// Deps represents the dependencies required by the src UDP module. +type Deps struct { + Exonet exonet.Module + Nodes nodes.Module + Objects objects.Module +} + +func (mod *Module) LoadDependencies() (err error) { + err = core.Inject(mod.node, &mod.Deps) + if err != nil { + return + } + + mod.Exonet.SetDialer("tcp", mod) + mod.Exonet.SetParser("tcp", mod) + mod.Exonet.SetUnpacker("tcp", mod) + mod.Nodes.AddResolver(mod) + + return +} diff --git a/mod/udp/src/dial.go b/mod/udp/src/dial.go new file mode 100644 index 000000000..136383fa5 --- /dev/null +++ b/mod/udp/src/dial.go @@ -0,0 +1,43 @@ +package udp + +import ( + "net" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/udp" +) + +var _ exonet.Dialer = &Module{} + +func (mod *Module) Dial(ctx *astral.Context, endpoint exonet.Endpoint) (exonet.Conn, error) { + switch endpoint.Network() { + case "udp": + // Supported network + default: + return nil, exonet.ErrUnsupportedNetwork + } + + // Use net.Dialer for dialing UDP connections + dialer := net.Dialer{Timeout: mod.config.DialTimeout} + conn, err := dialer.DialContext(ctx, "udp", endpoint.Address()) + if err != nil { + return nil, err + } + + localEndpoint, _ := udp.ParseEndpoint(conn.LocalAddr().String()) + remoteEndpoint, _ := udp.ParseEndpoint(conn.RemoteAddr().String()) + + udpConn, ok := conn.(*net.UDPConn) + if !ok { + return nil, exonet.ErrUnsupportedNetwork + } + + reliableConn, err := NewConn(udpConn, localEndpoint, remoteEndpoint, + mod.config.FlowControl) + if err != nil { + return nil, err + } + + return reliableConn, nil +} diff --git a/mod/udp/src/endpoint_resolver.go b/mod/udp/src/endpoint_resolver.go new file mode 100644 index 000000000..92ec39bba --- /dev/null +++ b/mod/udp/src/endpoint_resolver.go @@ -0,0 +1,22 @@ +package udp + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/sig" +) + +// NOTE: Should we expose our UDP endpoints the same way we expose TCP ones? + +func (mod *Module) ResolveEndpoints(ctx *astral.Context, nodeID *astral.Identity) (_ <-chan exonet.Endpoint, err error) { + if !nodeID.IsEqual(mod.node.Identity()) { + return sig.ArrayToChan([]exonet.Endpoint{}), nil + } + + var all []exonet.Endpoint + + all = append(all, mod.publicEndpoints...) + all = append(all, mod.localEndpoints()...) + + return sig.ArrayToChan(all), nil +} diff --git a/mod/udp/src/loader.go b/mod/udp/src/loader.go new file mode 100644 index 000000000..540b30221 --- /dev/null +++ b/mod/udp/src/loader.go @@ -0,0 +1,40 @@ +package udp + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/log" + "github.com/cryptopunkscc/astrald/core" + "github.com/cryptopunkscc/astrald/core/assets" + "github.com/cryptopunkscc/astrald/mod/udp" +) + +type Loader struct{} + +func (Loader) Load(node astral.Node, assets assets.Assets, l *log.Logger) (core.Module, error) { + mod := &Module{ + node: node, + log: l, + config: defaultConfig, + } + + _ = assets.LoadYAML(udp.ModuleName, &mod.config) + + // Parse public endpoints + for _, pe := range mod.config.PublicEndpoints { + endpoint, err := udp.ParseEndpoint(pe) + if err != nil { + l.Error("error parsing public endpoint \"%v\": %v", pe, err) + continue + } + + mod.publicEndpoints = append(mod.publicEndpoints, endpoint) + } + + return mod, nil +} + +func init() { + if err := core.RegisterModule(udp.ModuleName, Loader{}); err != nil { + panic(err) + } +} diff --git a/mod/udp/src/module.go b/mod/udp/src/module.go new file mode 100644 index 000000000..cb043d20e --- /dev/null +++ b/mod/udp/src/module.go @@ -0,0 +1,81 @@ +package udp + +import ( + "context" + "net" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/log" + "github.com/cryptopunkscc/astrald/core/assets" + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/udp" + "github.com/cryptopunkscc/astrald/tasks" +) + +// Module represents the UDP module and implements the exonet.Dialer interface. +type Module struct { + Deps + config Config // Configuration for the module + node astral.Node + assets assets.Assets + log *log.Logger + ctx context.Context + publicEndpoints []exonet.Endpoint +} + +func (mod *Module) Run(ctx *astral.Context) error { + mod.ctx = ctx + + err := tasks.Group(NewServer(mod)).Run(ctx) + if err != nil { + return err + } + + <-ctx.Done() + + return nil +} + +func (mod *Module) ListenPort() int { + return mod.config.ListenPort +} + +func (mod *Module) localIPs() ([]udp.IP, error) { + list := make([]udp.IP, 0) + + ifaceAddrs, err := InterfaceAddrs() + if err != nil { + return nil, err + } + + for _, a := range ifaceAddrs { + ipnet, ok := a.(*net.IPNet) + if !ok { + continue + } + ip := udp.IP(ipnet.IP) + list = append(list, ip) + } + + return list, nil +} + +func (mod *Module) localEndpoints() (list []exonet.Endpoint) { + ips, err := mod.localIPs() + if err != nil { + return + } + + for _, ip := range ips { + if ip.IsLoopback() { + continue + } + if ip.IsGlobalUnicast() || ip.IsPrivate() { + list = append(list, &udp.Endpoint{ + IP: ip, + Port: astral.Uint16(mod.ListenPort()), + }) + } + } + return +} diff --git a/mod/udp/src/net.go b/mod/udp/src/net.go new file mode 100644 index 000000000..9b37ab60b --- /dev/null +++ b/mod/udp/src/net.go @@ -0,0 +1,7 @@ +package udp + +import "net" + +// NOTE: same as in tcp module - consider moving to a common package + +var InterfaceAddrs = net.InterfaceAddrs diff --git a/mod/udp/src/packet.go b/mod/udp/src/packet.go new file mode 100644 index 000000000..a968d03cf --- /dev/null +++ b/mod/udp/src/packet.go @@ -0,0 +1,106 @@ +package udp + +import ( + "bytes" + "encoding/binary" + + "github.com/cryptopunkscc/astrald/mod/udp" +) + +const ( + FlagSYN = 1 << 0 + FlagACK = 1 << 1 + FlagFIN = 1 << 2 +) + +// Packet represents a src UDP packet with TCP-like header +type Packet struct { + Seq uint32 // Sequence number (first byte seq of this segment) + Ack uint32 // Acknowledgment number (cumulative ack: all bytes < Ack received) + Flags uint8 // Bit flags: SYN, ACK, FIN, etc. + Win uint16 // Advertised receive window in BYTES + Len uint16 // Payload length + Payload []byte // Actual data +} + +// Marshal serializes the Packet into bytes for transmission +// Format: [Seq:4][Ack:4][Flags:1][Win:2][Len:2][Payload:N] +func (p *Packet) Marshal() ([]byte, error) { + buf := new(bytes.Buffer) + if err := binary.Write(buf, binary.BigEndian, p.Seq); err != nil { + return nil, err + } + if err := binary.Write(buf, binary.BigEndian, p.Ack); err != nil { + return nil, err + } + if err := buf.WriteByte(p.Flags); err != nil { + return nil, err + } + if err := binary.Write(buf, binary.BigEndian, p.Win); err != nil { + return nil, err + } + if err := binary.Write(buf, binary.BigEndian, p.Len); err != nil { + return nil, err + } + if _, err := buf.Write(p.Payload); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// Unmarshal parses bytes into the Packet struct fields +func (p *Packet) Unmarshal(data []byte) error { + if len(data) < 13 { // 4+4+1+2+2 = 13 bytes header + return udp.ErrPacketTooShort + } + + p.Seq = binary.BigEndian.Uint32(data[0:4]) + p.Ack = binary.BigEndian.Uint32(data[4:8]) + p.Flags = data[8] + p.Win = binary.BigEndian.Uint16(data[9:11]) + p.Len = binary.BigEndian.Uint16(data[11:13]) + + // Validate payload length + if int(p.Len) != len(data)-13 { + return udp.ErrInvalidPayloadLength + } + + p.Payload = data[13:] + return nil +} + +// UnmarshalPacket parses bytes into a Packet instance. +func UnmarshalPacket(data []byte) (*Packet, error) { + if len(data) < 13 { // Minimum header size: 4+4+1+2+2 + return nil, udp.ErrMalformedPacket + } + + pkt := &Packet{} + r := bytes.NewReader(data) + if err := binary.Read(r, binary.BigEndian, &pkt.Seq); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &pkt.Ack); err != nil { + return nil, err + } + flags, err := r.ReadByte() + if err != nil { + return nil, err + } + pkt.Flags = flags + if err := binary.Read(r, binary.BigEndian, &pkt.Win); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &pkt.Len); err != nil { + return nil, err + } + if int(pkt.Len) > len(data)-13 { + return nil, udp.ErrMalformedPacket + } + pkt.Payload = make([]byte, pkt.Len) + if _, err := r.Read(pkt.Payload); err != nil { + return nil, err + } + + return pkt, nil +} diff --git a/mod/udp/src/packet_test.go b/mod/udp/src/packet_test.go new file mode 100644 index 000000000..167337082 --- /dev/null +++ b/mod/udp/src/packet_test.go @@ -0,0 +1,74 @@ +package udp + +import ( + "bytes" + "testing" +) + +func TestPacketMarshalUnmarshal(t *testing.T) { + tests := []struct { + name string + packet Packet + expectErr bool + }{ + { + name: "Valid Packet", + packet: Packet{ + Seq: 1, + Ack: 2, + Flags: 0x01, + Win: 1024, + Len: 5, + Payload: []byte("hello"), + }, + expectErr: false, + }, + { + name: "Empty Payload", + packet: Packet{ + Seq: 10, + Ack: 20, + Flags: 0x02, + Win: 2048, + Len: 0, + Payload: []byte{}, + }, + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test Marshal + data, err := tt.packet.Marshal() + if (err != nil) != tt.expectErr { + t.Fatalf("Marshal error = %v, expectErr = %v", err, tt.expectErr) + } + + // Test Unmarshal + var unmarshaled Packet + err = unmarshaled.Unmarshal(data) + if (err != nil) != tt.expectErr { + t.Fatalf("Unmarshal error = %v, expectErr = %v", err, tt.expectErr) + } + + // Verify the unmarshaled packet matches the original + if !bytes.Equal(unmarshaled.Payload, tt.packet.Payload) || + unmarshaled.Seq != tt.packet.Seq || + unmarshaled.Ack != tt.packet.Ack || + unmarshaled.Flags != tt.packet.Flags || + unmarshaled.Win != tt.packet.Win || + unmarshaled.Len != tt.packet.Len { + t.Errorf("Unmarshaled packet does not match original. Got %+v, want %+v", unmarshaled, tt.packet) + } + }) + } +} + +func TestUnmarshalPacket(t *testing.T) { + data := []byte("testdata") + _, err := UnmarshalPacket(data) + if err == nil { + t.Error("Expected error for invalid packet data, got nil") + } +} diff --git a/mod/udp/src/parse.go b/mod/udp/src/parse.go new file mode 100644 index 000000000..a67b7b358 --- /dev/null +++ b/mod/udp/src/parse.go @@ -0,0 +1,16 @@ +package udp + +import ( + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/udp" +) + +func (mod *Module) Parse(network string, address string) (exonet.Endpoint, error) { + switch network { + case "udp": + default: + return nil, exonet.ErrUnsupportedNetwork + } + + return udp.ParseEndpoint(address) +} diff --git a/mod/udp/src/plan.md b/mod/udp/src/plan.md new file mode 100644 index 000000000..64e9657c9 --- /dev/null +++ b/mod/udp/src/plan.md @@ -0,0 +1,209 @@ +# Reliable UDP Module: Architectural Brief + +## Purpose & Context +The Reliable UDP module provides stream-like semantics over UDP, ensuring ordered and reliable delivery of data. It is designed to integrate seamlessly with the Astral ecosystem, particularly the `exonet` module and node communication. Unlike raw UDP, this module introduces mechanisms for retransmissions, acknowledgments, and a handshake protocol to establish connection state before data exchange. + +## Interfaces & Contracts +- **Connection Interface**: Provides ordered, reliable byte streams. Implements `io.ReadWriteCloser`. +- **Listener Behavior**: Accepts incoming connections, demultiplexing based on remote endpoints. +- **Endpoint Handling**: Supports parsing, packing, and unpacking of network addresses. +- **Invariants**: + - Data is delivered in order. + - Lost packets are retransmitted. + - Connections are established via a handshake. + +## Handshake Protocol +The handshake follows a three-step process: +1. **SYN**: Initiator sends a SYN packet with an initial sequence number. +2. **SYN|ACK**: Responder replies with a SYN|ACK, acknowledging the initiator's sequence number and providing its own. +3. **ACK**: Initiator acknowledges the responder's sequence number, completing the handshake. + +### Sequence Space Rules +- SYN and FIN each consume one sequence number. +- Retransmissions occur if no acknowledgment is received within the retransmission timeout (RTO). +- A connection is established after the ACK is received. + +### Timing & Retransmission +- Initial RTO: 500ms (configurable). +- Exponential backoff for retransmissions. +- Maximum retries: 8 (configurable). + +## Data Path Overview +- **Segmentation**: Application data is split into packets, each with a sequence number. +- **Ordering**: Out-of-order packets are buffered until missing packets arrive. +- **Acknowledgments**: Cumulative ACKs confirm receipt of all bytes up to a sequence number. +- **Retransmission**: Unacknowledged packets are retransmitted after RTO. +- **Batching**: Multiple packets may be sent together to optimize throughput. + +## Concurrency & I/O Model +- **Locking Domains**: Separate locks for send and receive paths. +- **Goroutines**: + - One for sending data. + - One for receiving and processing packets. + - Timers for retransmissions. +- **Shutdown**: Ensures all goroutines exit cleanly, and no resources are leaked. + +## Error Model & Shutdown Semantics +- **Errors**: Surface as `net.Error` or module-specific errors. +- **Idempotent Close**: Closing a connection multiple times has no adverse effects. +- **Partial Failures**: Errors during send/receive are propagated to the caller. + +## Compatibility & Integration +- **Endpoint Parsing**: Compatible with `exonet` endpoint parsing and unpacking. +- **Lifecycle Alignment**: Designed to align with the lifecycle of other modules like TCP and Tor. +- **Assumptions**: Assumes reliable delivery within the module; does not handle NAT traversal or encryption. + +## Security & Future Considerations +- **Stateless Cookie**: Potential for DoS mitigation using stateless cookies during the handshake. +- **PLPMTUD**: Path MTU discovery to avoid fragmentation. +- **Congestion Control**: Future integration with congestion control mechanisms. + +## Missing Logics and Potential Issues + +### Missing Logics +1. **Connection Handshake Validation**: + - The handshake process lacks validation for replay attacks or duplicate SYN packets. This could lead to unnecessary resource allocation. + +2. **Congestion Control**: + - The module does not implement congestion control mechanisms, which could lead to network congestion in high-traffic scenarios. + +3. **DoS Mitigation**: + - There is no stateless cookie mechanism during the handshake to prevent denial-of-service (DoS) attacks. + +4. **Connection Timeout**: + - The module does not enforce a timeout for idle connections, which could lead to resource exhaustion. + +5. **Error Propagation**: + - Errors during retransmissions or ACK handling are not consistently propagated to the caller, which could make debugging difficult. + +### Potential Performance Issues +1. **Timer Management**: + - The retransmission timer (`armRTO`) and ACK delay timer (`armAckDelay`) are reset frequently, which could lead to high overhead in timer management. + +2. **Lock Contention**: + - The use of mutexes (`rtoMu`) for timer operations could lead to contention under high concurrency. + +3. **Inefficient Buffering**: + - The `sendQ` buffer in `conn.go` may become a bottleneck if the application writes data faster than the network can transmit. + +4. **Packet Parsing Overhead**: + - The `UnmarshalPacket` function in `recv.go` is called for every incoming packet, which could become a performance bottleneck if the parsing logic is complex. + +### Potential Bugs +1. **Timer Race Conditions**: + - The `armRTO` and `stopRTO` functions do not ensure that the timer callback (`handleRTO`) is not running when the timer is stopped, which could lead to race conditions. + +2. **Endpoint Parsing Errors**: + - The `Dial` function in `dial.go` does not handle errors from `udp.ParseEndpoint`, which could lead to nil pointer dereferences. + +3. **Unbounded Retransmissions**: + - The retransmission logic does not enforce a maximum number of retries, which could lead to infinite retransmissions in case of persistent packet loss. + +4. **ACK Timer Reset**: + - The `armAckDelay` function resets the ACK timer without checking if the timer is already running, which could lead to missed ACKs. + +## Implementation Status + +### Fully Implemented and Tested +1. **Ring Buffer**: + - Complete implementation with test coverage for: + - Blocking write/read operations + - Buffer closure handling + - Concurrent access patterns + +2. **Packet Serialization**: + - Full implementation with tests for: + - Marshal/unmarshal operations + - Valid packet handling + - Empty payload cases + +3. **Configuration**: + - Complete implementation with tests covering: + - Default values + - Range validation + - Value normalization + +### Fully Implemented but Untested +1. **Data Transmission**: + - Segmentation and packet sending in `send.go` + - Retransmission handling in `timers.go` + - No test coverage for edge cases or error conditions + +2. **Data Reception**: + - Packet processing and buffering in `recv.go` + - Out-of-order packet handling + - No tests for complex reassembly scenarios + +3. **Server Logic**: + - Connection management in `server.go` + - Datagram routing + - Lacks tests for concurrent connections + +### Partially Implemented +1. **Handshake Protocol**: + - Basic structure defined in `packet.go` (SYN/ACK/FIN flags) + - Missing implementation in: + - Connection establishment logic + - State machine for handshake steps + - Timeout handling during handshake + +2. **Error Handling**: + - Basic error types defined + - Inconsistent propagation in retransmission logic + - Missing comprehensive error recovery + +3. **Timer Management**: + - Basic timer operations implemented + - Race condition risks identified + - Missing proper cleanup and synchronization + +### Missing Components +1. **Connection State Management**: + - No explicit connection state machine + - Missing timeout handling for idle connections + - No graceful connection termination + +2. **Flow Control**: + - Window size tracking implemented + - Missing: + - Congestion control + - Slow start mechanism + - Fast retransmit/recovery + +3. **Security Features**: + - No DoS protection + - Missing replay attack prevention + - No cookie mechanism for handshake + +4. **Testing Infrastructure**: + - Need integration tests for: + - Complete connection lifecycle + - Error scenarios + - Performance under load + - Network condition simulation + +### Next Steps (Prioritized) +1. **Complete Handshake Implementation**: + - Implement state transitions + - Add timeout handling + - Include sequence number validation + +2. **Add Connection Management**: + - Implement idle connection detection + - Add connection timeouts + - Create cleanup mechanisms + +3. **Enhance Security**: + - Add SYN cookie mechanism + - Implement replay protection + - Add rate limiting for new connections + +4. **Improve Reliability**: + - Add congestion control + - Implement proper window management + - Add fast retransmit/recovery + +5. **Complete Test Coverage**: + - Add integration tests + - Create network simulation tests + - Test concurrent connections diff --git a/mod/udp/src/recv.go b/mod/udp/src/recv.go new file mode 100644 index 000000000..f4a29d612 --- /dev/null +++ b/mod/udp/src/recv.go @@ -0,0 +1,128 @@ +package udp + +import ( + "net" + "time" +) + +// recvLoop parses incoming datagrams, processes ACKs and data, and coalesces ACKs back. +func (c *Conn) recvLoop() { + defer c.wg.Done() + + buf := make([]byte, 64*1024) + + for { + _ = c.udpConn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + n, _, err := c.udpConn.ReadFromUDP(buf) + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Timeout() { + if c.closed.Load() || c.ctx.Err() != nil { + return + } + continue + } + if !c.closed.Load() { + c.closeWithError(err) + } + return + } + + pkt, uerr := UnmarshalPacket(buf[:n]) + if uerr != nil { + // drop malformed + continue + } + + if pkt.Ack != 0 || (pkt.Flags&FlagACK) != 0 { + c.advanceAck(pkt.Ack) + } + + if pkt.Len > 0 { + if err := c.handleData(pkt); err != nil { + c.closeWithError(err) + return + } + } + } +} + +// handleData commits in-order payload to appBuf, buffers out-of-order, +// and schedules a delayed pure ACK for the burst. +func (c *Conn) handleData(pkt *Packet) error { + ackNeeded := false + + c.rcvMu.Lock() + switch { + case pkt.Seq == c.rcvNext: + payload := append([]byte(nil), pkt.Payload...) + c.rcvMu.Unlock() + + // Do not hold rcvMu while blocking + if _, err := c.appBuf.WriteAll(payload); err != nil { + return err + } + + c.rcvMu.Lock() + c.rcvNext += uint32(len(payload)) + ackNeeded = true + + // drain contiguous out-of-order + for { + next, ok := c.ooo[c.rcvNext] + if !ok { + break + } + delete(c.ooo, c.rcvNext) + data := append([]byte(nil), next...) + c.rcvMu.Unlock() + if _, err := c.appBuf.WriteAll(data); err != nil { + return err + } + c.rcvMu.Lock() + c.rcvNext += uint32(len(data)) + } + c.rcvMu.Unlock() + + case seqLT(c.rcvNext, pkt.Seq): + // out-of-order: drop if would exceed RecvBuf cap + if pkt.Len <= uint16(c.cfg.RecvBufBytes) { + c.ooo[pkt.Seq] = append([]byte(nil), pkt.Payload...) + ackNeeded = true // peer will infer gap + } + c.rcvMu.Unlock() + + default: + // duplicate; ignore + c.rcvMu.Unlock() + } + + if ackNeeded { + c.armAckDelay() + } + return nil +} + +// sendPureACK emits a standalone cumulative ACK for rcvNext. +func (c *Conn) sendPureACK() error { + // snapshot current cumulative ack + c.rcvMu.Lock() + ack := c.rcvNext + c.rcvMu.Unlock() + + pkt := Packet{ + Seq: c.nextSeq, // sender ignores Seq on pure ACK + Ack: ack, + Flags: FlagACK, + Win: 0, + Len: 0, + } + raw, err := pkt.Marshal() + if err != nil { + return err + } + + c.writeMu.Lock() + _, werr := c.udpConn.Write(raw) + c.writeMu.Unlock() + return werr +} diff --git a/mod/udp/src/ring_buffer.go b/mod/udp/src/ring_buffer.go new file mode 100644 index 000000000..76237b65a --- /dev/null +++ b/mod/udp/src/ring_buffer.go @@ -0,0 +1,159 @@ +package udp + +import ( + "io" + "sync" +) + +// ringBuffer implements a blocking circular buffer for byte streams. +type ringBuffer struct { + buf []byte // underlying buffer + cap int // capacity (fixed) + n int // current bytes stored + r int // read position + w int // write position + mu sync.Mutex // protects all fields + notEmp *sync.Cond // signaled when buffer becomes non-empty + notFul *sync.Cond // signaled when buffer has space available + closed bool // whether buffer is closed +} + +// newRingBuffer creates a new ring buffer with the specified capacity. +func newRingBuffer(capacity int) *ringBuffer { + if capacity < 0 { + capacity = 0 + } + rb := &ringBuffer{ + buf: make([]byte, capacity), + cap: capacity, + } + rb.notEmp = sync.NewCond(&rb.mu) + rb.notFul = sync.NewCond(&rb.mu) + return rb +} + +// WriteAll blocks until all bytes are written or the buffer is closed. +func (rb *ringBuffer) WriteAll(b []byte) (int, error) { + rb.mu.Lock() + defer rb.mu.Unlock() + + written := 0 + for written < len(b) { + for rb.n == rb.cap && !rb.closed { + rb.notFul.Wait() + } + if rb.closed { + return written, io.ErrClosedPipe + } + + space := rb.cap - rb.n + toWrite := len(b) - written + if toWrite > space { + toWrite = space + } + + end := (rb.w + toWrite) % rb.cap + if end > rb.w { + copy(rb.buf[rb.w:end], b[written:written+toWrite]) + } else { + copy(rb.buf[rb.w:], b[written:written+toWrite]) + copy(rb.buf[:end], b[written+rb.cap-rb.w:written+toWrite]) + } + + rb.w = end + rb.n += toWrite + written += toWrite + rb.notEmp.Signal() + } + + return written, nil +} + +// TryWrite attempts to write bytes without blocking. +func (rb *ringBuffer) TryWrite(b []byte) int { + rb.mu.Lock() + defer rb.mu.Unlock() + + if rb.n == rb.cap || rb.closed { + return 0 + } + + space := rb.cap - rb.n + toWrite := len(b) + if toWrite > space { + toWrite = space + } + + end := (rb.w + toWrite) % rb.cap + if end > rb.w { + copy(rb.buf[rb.w:end], b[:toWrite]) + } else { + copy(rb.buf[rb.w:], b[:toWrite]) + copy(rb.buf[:end], b[rb.cap-rb.w:toWrite]) + } + + rb.w = end + rb.n += toWrite + rb.notEmp.Signal() + + return toWrite +} + +// Read blocks until at least one byte is available or the buffer is closed and drained. +func (rb *ringBuffer) Read(p []byte) (int, error) { + rb.mu.Lock() + defer rb.mu.Unlock() + + for rb.n == 0 && !rb.closed { + rb.notEmp.Wait() + } + + if rb.n == 0 && rb.closed { + return 0, io.EOF + } + + toRead := len(p) + if toRead > rb.n { + toRead = rb.n + } + + end := (rb.r + toRead) % rb.cap + if end > rb.r { + copy(p, rb.buf[rb.r:end]) + } else { + copy(p, rb.buf[rb.r:]) + copy(p[rb.cap-rb.r:], rb.buf[:end]) + } + + rb.r = end + rb.n -= toRead + rb.notFul.Signal() + + return toRead, nil +} + +// Close marks the buffer as closed and wakes all waiters. +func (rb *ringBuffer) Close() { + rb.mu.Lock() + defer rb.mu.Unlock() + + if !rb.closed { + rb.closed = true + rb.notEmp.Broadcast() + rb.notFul.Broadcast() + } +} + +// Len returns the number of bytes currently stored in the buffer. +func (rb *ringBuffer) Len() int { + rb.mu.Lock() + defer rb.mu.Unlock() + return rb.n +} + +// Cap returns the capacity of the buffer. +func (rb *ringBuffer) Cap() int { + rb.mu.Lock() + defer rb.mu.Unlock() + return rb.cap +} diff --git a/mod/udp/src/ring_buffer_test.go b/mod/udp/src/ring_buffer_test.go new file mode 100644 index 000000000..4b7d9b99d --- /dev/null +++ b/mod/udp/src/ring_buffer_test.go @@ -0,0 +1,170 @@ +package udp + +import ( + "io" + "sync" + "testing" + "time" +) + +func TestRingBufferBlockingWriteRead(t *testing.T) { + rb := newRingBuffer(10) + + var wg sync.WaitGroup + wg.Add(2) + + // Writer + go func() { + defer wg.Done() + data := []byte("hello world") + written, err := rb.WriteAll(data) + if written != len(data) || err != nil { + t.Errorf("WriteAll failed: written=%d, err=%v", written, err) + } + }() + + // Reader + go func() { + defer wg.Done() + buf := make([]byte, 11) + read, err := rb.Read(buf) + if read != 11 || err != nil || string(buf) != "hello world" { + t.Errorf("Read failed: read=%d, err=%v, buf=%s", read, err, buf) + } + }() + + wg.Wait() +} + +func TestRingBufferCloseWhileReaderWaiting(t *testing.T) { + rb := newRingBuffer(10) + + var wg sync.WaitGroup + wg.Add(1) + + // Reader + go func() { + defer wg.Done() + buf := make([]byte, 10) + _, err := rb.Read(buf) + if err != io.EOF { + t.Errorf("Expected io.EOF, got %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) // Ensure reader is waiting + rb.Close() + + wg.Wait() +} + +func TestRingBufferCloseWhileWriterWaiting(t *testing.T) { + rb := newRingBuffer(5) + + var wg sync.WaitGroup + wg.Add(1) + + // Writer + go func() { + defer wg.Done() + data := []byte("hello world") + _, err := rb.WriteAll(data) + if err != io.ErrClosedPipe { + t.Errorf("Expected io.ErrClosedPipe, got %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) // Ensure writer is waiting + rb.Close() + + wg.Wait() +} + +func TestRingBufferConcurrentProducersConsumers(t *testing.T) { + rb := newRingBuffer(50) + + var wg sync.WaitGroup + producers := 5 + consumers := 5 + iterations := 100 + + // Producers + for i := 0; i < producers; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < iterations; j++ { + rb.WriteAll([]byte{byte(id)}) + } + }(i) + } + + // Consumers + for i := 0; i < consumers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 1) + for j := 0; j < iterations; j++ { + rb.Read(buf) + } + }() + } + + wg.Wait() +} + +func TestRingBufferZeroCapacity(t *testing.T) { + rb := newRingBuffer(0) + + var wg sync.WaitGroup + wg.Add(1) + + // Writer + go func() { + defer wg.Done() + data := []byte("data") + _, err := rb.WriteAll(data) + if err != io.ErrClosedPipe { + t.Errorf("Expected io.ErrClosedPipe, got %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) // Ensure writer is waiting + rb.Close() + + wg.Wait() +} + +func TestRingBufferRaceSafety(t *testing.T) { + rb := newRingBuffer(100) + + var wg sync.WaitGroup + producers := 10 + consumers := 10 + + // Producers + for i := 0; i < producers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 1000; j++ { + rb.TryWrite([]byte{byte(j % 256)}) + } + }() + } + + // Consumers + for i := 0; i < consumers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 10) + for j := 0; j < 1000; j++ { + rb.Read(buf) + } + }() + } + + wg.Wait() +} diff --git a/mod/udp/src/seg_meta.go b/mod/udp/src/seg_meta.go new file mode 100644 index 000000000..f7e8e29a0 --- /dev/null +++ b/mod/udp/src/seg_meta.go @@ -0,0 +1,12 @@ +package udp + +import "time" + +// segMeta stores metadata for an unacked segment +type segMeta struct { + data []byte + sentAt time.Time + retries int + seqStart uint32 + length int +} diff --git a/mod/udp/src/send.go b/mod/udp/src/send.go new file mode 100644 index 000000000..6c83781e9 --- /dev/null +++ b/mod/udp/src/send.go @@ -0,0 +1,249 @@ +package udp + +import ( + "context" + "time" + + "github.com/cryptopunkscc/astrald/mod/udp" +) + +// Write implements batched segmentation and send. +func (c *Conn) Write(p []byte) (int, error) { + if c.closed.Load() { + return 0, udp.ErrClosed + } + written := 0 + + for written < len(p) { + chunk := p[written:] + + c.sendMu.Lock() + for c.sendQ.Len() >= c.cfg.SendBufBytes && !c.closed.Load() { + c.sendCond.Wait() + } + if c.closed.Load() { + c.sendMu.Unlock() + return written, udp.ErrClosed + } + space := c.cfg.SendBufBytes - c.sendQ.Len() + if space > 0 { + toCopy := len(chunk) + if toCopy > space { + toCopy = space + } + _, err := c.sendQ.Write(chunk[:toCopy]) + if err != nil { + c.sendMu.Unlock() + return written, err + } + written += toCopy + } + c.sendMu.Unlock() + + if err := c.flushSendQueue(); err != nil { + return written, err + } + } + + return written, nil +} + +// flushSendQueue cuts ≤ MSS segments from sendQ while within WindowBytes, +// builds packets (piggybacking ACK), records unacked, and writes them back-to-back. +func (c *Conn) flushSendQueue() error { + var bufs [][]byte + var ack uint32 + + c.sendMu.Lock() + c.rcvMu.Lock() + ack = c.rcvNext + c.rcvMu.Unlock() + + win := c.cfg.WindowBytes - c.bytesInFlight + for win > 0 && c.sendQ.Len() > 0 { + segLen := c.mss + if segLen > c.sendQ.Len() { + segLen = c.sendQ.Len() + } + if segLen > win { + segLen = win + } + if segLen <= 0 { + break + } + + payload := c.sendQ.Next(segLen) + seq := c.nextSeq + + pkt := Packet{ + Seq: seq, + Ack: ack, + Flags: FlagACK, + Win: 0, + Len: uint16(segLen), + Payload: payload, + } + raw, err := pkt.Marshal() + if err != nil { + c.sendMu.Unlock() + return err + } + bufs = append(bufs, raw) + + meta := segMeta{ + data: append([]byte(nil), payload...), + sentAt: time.Now(), + retries: 0, + seqStart: seq, + length: segLen, + } + c.unacked[seq] = meta + c.order = append(c.order, seq) + + c.nextSeq += uint32(segLen) + c.bytesInFlight += segLen + win -= segLen + } + c.sendMu.Unlock() + + if len(bufs) == 0 { + return nil + } + + c.writeMu.Lock() + var writeErr error + for _, b := range bufs { + if _, writeErr = c.udpConn.Write(b); writeErr != nil { + break + } + } + c.writeMu.Unlock() + if writeErr != nil { + return writeErr + } + + c.startRTOIfNeededLocked() + return nil +} + +// startRTOIfNeededLocked arms the RTO timer when unacked is non-empty and no timer running. +func (c *Conn) startRTOIfNeededLocked() { + c.sendMu.Lock() + need := len(c.unacked) > 0 + c.sendMu.Unlock() + if !need { + return + } + + c.rtoMu.Lock() + defer c.rtoMu.Unlock() + if c.rtoTimer != nil { + return + } + d := c.rto + if d <= 0 { + d = c.cfg.RTO + } + c.rtoTimer = time.AfterFunc(d, c.onRTOTimeout) +} + +// onRTOTimeout retransmits the earliest unacked segment with backoff. +func (c *Conn) onRTOTimeout() { + var seq uint32 + var meta segMeta + var ok bool + + c.sendMu.Lock() + if len(c.order) == 0 { + c.sendMu.Unlock() + c.stopRTO() + return + } + seq = c.order[0] + meta, ok = c.unacked[seq] + if !ok { + c.order = c.order[1:] + c.sendMu.Unlock() + c.startRTOIfNeededLocked() + return + } + + c.rcvMu.Lock() + ack := c.rcvNext + c.rcvMu.Unlock() + + pkt := Packet{ + Seq: meta.seqStart, + Ack: ack, + Flags: FlagACK, + Win: 0, + Len: uint16(meta.length), + Payload: meta.data, + } + raw, err := pkt.Marshal() + c.sendMu.Unlock() + if err != nil { + c.closeWithError(err) + return + } + + c.writeMu.Lock() + _, werr := c.udpConn.Write(raw) + c.writeMu.Unlock() + if werr != nil { + c.closeWithError(werr) + return + } + + c.sendMu.Lock() + meta.retries++ + meta.sentAt = time.Now() + c.unacked[seq] = meta + c.rto *= 2 + if c.rto > c.cfg.RTOMax { + c.rto = c.cfg.RTOMax + } + overLimit := meta.retries > c.cfg.RetryLimit + c.sendMu.Unlock() + + if overLimit { + c.closeWithError(context.DeadlineExceeded) + return + } + + c.rtoMu.Lock() + if c.rtoTimer != nil { + c.rtoTimer.Reset(c.rto) + } + c.rtoMu.Unlock() +} + +// advanceAck removes fully-acked segments up to 'ack' and manages timers/backpressure. +func (c *Conn) advanceAck(ack uint32) { + c.sendMu.Lock() + changed := false + for len(c.order) > 0 { + seq := c.order[0] + meta := c.unacked[seq] + end := seq + uint32(meta.length) + // if segment end <= ack, it is fully acked + if seqLT(end, ack) || end == ack { + delete(c.unacked, seq) + c.order = c.order[1:] + c.bytesInFlight -= meta.length + changed = true + } else { + break + } + } + empty := len(c.unacked) == 0 + c.sendMu.Unlock() + + if changed { + // let writers proceed if sendQ was full + c.sendCond.Signal() + } + if empty { + c.stopRTO() + } +} diff --git a/mod/udp/src/server.go b/mod/udp/src/server.go new file mode 100644 index 000000000..f61033e97 --- /dev/null +++ b/mod/udp/src/server.go @@ -0,0 +1,96 @@ +package udp + +import ( + "net" + "sync" + + "github.com/cryptopunkscc/astrald/astral" +) + +// Connection states +const ( + StateClosed = iota // Connection is closed + StateSynSent // SYN sent, waiting for SYN-ACK + StateSynReceived // SYN received, waiting for ACK + StateEstablished // Connection established +) + +// Server implements src UDP listener with connection demultiplexing +type Server struct { + *Module + listener *net.UDPConn + conns map[string]*Conn // Remote address → connection map + mutex sync.Mutex // Protects access to conns + acceptCh chan *Conn // Channel for accepted connections + stopCh chan struct{} // Channel to signal server shutdown + wg sync.WaitGroup // WaitGroup for managing goroutines +} + +// NewServer creates a new src UDP server +func NewServer(module *Module) *Server { + return &Server{ + Module: module, + conns: make(map[string]*Conn), + acceptCh: make(chan *Conn, 16), + stopCh: make(chan struct{}), + } +} + +// Run starts the server and listens for incoming connections +func (s *Server) Run(ctx *astral.Context) error { + + listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: s.config.ListenPort}) + if err != nil { + s.log.Errorv(0, "failed to start server: %v", err) + return err + } + s.listener = listener + s.log.Info("started server at %v", listener.LocalAddr()) + defer s.log.Info("stopped server at %v", listener.LocalAddr()) + + s.wg.Add(1) + go s.readLoop() + + <-ctx.Done() + s.Close() + return nil +} + +// Close gracefully shuts down the server +func (s *Server) Close() error { + close(s.stopCh) + s.mutex.Lock() + for _, conn := range s.conns { + conn.Close() + } + s.mutex.Unlock() + + s.listener.Close() + s.wg.Wait() + return nil +} + +// readLoop handles incoming datagrams and routes them to connections +func (s *Server) readLoop() { + defer s.wg.Done() + buf := make([]byte, 64*1024) // Large buffer for high throughput + + for { + n, addr, err := s.listener.ReadFromUDP(buf) + if err != nil { + select { + case <-s.stopCh: + return // Graceful shutdown + default: + s.log.Errorv(1, "read error: %v", err) + continue + } + } + + s.handlePacket(buf[:n], addr) + } +} + +// handlePacket processes an incoming packet and routes it to the appropriate connection +func (s *Server) handlePacket(data []byte, addr *net.UDPAddr) { +} diff --git a/mod/udp/src/timers.go b/mod/udp/src/timers.go new file mode 100644 index 000000000..14a263bbc --- /dev/null +++ b/mod/udp/src/timers.go @@ -0,0 +1,45 @@ +package udp + +import ( + "time" +) + +// armRTO starts or resets the retransmission timer +func (c *Conn) armRTO(d time.Duration) { + c.rtoMu.Lock() + defer c.rtoMu.Unlock() + + if c.rtoTimer != nil { + c.rtoTimer.Stop() + } + c.rtoTimer = time.AfterFunc(d, c.handleRTO) +} + +// stopRTO safely stops the retransmission timer +func (c *Conn) stopRTO() { + c.rtoMu.Lock() + defer c.rtoMu.Unlock() + + if c.rtoTimer != nil { + c.rtoTimer.Stop() + c.rtoTimer = nil + } +} + +// armAckDelay schedules a pure ACK to be sent soon +func (c *Conn) armAckDelay() { + c.rtoMu.Lock() + defer c.rtoMu.Unlock() + + if c.ackTimer == nil { + c.ackTimer = time.AfterFunc(c.cfg.AckDelay, c.sendPureACK) + } else { + c.ackTimer.Reset(c.cfg.AckDelay) + } +} + +// armAckDelayTimerLocked initializes the ACK delay timer (called during initialization) +// Note: Unlike armAckDelay, this is called from NewConn when mutex is already held +func (c *Conn) armAckDelayTimerLocked() { + c.ackTimer = time.AfterFunc(c.cfg.AckDelay, c.sendPureACK) +} diff --git a/mod/udp/src/unpack.go b/mod/udp/src/unpack.go new file mode 100644 index 000000000..0f9403e35 --- /dev/null +++ b/mod/udp/src/unpack.go @@ -0,0 +1,25 @@ +package udp + +import ( + "bytes" + + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/udp" +) + +var _ exonet.Unpacker = &Module{} + +func (mod *Module) Unpack(network string, data []byte) (exonet.Endpoint, error) { + switch network { + case "udp": + default: + return nil, exonet.ErrUnsupportedNetwork + } + return Unpack(data) +} + +func Unpack(buf []byte) (e *udp.Endpoint, err error) { + e = &udp.Endpoint{} + _, err = e.ReadFrom(bytes.NewReader(buf)) + return +} From 8778effdecec2b848ba13c2ff1a6bb04a3fc33e2 Mon Sep 17 00:00:00 2001 From: Rekseto Date: Thu, 25 Sep 2025 13:40:56 +0200 Subject: [PATCH 02/13] fix: update Exonet to use UDP instead of TCP for dialer, parser, and unpacker --- mod/udp/src/conn.go | 1 - mod/udp/src/deps.go | 6 +++--- mod/udp/src/send.go | 1 + 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mod/udp/src/conn.go b/mod/udp/src/conn.go index 1ffc9ef0d..01865fd1e 100644 --- a/mod/udp/src/conn.go +++ b/mod/udp/src/conn.go @@ -43,7 +43,6 @@ type Conn struct { ooo map[uint32][]byte // out-of-order segments by seqStart appBuf *ringBuffer // ordered bytes for Read() ackPending atomic.Bool // (reserved) if you add explicit flags later - // timers rtoMu sync.Mutex rto time.Duration diff --git a/mod/udp/src/deps.go b/mod/udp/src/deps.go index a03aa351b..24afdf349 100644 --- a/mod/udp/src/deps.go +++ b/mod/udp/src/deps.go @@ -20,9 +20,9 @@ func (mod *Module) LoadDependencies() (err error) { return } - mod.Exonet.SetDialer("tcp", mod) - mod.Exonet.SetParser("tcp", mod) - mod.Exonet.SetUnpacker("tcp", mod) + mod.Exonet.SetDialer("udp", mod) + mod.Exonet.SetParser("udp", mod) + mod.Exonet.SetUnpacker("udp", mod) mod.Nodes.AddResolver(mod) return diff --git a/mod/udp/src/send.go b/mod/udp/src/send.go index 6c83781e9..48ee8baba 100644 --- a/mod/udp/src/send.go +++ b/mod/udp/src/send.go @@ -180,6 +180,7 @@ func (c *Conn) onRTOTimeout() { Len: uint16(meta.length), Payload: meta.data, } + raw, err := pkt.Marshal() c.sendMu.Unlock() if err != nil { From 650752d1aa72fbaa36267acc5112a26b209b006b Mon Sep 17 00:00:00 2001 From: Rekseto Date: Thu, 25 Sep 2025 21:29:08 +0200 Subject: [PATCH 03/13] simplification of reliable UDP --- mod/udp/README.md | 54 ------ mod/udp/conn.go | 12 -- mod/udp/errors.go | 17 +- mod/udp/src/conn.go | 273 +++++++++++++++-------------- mod/udp/src/conn_handshake.go | 113 ++++++++++++ mod/udp/src/conn_handshake_test.go | 132 ++++++++++++++ mod/udp/src/packet.go | 13 ++ mod/udp/src/recv.go | 128 -------------- mod/udp/src/seg_meta.go | 12 -- mod/udp/src/send.go | 250 -------------------------- mod/udp/src/server.go | 70 ++++++-- mod/udp/src/timers.go | 45 ----- 12 files changed, 466 insertions(+), 653 deletions(-) delete mode 100644 mod/udp/conn.go create mode 100644 mod/udp/src/conn_handshake.go create mode 100644 mod/udp/src/conn_handshake_test.go delete mode 100644 mod/udp/src/recv.go delete mode 100644 mod/udp/src/seg_meta.go delete mode 100644 mod/udp/src/send.go delete mode 100644 mod/udp/src/timers.go diff --git a/mod/udp/README.md b/mod/udp/README.md index 90d5a39a5..e69de29bb 100644 --- a/mod/udp/README.md +++ b/mod/udp/README.md @@ -1,54 +0,0 @@ -# Reliable UDP Module - -## Overview -This module provides reliable, ordered, and stream-like communication over UDP. It ensures data integrity and delivery through acknowledgments, retransmissions, and a handshake protocol. The module is part of the Astral ecosystem and integrates with `exonet` for endpoint management. - -## File & Struct Map -- **config.go**: Defines configuration constants (e.g., retransmission timeouts, buffer sizes). -- **conn.go**: Implements the `Conn` struct, representing a reliable UDP connection. Handles sending, receiving, and retransmissions. -- **endpoint_resolver.go**: Resolves endpoints for the module, integrating with `exonet`. -- **loader.go**: Initializes the module with dependencies. -- **module.go**: Defines the `Module` struct, the entry point for the UDP module. -- **recv.go**: Implements the receive loop, processing incoming packets and acknowledgments. -- **send.go**: Handles segmentation, batching, and sending of data. -- **ring_buffer.go**: Provides a circular buffer for efficient data storage and retrieval. -- **packet.go**: Defines the `Packet` struct and serialization logic. -- **server.go**: Implements the `Server` struct, managing incoming connections and demultiplexing. - -## Current Findings & Considerations -- **ACK Handling**: The module uses cumulative acknowledgments to confirm receipt of data up to a specific sequence number. This simplifies state management but requires careful handling of retransmissions to avoid unnecessary duplicates. -- **RTO/Backoff**: Retransmission timeouts are implemented based on RFC 6298, with exponential backoff to handle varying network conditions. This ensures robustness in the face of packet loss. -- **Handshake**: The handshake protocol establishes connection state before data exchange. However, it currently lacks stateless cookie support, which could mitigate DoS attacks by ensuring that resources are only allocated for legitimate connections. - -## Datagram Structure -A datagram in this module is represented by the `Packet` struct. It includes the following fields: -- **Seq (uint32)**: Sequence number indicating the first byte of the segment. -- **Ack (uint32)**: Cumulative acknowledgment number, confirming receipt of all bytes up to this sequence number. -- **Flags (uint8)**: Control flags such as SYN, ACK, and FIN. -- **Win (uint16)**: Advertised receive window size in bytes. -- **Len (uint16)**: Length of the payload. -- **Payload ([]byte)**: The actual data being transmitted. - -### Fragmentation and Reassembly -- **Fragmentation**: Large application data is segmented into smaller packets, each fitting within the Maximum Segment Size (MSS). This ensures compatibility with network MTU limits and avoids IP-level fragmentation. -- **Reassembly**: On the receiving side, out-of-order packets are buffered and reassembled into the original data stream once all fragments are received. - -## Diagrams -### Handshake Protocol -- The handshake establishes a connection between two endpoints before data exchange. -- Steps: - 1. **SYN**: The initiator sends a SYN packet to start the handshake. - 2. **SYN|ACK**: The responder replies with a SYN|ACK packet, acknowledging the initiator's SYN and sending its own sequence number. - 3. **ACK**: The initiator sends an ACK packet to confirm the responder's sequence number. -- Once the ACK is received, the connection is established. - -### Data Flow -- **Write Path**: - - Application data is segmented into smaller packets (segmentation). - - Packets are serialized and sent over the network (packetization). -- **Network Transmission**: - - Packets are transmitted over the network, potentially out of order. -- **Read Path**: - - Received packets are reassembled into the original data stream (reassembly). - - Cumulative acknowledgments (ACKs) confirm receipt of data up to a specific sequence number. -- Retransmissions occur for lost packets based on retransmission timeouts (RTO). diff --git a/mod/udp/conn.go b/mod/udp/conn.go deleted file mode 100644 index e0fb31032..000000000 --- a/mod/udp/conn.go +++ /dev/null @@ -1,12 +0,0 @@ -package udp - -// DatagramWriter is how Conn sends bytes to its peer. -type DatagramWriter interface { - WriteDatagram(b []byte) error -} - -// DatagramReceiver is how Conn *receives* parsed packets when it does not own a socket read loop. -// (For active conns, the recvLoop calls HandleDatagram itself.) -type DatagramReceiver interface { - HandleDatagram(raw []byte) // fast path: parse + process (ACK/data) -} diff --git a/mod/udp/errors.go b/mod/udp/errors.go index 0b593d49a..efe883990 100644 --- a/mod/udp/errors.go +++ b/mod/udp/errors.go @@ -3,11 +3,14 @@ package udp import "errors" var ( - ErrPacketTooShort = errors.New("packet too short") - ErrListenerClosed = errors.New("listener closed") - ErrConnClosed = errors.New("connection closed") - ErrInvalidPayloadLength = errors.New("invalid payload length") - ErrClosed = errors.New("connection closed") - ErrZeroMSS = errors.New("invalid MSS") - ErrMalformedPacket = errors.New("malformed packet") + ErrPacketTooShort = errors.New("packet too short") + ErrListenerClosed = errors.New("listener closed") + ErrConnClosed = errors.New("connection closed") + ErrInvalidPayloadLength = errors.New("invalid payload length") + ErrClosed = errors.New("connection closed") + ErrZeroMSS = errors.New("invalid MSS") + ErrMalformedPacket = errors.New("malformed packet") + ErrHandshakeTimeout = errors.New("handshake timeout") + ErrHandshakeReset = errors.New("handshake reset") + ErrConnectionNotEstablished = errors.New("connection not established") ) diff --git a/mod/udp/src/conn.go b/mod/udp/src/conn.go index 01865fd1e..2eb91d516 100644 --- a/mod/udp/src/conn.go +++ b/mod/udp/src/conn.go @@ -2,66 +2,89 @@ package udp import ( - "bytes" - "context" - "io" "net" - "sync" - "sync/atomic" "time" "github.com/cryptopunkscc/astrald/mod/exonet" "github.com/cryptopunkscc/astrald/mod/udp" ) -// Conn implements src, ordered communication over a connected UDP socket. +// DatagramWriter is how Conn sends bytes to its peer. +type DatagramWriter interface { + WriteDatagram(b []byte) error +} + +// DatagramReceiver is how Conn *receives* parsed packets when it does not own a socket read loop. +// (For active conns, the recvLoop calls HandleDatagram itself.) +type DatagramReceiver interface { + HandleDatagram(raw []byte) // fast path: parse + process (ACK/data) +} + +type Handshaker interface { + Handshake() error +} + +type Fragmenter interface { +} + +// Conn represents a reliable UDP connection // Handshake/FIN are out of scope for this MVP; stream semantics only. type Conn struct { // socket / addressing udpConn *net.UDPConn localEndpoint *udp.Endpoint remoteEndpoint *udp.Endpoint - // config cfg FlowControlConfig - mss int - - // send side (guarded by sendMu unless noted) - sendMu sync.Mutex - sendBase uint32 // first unacked byte - nextSeq uint32 // next byte sequence to assign - nextSendSeq uint32 - sendQ *bytes.Buffer // queued app data (bounded by cfg.SendBufBytes) - sendCond *sync.Cond // signals space available / data added - bytesInFlight int - unacked map[uint32]segMeta // seqStart -> meta - order []uint32 // seqStarts in send order (oldest first) - - // recv side - rcvMu sync.Mutex - rcvNext uint32 - ooo map[uint32][]byte // out-of-order segments by seqStart - appBuf *ringBuffer // ordered bytes for Read() - ackPending atomic.Bool // (reserved) if you add explicit flags later - // timers - rtoMu sync.Mutex - rto time.Duration - rtoTimer *time.Timer - ackTimer *time.Timer // set on-demand in recv.go - - // control/lifecycle - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup - - closed atomic.Bool - closeOnce sync.Once - closeErr atomic.Value // error - - // write serialization (shared per UDP socket if you ever share it) - writeMu *sync.Mutex - // Add a mutex field for synchronization - mutex sync.Mutex + + state ConnState + inCh chan *Packet + closedFlag bool + + // + initialSeqNumLocal uint32 + initialSeqNumRemote uint32 + // send state + nextSeqNum uint32 + connID uint32 + sendBase uint32 // oldest unacked sequence (i.e., cumulative ACK floor). + ackedSeqNum uint32 // highest cumulative ACK seen (often == sendBase). + expected uint32 + // + unacked map[uint32]*Packet // seq -> packet + // receive state +} + +func (c *Conn) setState(state ConnState) { + c.state = state +} + +func (c *Conn) inState(state ConnState) bool { + return c.state == state +} + +func (c *Conn) Read(p []byte) (n int, err error) { + if !c.inState(StateEstablished) { + return 0, udp.ErrConnectionNotEstablished + } + //TODO implement me + panic("implement me") +} + +func (c *Conn) Write(p []byte) (n int, err error) { + if !c.inState(StateEstablished) { + return 0, udp.ErrConnectionNotEstablished + } + + //TODO implement me + panic("implement me") +} + +func (c *Conn) Close() error { + c.closedFlag = true + c.udpConn.SetReadDeadline(time.Now()) + //TODO implement me + panic("implement me") } // NewConn constructs a connection around an already-connected UDP socket. @@ -76,72 +99,11 @@ func NewConn(c *net.UDPConn, l, r *udp.Endpoint, cfg FlowControlConfig) (*Conn, localEndpoint: l, remoteEndpoint: r, cfg: cfg, - mss: cfg.MSS, - - sendBase: 1, // start at 1 so 0 can be a sentinel in traces - nextSeq: 1, - sendQ: &bytes.Buffer{}, - unacked: make(map[uint32]segMeta), - order: make([]uint32, 0, 128), - - rcvNext: 1, - ooo: make(map[uint32][]byte), - appBuf: newRingBuffer(cfg.RecvBufBytes), - - rto: cfg.RTO, - writeMu: &sync.Mutex{}, } - rc.sendCond = sync.NewCond(&rc.sendMu) - - // Start receiver loop (defined in recv.go) - rc.wg.Add(1) - go rc.recvLoop() return rc, nil } -// Read implements stream semantics. It blocks until data is available or the -// connection is closed and drained. On close, it returns any stored terminal error -// or io.EOF when the buffer is empty. -func (c *Conn) Read(p []byte) (int, error) { - n, err := c.appBuf.Read(p) - if n > 0 { - return n, nil - } - if c.closed.Load() { - if errv := c.closeErr.Load(); errv != nil { - return 0, errv.(error) - } - return 0, io.EOF - } - return n, err -} - -// Close terminates the connection and waits for the recv loop to exit. -func (c *Conn) Close() error { - c.closeOnce.Do(func() { - c.closed.Store(true) - c.cancel() - - // stop timers - c.stopRTO() // defined in send.go - c.rtoMu.Lock() - if c.ackTimer != nil { - c.ackTimer.Stop() - c.ackTimer = nil - } - c.rtoMu.Unlock() - - // wake blocked goroutines - c.sendCond.Broadcast() - c.appBuf.Close() - - _ = c.udpConn.Close() - }) - c.wg.Wait() - return nil -} - // Outbound reports whether this connection was dialed out. // For now this always returns true for Dial usage; adjust if you add a listener. func (c *Conn) Outbound() bool { return true } @@ -156,35 +118,90 @@ func (c *Conn) RemoteEndpoint() exonet.Endpoint { return c.remoteEndpoint } -// NOTE: Flagged for check (might be ai overcomplexity) -// closeWithError records the error and closes the connection. -func (c *Conn) closeWithError(err error) { - if err != nil { - c.closeErr.Store(err) +func (c *Conn) receivingLoop() { + const maxPayloadSize = 64 * 1024 + buf := make([]byte, maxPayloadSize) + for { + c.udpConn.SetReadDeadline(time.Now().Add(1 * time.Second)) + n, addr, err := c.udpConn.ReadFromUDP(buf) + if err != nil { + if c.closedFlag { + return + } + continue + } + + // NOTE: test it + if addr.String() != c.remoteEndpoint.IP.String() { + continue // not for this Conn + } + + pktData := make([]byte, n) + copy(pktData, buf[:n]) + pkt := &Packet{} + if err := pkt.Unmarshal(pktData); err != nil { + continue // drop malformed + } + if int(pkt.Len) > maxPayloadSize { + continue // invalid length + } + isControl := pkt.Flags&(FlagSYN|FlagACK|FlagFIN) != 0 && pkt.Len == 0 + if isControl { + // Block until enqueued + c.inCh <- pkt + } else { + // Drop data if channel full + select { + case c.inCh <- pkt: + default: + // drop data + } + } } - _ = c.Close() } -// seqLT compares sequence numbers with wrap-around semantics. -func seqLT(a, b uint32) bool { return int32(a-b) < 0 } +// Go +func (c *Conn) InboundPacketHandler() { + for pkt := range c.inCh { + if pkt.Flags&FlagACK != 0 { + c.handleAckPacket(pkt) + continue + } + + if pkt.Flags&(FlagSYN|FlagFIN) != 0 { + c.handleControlPacket(pkt) + continue + } -// sendPacket sends a packet over the UDP connection. -func (c *Conn) sendPacket(pkt *Packet) error { - raw, err := pkt.Marshal() - if err != nil { - return err + c.handleDataPacket(pkt) } +} - c.writeMu.Lock() - defer c.writeMu.Unlock() +// handleAckPacket processes ACK packets +func (c *Conn) handleAckPacket(pkt *Packet) { + ack := pkt.Ack + for seq := range c.unacked { + if seq <= ack { + delete(c.unacked, seq) + } + } + c.sendBase = ack + 1 +} - _, err = c.udpConn.Write(raw) - return err +// handleControlPacket processes SYN, FIN, and other control packets +func (c *Conn) handleControlPacket(pkt *Packet) { + // Example: handle SYN, FIN, or other control logic + if pkt.Flags&FlagSYN != 0 { + // ...handle SYN logic... + } + if pkt.Flags&FlagFIN != 0 { + // ...handle FIN logic... + } + // ...handle other control flags as needed... } -// handleRTO handles retransmission timeouts by retransmitting the earliest unacked segment. -func (c *Conn) handleRTO() { - // Implementation for retransmission timeout handling - // This will involve retransmitting the earliest unacked segment - // and applying exponential backoff to the retransmission timer. +// handleDataPacket processes data packets +func (c *Conn) handleDataPacket(pkt *Packet) { + // Example: deliver to receive buffer, update expected, send ACK, etc. + // ...implement data delivery and ACK logic... } diff --git a/mod/udp/src/conn_handshake.go b/mod/udp/src/conn_handshake.go new file mode 100644 index 000000000..55a29e0f2 --- /dev/null +++ b/mod/udp/src/conn_handshake.go @@ -0,0 +1,113 @@ +package udp + +import ( + "context" + "fmt" + + "github.com/cryptopunkscc/astrald/mod/udp" +) + +type ConnState int + +const ( + StateClosed ConnState = iota // no connection / after Close() + StateListen // (server only) waiting for SYN + StateSynSent // client sent SYN, waiting for SYN|ACK + StateSynReceived // server got SYN, sent SYN|ACK, waiting for final ACK + StateEstablished // handshake complete, normal data flow + StateFinSent // Close() called, FIN sent, waiting for ACK + StateFinReceived // FIN received, waiting for local Close() + StateTimeWait // (optional) short wait after FIN to absorb retransmits +) + +func (c *Conn) startClientHandshake(ctx context.Context) error { + c.initialSeqNumLocal = randUint32NZ() + c.connID = c.initialSeqNumLocal + c.setState(StateSynSent) + + err := c.sendControl(FlagSYN, c.initialSeqNumLocal, 0) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + return udp.ErrHandshakeTimeout + case pkt := <-c.inCh: + if pkt.Flags&(FlagSYN|FlagACK) == (FlagSYN|FlagACK) && pkt.Ack == c.initialSeqNumLocal+1 && pkt.Seq != 0 { + c.initialSeqNumRemote = pkt.Seq + + err := c.sendControl(FlagACK, c.initialSeqNumLocal+1, c.initialSeqNumRemote+1) + if err != nil { + return err + } + c.setState(StateEstablished) + go c.InboundPacketHandler() + return nil + } + } + } +} + +func (c *Conn) startServerHandshake(ctx context.Context, synPkt *Packet) error { + c.initialSeqNumRemote = synPkt.Seq + c.connID = synPkt.Seq + c.initialSeqNumLocal = randUint32NZ() + c.setState(StateSynReceived) + + if err := c.sendControl(FlagSYN|FlagACK, c.initialSeqNumLocal, c.initialSeqNumRemote+1); err != nil { + return err + } + for { + select { + case <-ctx.Done(): + return udp.ErrHandshakeTimeout + case pkt := <-c.inCh: + if pkt.Flags&FlagACK != 0 && pkt.Ack == c.initialSeqNumLocal+1 { + c.setState(StateEstablished) + go c.InboundPacketHandler() + return nil + } + } + } +} + +func (c *Conn) sendControl(flags uint8, seq, ack uint32) error { + pkt := &Packet{ + Seq: seq, + Ack: ack, + Flags: flags, + Len: 0, + } + + data, err := pkt.Marshal() + if err != nil { + return fmt.Errorf(`sendControl failed to marshal control packet: %w`, err) + } + + if c.udpConn == nil { + return udp.ErrConnClosed + } + + _, err = c.udpConn.Write(data) + if err != nil { + return fmt.Errorf(`sendControl failed to send control packet: %w`, err) + } + + return err +} + +func (c *Conn) handleInbound(pkt *Packet) { + // ...process handshake/control/data packets... +} + +// TODO: implement proper random non-zero uint32 generator +func randUint32NZ() uint32 { + // ...generate non-zero random uint32... + return 1 // stub +} + +// notifyInbound is a channel for inbound packets +// ...in Conn struct... +// notifyInbound chan *Packet diff --git a/mod/udp/src/conn_handshake_test.go b/mod/udp/src/conn_handshake_test.go new file mode 100644 index 000000000..dcd07acdd --- /dev/null +++ b/mod/udp/src/conn_handshake_test.go @@ -0,0 +1,132 @@ +package udp + +import ( + "context" + "testing" + "time" +) + +func TestClientHandshake_Success(t *testing.T) { + conn := &Conn{ + notifyInbound: make(chan *Packet, 1), + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + // Simulate server SYN|ACK response + go func() { + time.Sleep(10 * time.Millisecond) + conn.notifyInbound <- &Packet{ + Seq: 12345, // server ISN + Ack: 2, // client ISN + 1 + Flags: FlagSYN | FlagACK, + Len: 0, + } + }() + + conn.initialSeqNumLocal = 1 // deterministic for test + conn.connID = 1 + + err := conn.startClientHandshake(ctx) + if err != nil { + t.Fatalf("expected handshake success, got error: %v", err) + } + if conn.state != StateEstablished { + t.Fatalf("expected StateEstablished, got %v", conn.state) + } + if conn.initialSeqNumRemote != 12345 { + t.Fatalf("expected initialSeqNumRemote=12345, got %v", conn.initialSeqNumRemote) + } +} + +func TestClientHandshake_Timeout(t *testing.T) { + conn := &Conn{ + notifyInbound: make(chan *Packet, 1), + } + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := conn.startClientHandshake(ctx) + if err == nil || err.Error() != "handshake timeout" { + t.Fatalf("expected handshake timeout error, got: %v", err) + } +} + +func TestServerHandshake_Success(t *testing.T) { + conn := &Conn{ + notifyInbound: make(chan *Packet, 1), + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + // Simulate client ACK response + go func() { + time.Sleep(10 * time.Millisecond) + conn.notifyInbound <- &Packet{ + Ack: 2, // server ISN + 1 + Flags: FlagACK, + Len: 0, + } + }() + + synPkt := &Packet{Seq: 1, Flags: FlagSYN, Len: 0} + err := conn.startServerHandshake(ctx, synPkt) + if err != nil { + t.Fatalf("expected handshake success, got error: %v", err) + } + if conn.state != StateEstablished { + t.Fatalf("expected StateEstablished, got %v", conn.state) + } +} + +func TestServerHandshake_Timeout(t *testing.T) { + conn := &Conn{ + notifyInbound: make(chan *Packet, 1), + } + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + synPkt := &Packet{Seq: 1, Flags: FlagSYN, Len: 0} + err := conn.startServerHandshake(ctx, synPkt) + if err == nil || err.Error() != "handshake timeout" { + t.Fatalf("expected handshake timeout error, got: %v", err) + } +} + +func TestClientHandshake_BadAckIgnored(t *testing.T) { + conn := &Conn{ + notifyInbound: make(chan *Packet, 2), + } + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // Simulate server response with wrong Ack + go func() { + time.Sleep(10 * time.Millisecond) + conn.notifyInbound <- &Packet{ + Seq: 12345, + Ack: 999, // wrong ack + Flags: FlagSYN | FlagACK, + Len: 0, + } + // Then correct response + time.Sleep(10 * time.Millisecond) + conn.notifyInbound <- &Packet{ + Seq: 12345, + Ack: 2, + Flags: FlagSYN | FlagACK, + Len: 0, + } + }() + + conn.initialSeqNumLocal = 1 + conn.connID = 1 + + err := conn.startClientHandshake(ctx) + if err != nil { + t.Fatalf("expected handshake success, got error: %v", err) + } + if conn.state != StateEstablished { + t.Fatalf("expected StateEstablished, got %v", conn.state) + } +} diff --git a/mod/udp/src/packet.go b/mod/udp/src/packet.go index a968d03cf..cb7d499b5 100644 --- a/mod/udp/src/packet.go +++ b/mod/udp/src/packet.go @@ -3,6 +3,7 @@ package udp import ( "bytes" "encoding/binary" + "time" "github.com/cryptopunkscc/astrald/mod/udp" ) @@ -14,6 +15,13 @@ const ( ) // Packet represents a src UDP packet with TCP-like header +// Handshake usage: +// +// ConnID: carried in Seq (ISN) +// Seq: initial sequence number (ISN) +// Ack: cumulative acknowledgment of peer's ISN+1 +// Flags: SYN, ACK, FIN (DATA if Len > 0) +// Len: 0 for control packets (SYN, SYN|ACK, ACK, FIN) type Packet struct { Seq uint32 // Sequence number (first byte seq of this segment) Ack uint32 // Acknowledgment number (cumulative ack: all bytes < Ack received) @@ -104,3 +112,8 @@ func UnmarshalPacket(data []byte) (*Packet, error) { return pkt, nil } + +type SentPacket struct { + pkt *Packet + sentTime time.Time +} diff --git a/mod/udp/src/recv.go b/mod/udp/src/recv.go deleted file mode 100644 index f4a29d612..000000000 --- a/mod/udp/src/recv.go +++ /dev/null @@ -1,128 +0,0 @@ -package udp - -import ( - "net" - "time" -) - -// recvLoop parses incoming datagrams, processes ACKs and data, and coalesces ACKs back. -func (c *Conn) recvLoop() { - defer c.wg.Done() - - buf := make([]byte, 64*1024) - - for { - _ = c.udpConn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) - n, _, err := c.udpConn.ReadFromUDP(buf) - if err != nil { - if ne, ok := err.(net.Error); ok && ne.Timeout() { - if c.closed.Load() || c.ctx.Err() != nil { - return - } - continue - } - if !c.closed.Load() { - c.closeWithError(err) - } - return - } - - pkt, uerr := UnmarshalPacket(buf[:n]) - if uerr != nil { - // drop malformed - continue - } - - if pkt.Ack != 0 || (pkt.Flags&FlagACK) != 0 { - c.advanceAck(pkt.Ack) - } - - if pkt.Len > 0 { - if err := c.handleData(pkt); err != nil { - c.closeWithError(err) - return - } - } - } -} - -// handleData commits in-order payload to appBuf, buffers out-of-order, -// and schedules a delayed pure ACK for the burst. -func (c *Conn) handleData(pkt *Packet) error { - ackNeeded := false - - c.rcvMu.Lock() - switch { - case pkt.Seq == c.rcvNext: - payload := append([]byte(nil), pkt.Payload...) - c.rcvMu.Unlock() - - // Do not hold rcvMu while blocking - if _, err := c.appBuf.WriteAll(payload); err != nil { - return err - } - - c.rcvMu.Lock() - c.rcvNext += uint32(len(payload)) - ackNeeded = true - - // drain contiguous out-of-order - for { - next, ok := c.ooo[c.rcvNext] - if !ok { - break - } - delete(c.ooo, c.rcvNext) - data := append([]byte(nil), next...) - c.rcvMu.Unlock() - if _, err := c.appBuf.WriteAll(data); err != nil { - return err - } - c.rcvMu.Lock() - c.rcvNext += uint32(len(data)) - } - c.rcvMu.Unlock() - - case seqLT(c.rcvNext, pkt.Seq): - // out-of-order: drop if would exceed RecvBuf cap - if pkt.Len <= uint16(c.cfg.RecvBufBytes) { - c.ooo[pkt.Seq] = append([]byte(nil), pkt.Payload...) - ackNeeded = true // peer will infer gap - } - c.rcvMu.Unlock() - - default: - // duplicate; ignore - c.rcvMu.Unlock() - } - - if ackNeeded { - c.armAckDelay() - } - return nil -} - -// sendPureACK emits a standalone cumulative ACK for rcvNext. -func (c *Conn) sendPureACK() error { - // snapshot current cumulative ack - c.rcvMu.Lock() - ack := c.rcvNext - c.rcvMu.Unlock() - - pkt := Packet{ - Seq: c.nextSeq, // sender ignores Seq on pure ACK - Ack: ack, - Flags: FlagACK, - Win: 0, - Len: 0, - } - raw, err := pkt.Marshal() - if err != nil { - return err - } - - c.writeMu.Lock() - _, werr := c.udpConn.Write(raw) - c.writeMu.Unlock() - return werr -} diff --git a/mod/udp/src/seg_meta.go b/mod/udp/src/seg_meta.go deleted file mode 100644 index f7e8e29a0..000000000 --- a/mod/udp/src/seg_meta.go +++ /dev/null @@ -1,12 +0,0 @@ -package udp - -import "time" - -// segMeta stores metadata for an unacked segment -type segMeta struct { - data []byte - sentAt time.Time - retries int - seqStart uint32 - length int -} diff --git a/mod/udp/src/send.go b/mod/udp/src/send.go deleted file mode 100644 index 48ee8baba..000000000 --- a/mod/udp/src/send.go +++ /dev/null @@ -1,250 +0,0 @@ -package udp - -import ( - "context" - "time" - - "github.com/cryptopunkscc/astrald/mod/udp" -) - -// Write implements batched segmentation and send. -func (c *Conn) Write(p []byte) (int, error) { - if c.closed.Load() { - return 0, udp.ErrClosed - } - written := 0 - - for written < len(p) { - chunk := p[written:] - - c.sendMu.Lock() - for c.sendQ.Len() >= c.cfg.SendBufBytes && !c.closed.Load() { - c.sendCond.Wait() - } - if c.closed.Load() { - c.sendMu.Unlock() - return written, udp.ErrClosed - } - space := c.cfg.SendBufBytes - c.sendQ.Len() - if space > 0 { - toCopy := len(chunk) - if toCopy > space { - toCopy = space - } - _, err := c.sendQ.Write(chunk[:toCopy]) - if err != nil { - c.sendMu.Unlock() - return written, err - } - written += toCopy - } - c.sendMu.Unlock() - - if err := c.flushSendQueue(); err != nil { - return written, err - } - } - - return written, nil -} - -// flushSendQueue cuts ≤ MSS segments from sendQ while within WindowBytes, -// builds packets (piggybacking ACK), records unacked, and writes them back-to-back. -func (c *Conn) flushSendQueue() error { - var bufs [][]byte - var ack uint32 - - c.sendMu.Lock() - c.rcvMu.Lock() - ack = c.rcvNext - c.rcvMu.Unlock() - - win := c.cfg.WindowBytes - c.bytesInFlight - for win > 0 && c.sendQ.Len() > 0 { - segLen := c.mss - if segLen > c.sendQ.Len() { - segLen = c.sendQ.Len() - } - if segLen > win { - segLen = win - } - if segLen <= 0 { - break - } - - payload := c.sendQ.Next(segLen) - seq := c.nextSeq - - pkt := Packet{ - Seq: seq, - Ack: ack, - Flags: FlagACK, - Win: 0, - Len: uint16(segLen), - Payload: payload, - } - raw, err := pkt.Marshal() - if err != nil { - c.sendMu.Unlock() - return err - } - bufs = append(bufs, raw) - - meta := segMeta{ - data: append([]byte(nil), payload...), - sentAt: time.Now(), - retries: 0, - seqStart: seq, - length: segLen, - } - c.unacked[seq] = meta - c.order = append(c.order, seq) - - c.nextSeq += uint32(segLen) - c.bytesInFlight += segLen - win -= segLen - } - c.sendMu.Unlock() - - if len(bufs) == 0 { - return nil - } - - c.writeMu.Lock() - var writeErr error - for _, b := range bufs { - if _, writeErr = c.udpConn.Write(b); writeErr != nil { - break - } - } - c.writeMu.Unlock() - if writeErr != nil { - return writeErr - } - - c.startRTOIfNeededLocked() - return nil -} - -// startRTOIfNeededLocked arms the RTO timer when unacked is non-empty and no timer running. -func (c *Conn) startRTOIfNeededLocked() { - c.sendMu.Lock() - need := len(c.unacked) > 0 - c.sendMu.Unlock() - if !need { - return - } - - c.rtoMu.Lock() - defer c.rtoMu.Unlock() - if c.rtoTimer != nil { - return - } - d := c.rto - if d <= 0 { - d = c.cfg.RTO - } - c.rtoTimer = time.AfterFunc(d, c.onRTOTimeout) -} - -// onRTOTimeout retransmits the earliest unacked segment with backoff. -func (c *Conn) onRTOTimeout() { - var seq uint32 - var meta segMeta - var ok bool - - c.sendMu.Lock() - if len(c.order) == 0 { - c.sendMu.Unlock() - c.stopRTO() - return - } - seq = c.order[0] - meta, ok = c.unacked[seq] - if !ok { - c.order = c.order[1:] - c.sendMu.Unlock() - c.startRTOIfNeededLocked() - return - } - - c.rcvMu.Lock() - ack := c.rcvNext - c.rcvMu.Unlock() - - pkt := Packet{ - Seq: meta.seqStart, - Ack: ack, - Flags: FlagACK, - Win: 0, - Len: uint16(meta.length), - Payload: meta.data, - } - - raw, err := pkt.Marshal() - c.sendMu.Unlock() - if err != nil { - c.closeWithError(err) - return - } - - c.writeMu.Lock() - _, werr := c.udpConn.Write(raw) - c.writeMu.Unlock() - if werr != nil { - c.closeWithError(werr) - return - } - - c.sendMu.Lock() - meta.retries++ - meta.sentAt = time.Now() - c.unacked[seq] = meta - c.rto *= 2 - if c.rto > c.cfg.RTOMax { - c.rto = c.cfg.RTOMax - } - overLimit := meta.retries > c.cfg.RetryLimit - c.sendMu.Unlock() - - if overLimit { - c.closeWithError(context.DeadlineExceeded) - return - } - - c.rtoMu.Lock() - if c.rtoTimer != nil { - c.rtoTimer.Reset(c.rto) - } - c.rtoMu.Unlock() -} - -// advanceAck removes fully-acked segments up to 'ack' and manages timers/backpressure. -func (c *Conn) advanceAck(ack uint32) { - c.sendMu.Lock() - changed := false - for len(c.order) > 0 { - seq := c.order[0] - meta := c.unacked[seq] - end := seq + uint32(meta.length) - // if segment end <= ack, it is fully acked - if seqLT(end, ack) || end == ack { - delete(c.unacked, seq) - c.order = c.order[1:] - c.bytesInFlight -= meta.length - changed = true - } else { - break - } - } - empty := len(c.unacked) == 0 - c.sendMu.Unlock() - - if changed { - // let writers proceed if sendQ was full - c.sendCond.Signal() - } - if empty { - c.stopRTO() - } -} diff --git a/mod/udp/src/server.go b/mod/udp/src/server.go index f61033e97..090eb61d0 100644 --- a/mod/udp/src/server.go +++ b/mod/udp/src/server.go @@ -5,14 +5,7 @@ import ( "sync" "github.com/cryptopunkscc/astrald/astral" -) - -// Connection states -const ( - StateClosed = iota // Connection is closed - StateSynSent // SYN sent, waiting for SYN-ACK - StateSynReceived // SYN received, waiting for ACK - StateEstablished // Connection established + "github.com/cryptopunkscc/astrald/mod/udp" ) // Server implements src UDP listener with connection demultiplexing @@ -48,8 +41,15 @@ func (s *Server) Run(ctx *astral.Context) error { s.log.Info("started server at %v", listener.LocalAddr()) defer s.log.Info("stopped server at %v", listener.LocalAddr()) + localEndpoint, err := udp.ParseEndpoint(listener.LocalAddr(). + String()) + if err != nil { + s.log.Errorv(1, "error parsing local endpoint: %v", err) + return err + } + s.wg.Add(1) - go s.readLoop() + go s.readLoop(ctx, localEndpoint) <-ctx.Done() s.Close() @@ -70,11 +70,10 @@ func (s *Server) Close() error { return nil } -// readLoop handles incoming datagrams and routes them to connections -func (s *Server) readLoop() { +func (s *Server) readLoop(ctx *astral.Context, localEndpoint *udp.Endpoint) { defer s.wg.Done() - buf := make([]byte, 64*1024) // Large buffer for high throughput + buf := make([]byte, 64*1024) // TODO: Max packet size? for { n, addr, err := s.listener.ReadFromUDP(buf) if err != nil { @@ -87,10 +86,47 @@ func (s *Server) readLoop() { } } - s.handlePacket(buf[:n], addr) - } -} + pkt := &Packet{} + if err := pkt.Unmarshal(buf[:n]); err != nil { + s.log.Errorv(1, "packet unmarshal error from %v: %v", addr, err) + continue // drop malformed + } + + remoteKey := addr.String() + s.mutex.Lock() + conn, foundConn := s.conns[remoteKey] + if !foundConn && pkt.Flags&FlagSYN != 0 { + remoteEndpoint, err := udp.ParseEndpoint(addr.String()) + if err != nil { + s.log.Errorv(1, "ParseEndpoint error for %v: %v", addr, err) + continue + } + + conn, err = NewConn(s.listener, localEndpoint, remoteEndpoint, s.Module.config.FlowControl) + if err != nil { + s.log.Errorv(1, "NewConn error for %v: %v", addr, err) + s.mutex.Unlock() + continue + } + + conn.inCh = make(chan *Packet, 128) + s.conns[remoteKey] = conn + go func() { + err := conn.startServerHandshake(ctx, pkt) + if err != nil { + s.log.Errorv(1, "handshake error for %v: %v", addr, err) + } + }() + } + s.mutex.Unlock() -// handlePacket processes an incoming packet and routes it to the appropriate connection -func (s *Server) handlePacket(data []byte, addr *net.UDPAddr) { + if conn != nil { + select { + case conn.inCh <- pkt: + // success + default: + s.log.Errorv(1, "inCh full for %v, dropping packet", addr) + } + } + } } diff --git a/mod/udp/src/timers.go b/mod/udp/src/timers.go deleted file mode 100644 index 14a263bbc..000000000 --- a/mod/udp/src/timers.go +++ /dev/null @@ -1,45 +0,0 @@ -package udp - -import ( - "time" -) - -// armRTO starts or resets the retransmission timer -func (c *Conn) armRTO(d time.Duration) { - c.rtoMu.Lock() - defer c.rtoMu.Unlock() - - if c.rtoTimer != nil { - c.rtoTimer.Stop() - } - c.rtoTimer = time.AfterFunc(d, c.handleRTO) -} - -// stopRTO safely stops the retransmission timer -func (c *Conn) stopRTO() { - c.rtoMu.Lock() - defer c.rtoMu.Unlock() - - if c.rtoTimer != nil { - c.rtoTimer.Stop() - c.rtoTimer = nil - } -} - -// armAckDelay schedules a pure ACK to be sent soon -func (c *Conn) armAckDelay() { - c.rtoMu.Lock() - defer c.rtoMu.Unlock() - - if c.ackTimer == nil { - c.ackTimer = time.AfterFunc(c.cfg.AckDelay, c.sendPureACK) - } else { - c.ackTimer.Reset(c.cfg.AckDelay) - } -} - -// armAckDelayTimerLocked initializes the ACK delay timer (called during initialization) -// Note: Unlike armAckDelay, this is called from NewConn when mutex is already held -func (c *Conn) armAckDelayTimerLocked() { - c.ackTimer = time.AfterFunc(c.cfg.AckDelay, c.sendPureACK) -} From 2d20ca3cfa8a782648986f928d7430ad41a33cc3 Mon Sep 17 00:00:00 2001 From: Rekseto Date: Tue, 30 Sep 2025 04:25:22 +0200 Subject: [PATCH 04/13] refactor: enhance reliable UDP configuration and connection management --- go.mod | 1 + go.sum | 2 + mod/nodes/src/op_streams.go | 3 +- mod/udp/README.md | 28 +++ mod/udp/conn.go | 4 + mod/udp/endpoint.go | 7 + mod/udp/errors.go | 5 +- mod/udp/src/config.go | 97 +++++---- mod/udp/src/config_test.go | 14 +- mod/udp/src/conn.go | 303 +++++++++++++++++------------ mod/udp/src/conn_handshake.go | 126 ++++++++---- mod/udp/src/conn_handshake_test.go | 132 ------------- mod/udp/src/fragmenter.go | 88 +++++++++ mod/udp/src/fragmenter_test.go | 196 +++++++++++++++++++ mod/udp/src/plan.md | 209 -------------------- mod/udp/src/receive.go | 148 ++++++++++++++ mod/udp/src/retransmissions.go | 131 +++++++++++++ mod/udp/src/ring_buffer.go | 159 --------------- mod/udp/src/ring_buffer_test.go | 170 ---------------- mod/udp/src/send.go | 189 ++++++++++++++++++ mod/udp/src/server.go | 2 +- 21 files changed, 1126 insertions(+), 888 deletions(-) create mode 100644 mod/udp/conn.go delete mode 100644 mod/udp/src/conn_handshake_test.go create mode 100644 mod/udp/src/fragmenter.go create mode 100644 mod/udp/src/fragmenter_test.go delete mode 100644 mod/udp/src/plan.md create mode 100644 mod/udp/src/receive.go create mode 100644 mod/udp/src/retransmissions.go delete mode 100644 mod/udp/src/ring_buffer.go delete mode 100644 mod/udp/src/ring_buffer_test.go create mode 100644 mod/udp/src/send.go diff --git a/go.mod b/go.mod index ca8c98bc3..1e202657c 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/smallnest/ringbuffer v0.0.0-20250317021400-0da97b586904 // indirect golang.org/x/sys v0.32.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/libc v1.24.1 // indirect diff --git a/go.sum b/go.sum index 870a14f4e..313787fdc 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/smallnest/ringbuffer v0.0.0-20250317021400-0da97b586904 h1:OoG1xZV7CXnP2/Udl1ybEgTEds9XXA3NHWg+OR3c/a8= +github.com/smallnest/ringbuffer v0.0.0-20250317021400-0da97b586904/go.mod h1:tAG61zBM1DYRaGIPloumExGvScf08oHuo0kFoOqdbT0= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/wailsapp/mimetype v1.4.1 h1:pQN9ycO7uo4vsUUuPeHEYoUkLVkaRntMnHJxVwYhwHs= diff --git a/mod/nodes/src/op_streams.go b/mod/nodes/src/op_streams.go index 88172e689..2fec34e21 100644 --- a/mod/nodes/src/op_streams.go +++ b/mod/nodes/src/op_streams.go @@ -1,10 +1,11 @@ package nodes import ( + "slices" + "github.com/cryptopunkscc/astrald/astral" "github.com/cryptopunkscc/astrald/mod/nodes" "github.com/cryptopunkscc/astrald/mod/shell" - "slices" ) type opStreamsArgs struct { diff --git a/mod/udp/README.md b/mod/udp/README.md index e69de29bb..2bfda88ab 100644 --- a/mod/udp/README.md +++ b/mod/udp/README.md @@ -0,0 +1,28 @@ +# UDP Module + +This module implements UDP-based communication for the Astrald platform, enabling fast, connectionless data transfer between nodes. + +## Current Proof of Concept (PoC) State +- Basic UDP packet send/receive functionality +- Simple endpoint resolution +- Fragmentation and reassembly of large packets +- Minimal handshake and connection logic +- Configuration via `config.go` + +## Key Components +- **Dial:** Initiates outbound UDP connections. Integrates with Astral's context (`astral.Context`), endpoint abstraction (`exonet.Endpoint`), and parses endpoints using UDP utilities. Returns a reliable connection for Astral's exonet system. Similar to TCP's Dial, but uses UDP-specific logic and interacts with Astral's node and identity systems. +- **ResolveEndpoints:** Resolves available UDP endpoints for a node. Uses Astral's identity system (`astral.Identity`) to verify node identity and returns endpoints as `exonet.Endpoint` via Astral's signal utilities (`sig.ArrayToChan`). Connects with Astral's node and endpoint management. +- **Loader:** Loads the UDP module into Astral. Connects with Astral's node (`astral.Node`), asset management (`core/assets`), and logging (`astral/log.Logger`). Loads configuration from assets, parses public endpoints, and registers the module with Astral's core module system for lifecycle management. +- **Unpack:** Handles packet reassembly for Astral's exonet system. Uses UDP endpoint parsing and error handling, returning endpoints for Astral's network abstraction. Connects with Astral's error and endpoint utilities. +- **Server:** Listens for incoming UDP connections. Uses Astral's context and logging, manages connections and server lifecycle, and integrates with Astral's configuration and endpoint parsing. Registers endpoints and connections with Astral's node and router systems. +- +## Possible Improvements for Production Readiness +- Advanced flow control mechanisms (dynamic window sizing, congestion control) +- Fast retransmissions and selective acknowledgments (SACK) +- Adaptive retransmission timers (RTT estimation) +- Performance optimizations (buffering, batching) +- Support for NAT traversal and hole punching +- Comprehensive test coverage +- Documentation and usage examples + +This module is currently a proof of concept and not recommended for production use without further development. diff --git a/mod/udp/conn.go b/mod/udp/conn.go new file mode 100644 index 000000000..df1150adc --- /dev/null +++ b/mod/udp/conn.go @@ -0,0 +1,4 @@ +package udp + +type ReliableUdpConn interface { +} diff --git a/mod/udp/endpoint.go b/mod/udp/endpoint.go index 94d2c54c5..d8629cb25 100644 --- a/mod/udp/endpoint.go +++ b/mod/udp/endpoint.go @@ -135,6 +135,13 @@ func ParseEndpoint(s string) (*Endpoint, error) { }, nil } +func (e *Endpoint) UDPAddr() *net.UDPAddr { + return &net.UDPAddr{ + IP: net.ParseIP(e.IP.String()), + Port: int(e.Port), + } +} + func init() { _ = astral.DefaultBlueprints.Add(&Endpoint{}) diff --git a/mod/udp/errors.go b/mod/udp/errors.go index efe883990..4b8924e12 100644 --- a/mod/udp/errors.go +++ b/mod/udp/errors.go @@ -3,14 +3,13 @@ package udp import "errors" var ( + ErrRetransmissionLimitExceeded = errors.New( + "retransmissions limit exceeded") ErrPacketTooShort = errors.New("packet too short") - ErrListenerClosed = errors.New("listener closed") ErrConnClosed = errors.New("connection closed") ErrInvalidPayloadLength = errors.New("invalid payload length") - ErrClosed = errors.New("connection closed") ErrZeroMSS = errors.New("invalid MSS") ErrMalformedPacket = errors.New("malformed packet") ErrHandshakeTimeout = errors.New("handshake timeout") - ErrHandshakeReset = errors.New("handshake reset") ErrConnectionNotEstablished = errors.New("connection not established") ) diff --git a/mod/udp/src/config.go b/mod/udp/src/config.go index f7e2f9ee9..30df398f0 100644 --- a/mod/udp/src/config.go +++ b/mod/udp/src/config.go @@ -47,51 +47,61 @@ const ( MaxSendBufBytes = 8 << 20 ) +const ( + DefaultWndPkts = 32 + MinWndPkts = 1 + MaxWndPkts = 256 +) + // Config holds general settings for the UDP module. type Config struct { ListenPort int `yaml:"listen_port,omitempty"` // Port to listen on for incoming connections (default 1791) PublicEndpoints []string `yaml:"public_endpoints,omitempty"` DialTimeout time.Duration `yaml:"dial_timeout,omitempty"` // Timeout for dialing connections (default 1 minute) - FlowControl FlowControlConfig `yaml:"flow_control,omitempty"` // Flow control settings for UDP connections + FlowControl ReliableTransportConfig `yaml:"flow_control,omitempty"` // Flow control settings for UDP connections } -// FlowControlConfig holds configuration for individual UDP connections. -type FlowControlConfig struct { - MSS int // Maximum Segment Size (default 1187) - WindowBytes int // Send window size in bytes (default 16 * MSS) - RTO time.Duration // Initial retransmission timeout (default 500ms) - RTOMax time.Duration // Maximum retransmission timeout (default 4s) - RetryLimit int // Maximum retransmission attempts (default 8) - IdleTimeout time.Duration // Connection idle timeout (default 60s) - AckDelay time.Duration // Delayed ACK timer (default 25ms) - RecvBufBytes int // Receive buffer size (default 1MB) - SendBufBytes int // Send buffer size (default 1MB) +// ReliableTransportConfig holds configuration for individual UDP connections. +type ReliableTransportConfig struct { + MaxSegmentSize int // Maximum Segment Size (default 1187) + MaxWindowBytes int // Send window size in bytes (default 16 * MaxSegmentSize) + MaxWindowPackets int // Max in-flight packets (packet-count window, default 32) + RetransmissionInterval time.Duration // Initial retransmission timeout (default 500ms) + MaxRetransmissionInterval time.Duration // Maximum retransmission timeout (default 4s) + RetransmissionLimit int // Maximum retransmission attempts (default 8) + IdleTimeout time.Duration // Connection idle timeout (default 60s) + AckDelay time.Duration // Delayed ACK timer (default 25ms) + RecvBufBytes int // Receive buffer size (default 1MB) + SendBufBytes int // Send buffer size (default 1MB) } // Normalize sets sensible defaults for zero-values, clamps to safe ranges, and enforces invariants. // See RFC 9000, RFC 8085, RFC 6298 for rationale. -func (c *FlowControlConfig) Normalize() { - c.setDefaults() +func (c *ReliableTransportConfig) Normalize() { + c.SetDefaults() c.clampValues() } -// setDefaults initializes zero-values with sensible defaults. -func (c *FlowControlConfig) setDefaults() { - if c.MSS == 0 { - c.MSS = DefaultMSS +// SetDefaults initializes zero-values with sensible defaults. +func (c *ReliableTransportConfig) SetDefaults() { + if c.MaxSegmentSize == 0 { + c.MaxSegmentSize = DefaultMSS + } + if c.MaxWindowBytes == 0 { + c.MaxWindowBytes = DefaultWindowBytes } - if c.WindowBytes == 0 { - c.WindowBytes = DefaultWindowBytes + if c.MaxWindowPackets == 0 { + c.MaxWindowPackets = DefaultWndPkts } - if c.RTO == 0 { - c.RTO = DefaultRTO + if c.RetransmissionInterval == 0 { + c.RetransmissionInterval = DefaultRTO } - if c.RTOMax == 0 { - c.RTOMax = DefaultRTOMax + if c.MaxRetransmissionInterval == 0 { + c.MaxRetransmissionInterval = DefaultRTOMax } - if c.RetryLimit == 0 { - c.RetryLimit = DefaultRetries + if c.RetransmissionLimit == 0 { + c.RetransmissionLimit = DefaultRetries } if c.AckDelay == 0 { c.AckDelay = DefaultAckDelay @@ -110,13 +120,14 @@ func (c *FlowControlConfig) setDefaults() { // all of which are stated at the top of this file) // clampValues ensures all fields are within safe ranges and enforces invariants. -func (c *FlowControlConfig) clampValues() { - c.MSS = clampInt(c.MSS, MinMSS, MaxMSS) - c.WindowBytes = clampInt(c.WindowBytes, c.MSS, MaxWindowBytes) - c.RTO = clampDur(c.RTO, MinRTO, MaxRTOCeiling) - c.RTOMax = clampDur(c.RTOMax, c.RTO, MaxRTOCeiling) - c.RetryLimit = clampInt(c.RetryLimit, MinRetries, MaxRetries) - c.AckDelay = clampDur(c.AckDelay, MinAckDelay, c.RTO/2) +func (c *ReliableTransportConfig) clampValues() { + c.MaxSegmentSize = clampInt(c.MaxSegmentSize, MinMSS, MaxMSS) + c.MaxWindowBytes = clampInt(c.MaxWindowBytes, c.MaxSegmentSize, MaxWindowBytes) + c.MaxWindowPackets = clampInt(c.MaxWindowPackets, MinWndPkts, MaxWndPkts) + c.RetransmissionInterval = clampDur(c.RetransmissionInterval, MinRTO, MaxRTOCeiling) + c.MaxRetransmissionInterval = clampDur(c.MaxRetransmissionInterval, c.RetransmissionInterval, MaxRTOCeiling) + c.RetransmissionLimit = clampInt(c.RetransmissionLimit, MinRetries, MaxRetries) + c.AckDelay = clampDur(c.AckDelay, MinAckDelay, c.RetransmissionInterval/2) c.RecvBufBytes = clampInt(c.RecvBufBytes, MinRecvBufBytes, MaxRecvBufBytes) c.SendBufBytes = clampInt(c.SendBufBytes, MinSendBufBytes, MaxSendBufBytes) } @@ -146,16 +157,16 @@ func clampDur(v, lo, hi time.Duration) time.Duration { var defaultConfig = Config{ ListenPort: ListenPort, DialTimeout: time.Minute, - FlowControl: FlowControlConfig{ - MSS: DefaultMSS, - WindowBytes: DefaultWindowBytes, - RTO: DefaultRTO, - RTOMax: DefaultRTOMax, - RetryLimit: DefaultRetries, - IdleTimeout: 60 * time.Second, // Default idle timeout of 1 minute - AckDelay: DefaultAckDelay, - RecvBufBytes: DefaultRecvBufBytes, - SendBufBytes: DefaultSendBufBytes, + FlowControl: ReliableTransportConfig{ + MaxSegmentSize: DefaultMSS, + MaxWindowBytes: DefaultWindowBytes, + RetransmissionInterval: DefaultRTO, + MaxRetransmissionInterval: DefaultRTOMax, + RetransmissionLimit: DefaultRetries, + IdleTimeout: 60 * time.Second, // Default idle timeout of 1 minute + AckDelay: DefaultAckDelay, + RecvBufBytes: DefaultRecvBufBytes, + SendBufBytes: DefaultSendBufBytes, }, } diff --git a/mod/udp/src/config_test.go b/mod/udp/src/config_test.go index b9fb7ef15..173f34e21 100644 --- a/mod/udp/src/config_test.go +++ b/mod/udp/src/config_test.go @@ -15,12 +15,12 @@ func TestFlowControlConfigDefaults(t *testing.T) { func TestFlowControlConfigClamp(t *testing.T) { tests := []struct { name string - input FlowControlConfig - expected FlowControlConfig + input ReliableTransportConfig + expected ReliableTransportConfig }{ { name: "Values below range are clamped", - input: FlowControlConfig{ + input: ReliableTransportConfig{ MSS: 100, WindowBytes: 100, RTO: 5 * time.Millisecond, @@ -30,7 +30,7 @@ func TestFlowControlConfigClamp(t *testing.T) { RecvBufBytes: 100, SendBufBytes: 100, }, - expected: FlowControlConfig{ + expected: ReliableTransportConfig{ MSS: MinMSS, WindowBytes: MinMSS, RTO: MinRTO, @@ -43,7 +43,7 @@ func TestFlowControlConfigClamp(t *testing.T) { }, { name: "Values above range are clamped", - input: FlowControlConfig{ + input: ReliableTransportConfig{ MSS: 2000, WindowBytes: 2 << 20, RTO: 70 * time.Second, @@ -53,13 +53,13 @@ func TestFlowControlConfigClamp(t *testing.T) { RecvBufBytes: 16 << 20, SendBufBytes: 16 << 20, }, - expected: FlowControlConfig{ + expected: ReliableTransportConfig{ MSS: MaxMSS, WindowBytes: MaxWindowBytes, RTO: MaxRTOCeiling, RTOMax: MaxRTOCeiling, RetryLimit: MaxRetries, - AckDelay: MinAckDelay, // AckDelay is clamped to MinAckDelay if above range + AckDelay: 1 * time.Second, // Correct expected value RecvBufBytes: MaxRecvBufBytes, SendBufBytes: MaxSendBufBytes, }, diff --git a/mod/udp/src/conn.go b/mod/udp/src/conn.go index 2eb91d516..1fdb68b0f 100644 --- a/mod/udp/src/conn.go +++ b/mod/udp/src/conn.go @@ -2,11 +2,14 @@ package udp import ( + "io" "net" + "sync" + "sync/atomic" "time" - "github.com/cryptopunkscc/astrald/mod/exonet" "github.com/cryptopunkscc/astrald/mod/udp" + "github.com/smallnest/ringbuffer" ) // DatagramWriter is how Conn sends bytes to its peer. @@ -24,184 +27,226 @@ type Handshaker interface { Handshake() error } -type Fragmenter interface { -} - -// Conn represents a reliable UDP connection -// Handshake/FIN are out of scope for this MVP; stream semantics only. +type Unacked struct { + pkt *Packet // Packet metadata (seq, len, ring offsets) + sentTime time.Time // Last sent time + rtxCount int // Retransmit count + offset int // Offset in sendRB (handshake: -1) + length int // Length in sendRB / payload length (handshake: 0) + isHandshake bool // True if this entry is for a handshake control packet +} + +// Conn represents a reliable UDP connection. +// Implements reliability, flow control, retransmissions, and error notification. +// Key mechanisms: +// - Reliable delivery using retransmission timer and fast retransmit +// - Flow control using packet window (MaxWindowPackets) +// - Concurrency safety via sendMu and sendCond +// - Error notification via ErrChan (application can monitor for connection-level errors) +// - Centralized resource cleanup via Close() +// - PoC limitations: no congestion control, no SACK, no adaptive pacing type Conn struct { - // socket / addressing - udpConn *net.UDPConn + // UDP socket and addressing + udpConn *net.UDPConn // Underlying UDP socket localEndpoint *udp.Endpoint remoteEndpoint *udp.Endpoint - // config - cfg FlowControlConfig - state ConnState - inCh chan *Packet - closedFlag bool + // Configuration (reliability, flow control, etc.) + cfg ReliableTransportConfig // All protocol parameters - // + // Connection state (atomic for lock-free reads) + state uint32 // Current connection state (stores ConnState) + inCh chan *Packet // Incoming packet channel + closedFlag uint32 // 0=open, 1=closed (atomic) + + // Sequence numbers and send state initialSeqNumLocal uint32 initialSeqNumRemote uint32 - // send state - nextSeqNum uint32 - connID uint32 - sendBase uint32 // oldest unacked sequence (i.e., cumulative ACK floor). - ackedSeqNum uint32 // highest cumulative ACK seen (often == sendBase). - expected uint32 - // - unacked map[uint32]*Packet // seq -> packet - // receive state + nextSeqNum uint32 + connID uint32 // Connection ID + sendBase uint32 // Oldest unacked sequence (ACK floor) + ackedSeqNum uint32 // Highest cumulative ACK seen + expected uint32 // Next expected sequence number (receive side) + + inflight uint32 // Number of unacked packets + + // Send buffer and reliability + sendRB *ringbuffer.RingBuffer // Persistent send ring buffer + frag *BasicFragmenter // Fragmenter for packetization + unacked map[uint32]*Unacked // Map of unacked packets (seq -> Unacked) + + // Concurrency and coordination + sendMu sync.Mutex // Protects all shared state + sendCond *sync.Cond // Condition variable for sender coordination + + // Retransmission timer + rtxTimer *time.Timer // Fixed retransmission timer (PoC only) + + // Error notification + ErrChan chan error // Channel for connection-level errors (e.g., retransmission failure) + + // Inbound buffering & ACK state + recvRB *ringbuffer.RingBuffer + recvMu sync.Mutex + recvCond *sync.Cond + ackTimer *time.Timer + ackPending bool + lastAckSent uint32 + // Out-of-order buffer (keyed by sequence number of first byte) + recvOO map[uint32]*Packet // stored packets with Seq > expected awaiting in-order delivery } func (c *Conn) setState(state ConnState) { - c.state = state + atomic.StoreUint32(&c.state, uint32(state)) } func (c *Conn) inState(state ConnState) bool { - return c.state == state + return atomic.LoadUint32(&c.state) == uint32(state) } +func (c *Conn) isClosed() bool { return atomic.LoadUint32(&c.closedFlag) != 0 } + func (c *Conn) Read(p []byte) (n int, err error) { - if !c.inState(StateEstablished) { + if !c.inState(StateEstablished) && !c.isClosed() { return 0, udp.ErrConnectionNotEstablished } - //TODO implement me - panic("implement me") + c.recvMu.Lock() + for c.recvRB != nil && c.recvRB.Length() == 0 && !c.isClosed() { + c.recvCond.Wait() + } + if c.recvRB == nil || (c.recvRB.Length() == 0 && c.isClosed()) { + c.recvMu.Unlock() + return 0, io.EOF + } + want := len(p) + if rl := int(c.recvRB.Length()); want > rl { + want = rl + } + if want == 0 { + c.recvMu.Unlock() + return 0, nil + } + // Read directly into caller's buffer (no temp allocation) + m, _ := c.recvRB.Read(p[:want]) + c.recvMu.Unlock() + return m, nil } +// Write enqueues data into the send ring buffer. Implementation in send.go. func (c *Conn) Write(p []byte) (n int, err error) { - if !c.inState(StateEstablished) { - return 0, udp.ErrConnectionNotEstablished - } - - //TODO implement me - panic("implement me") + return c.writeSend(p) } func (c *Conn) Close() error { - c.closedFlag = true - c.udpConn.SetReadDeadline(time.Now()) - //TODO implement me - panic("implement me") + c.sendMu.Lock() + if c.isClosed() { // already closed + c.sendMu.Unlock() + return nil + } + atomic.StoreUint32(&c.closedFlag, 1) + // stop retransmission timer if running + if c.rtxTimer != nil { + c.rtxTimer.Stop() + c.rtxTimer = nil + } + // wake any waiters (writers, senderLoop) + c.sendCond.Broadcast() + ch := c.inCh + c.inCh = nil // detach channel to prevent further sends + c.sendMu.Unlock() + + if ch != nil { + close(ch) + } + _ = c.udpConn.SetReadDeadline(time.Now()) + c.recvMu.Lock() + if c.ackTimer != nil { + c.ackTimer.Stop() + c.ackTimer = nil + } + if c.recvCond != nil { + c.recvCond.Broadcast() + } + c.recvMu.Unlock() + return c.udpConn.Close() } // NewConn constructs a connection around an already-connected UDP socket. -func NewConn(c *net.UDPConn, l, r *udp.Endpoint, cfg FlowControlConfig) (*Conn, error) { +func NewConn(cn *net.UDPConn, l, r *udp.Endpoint, cfg ReliableTransportConfig) (*Conn, error) { cfg.Normalize() - if cfg.MSS <= 0 { + if cfg.MaxSegmentSize <= 0 { return nil, udp.ErrZeroMSS } + sendRBSize := cfg.MaxWindowBytes * 2 // allow for some retransmit slack + rb := ringbuffer.New(sendRBSize) + frag := NewBasicFragmenter(cfg.MaxSegmentSize) + rc := &Conn{ - udpConn: c, + udpConn: cn, localEndpoint: l, remoteEndpoint: r, cfg: cfg, + sendRB: rb, + frag: frag, + unacked: make(map[uint32]*Unacked), + ErrChan: make(chan error, 1), // Buffered to avoid blocking + inCh: make(chan *Packet, 32), // handshake delivery channel } + rc.sendCond = sync.NewCond(&rc.sendMu) + rc.recvRB = ringbuffer.New(cfg.RecvBufBytes) + rc.recvCond = sync.NewCond(&rc.recvMu) + rc.recvOO = make(map[uint32]*Packet) - return rc, nil -} - -// Outbound reports whether this connection was dialed out. -// For now this always returns true for Dial usage; adjust if you add a listener. -func (c *Conn) Outbound() bool { return true } + // start fused receive loop + go rc.recvLoop() -// LocalEndpoint returns the local UDP endpoint. -func (c *Conn) LocalEndpoint() exonet.Endpoint { - return c.localEndpoint -} - -// RemoteEndpoint returns the remote UDP endpoint. -func (c *Conn) RemoteEndpoint() exonet.Endpoint { - return c.remoteEndpoint + return rc, nil } -func (c *Conn) receivingLoop() { - const maxPayloadSize = 64 * 1024 - buf := make([]byte, maxPayloadSize) - for { - c.udpConn.SetReadDeadline(time.Now().Add(1 * time.Second)) - n, addr, err := c.udpConn.ReadFromUDP(buf) - if err != nil { - if c.closedFlag { - return - } - continue - } - - // NOTE: test it - if addr.String() != c.remoteEndpoint.IP.String() { - continue // not for this Conn - } - - pktData := make([]byte, n) - copy(pktData, buf[:n]) - pkt := &Packet{} - if err := pkt.Unmarshal(pktData); err != nil { - continue // drop malformed - } - if int(pkt.Len) > maxPayloadSize { - continue // invalid length - } - isControl := pkt.Flags&(FlagSYN|FlagACK|FlagFIN) != 0 && pkt.Len == 0 - if isControl { - // Block until enqueued - c.inCh <- pkt - } else { - // Drop data if channel full - select { - case c.inCh <- pkt: - default: - // drop data - } - } +// HandleAckPacket processes ACK packets +func (c *Conn) HandleAckPacket(packet *Packet) { + ack := packet.Ack + c.sendMu.Lock() + defer c.sendMu.Unlock() + if ack > c.ackedSeqNum { + c.ackedSeqNum = ack } -} - -// Go -func (c *Conn) InboundPacketHandler() { - for pkt := range c.inCh { - if pkt.Flags&FlagACK != 0 { - c.handleAckPacket(pkt) + if ack > c.sendBase { + c.sendBase = ack + } + // Remove fully acked packets (keyed by seq) + for s, u := range c.unacked { + if u.isHandshake { + // Handshake control (SYN / SYN|ACK) conceptually consumes 1 sequence number. + // Require ack > s (i.e., ack == s+1) to delete, avoiding premature removal + // if an unexpected ack echo with ack==s arrives. + if ack > s { // expected ack == s+1 + delete(c.unacked, s) + } continue } - - if pkt.Flags&(FlagSYN|FlagFIN) != 0 { - c.handleControlPacket(pkt) - continue + // Data packet: remove when cumulative ack covers entire payload. + if s+uint32(u.length) <= ack { + delete(c.unacked, s) } - - c.handleDataPacket(pkt) } -} - -// handleAckPacket processes ACK packets -func (c *Conn) handleAckPacket(pkt *Packet) { - ack := pkt.Ack - for seq := range c.unacked { - if seq <= ack { - delete(c.unacked, seq) - } + // Stop retransmission timer if no unacked packets remain + if len(c.unacked) == 0 && c.rtxTimer != nil { + c.rtxTimer.Stop() + c.rtxTimer = nil } - c.sendBase = ack + 1 + c.sendCond.Broadcast() } -// handleControlPacket processes SYN, FIN, and other control packets -func (c *Conn) handleControlPacket(pkt *Packet) { +// HandleControlPacket processes SYN, FIN, and other control packets +func (c *Conn) HandleControlPacket(packet *Packet) { // Example: handle SYN, FIN, or other control logic - if pkt.Flags&FlagSYN != 0 { - // ...handle SYN logic... + if packet.Flags&FlagSYN != 0 { + // TODO: ...handle SYN logic... } - if pkt.Flags&FlagFIN != 0 { - // ...handle FIN logic... + if packet.Flags&FlagFIN != 0 { + // TODO: ...handle FIN logic... } // ...handle other control flags as needed... } - -// handleDataPacket processes data packets -func (c *Conn) handleDataPacket(pkt *Packet) { - // Example: deliver to receive buffer, update expected, send ACK, etc. - // ...implement data delivery and ACK logic... -} diff --git a/mod/udp/src/conn_handshake.go b/mod/udp/src/conn_handshake.go index 55a29e0f2..5b08882a7 100644 --- a/mod/udp/src/conn_handshake.go +++ b/mod/udp/src/conn_handshake.go @@ -2,7 +2,9 @@ package udp import ( "context" + "crypto/rand" "fmt" + "time" "github.com/cryptopunkscc/astrald/mod/udp" ) @@ -20,13 +22,17 @@ const ( StateTimeWait // (optional) short wait after FIN to absorb retransmits ) -func (c *Conn) startClientHandshake(ctx context.Context) error { - c.initialSeqNumLocal = randUint32NZ() +func (c *Conn) StartClientHandshake(ctx context.Context) error { + seq, err := randUint32NZ() + if err != nil { + return fmt.Errorf("failed to generate initial sequence number: %w", err) + } + c.initialSeqNumLocal = seq c.connID = c.initialSeqNumLocal c.setState(StateSynSent) - err := c.sendControl(FlagSYN, c.initialSeqNumLocal, 0) - if err != nil { + // build + send initial SYN and register in unacked for unified retransmission + if err := c.sendHandshakeControl(FlagSYN, c.initialSeqNumLocal, 0); err != nil { return err } @@ -36,78 +42,130 @@ func (c *Conn) startClientHandshake(ctx context.Context) error { return udp.ErrHandshakeTimeout case pkt := <-c.inCh: if pkt.Flags&(FlagSYN|FlagACK) == (FlagSYN|FlagACK) && pkt.Ack == c.initialSeqNumLocal+1 && pkt.Seq != 0 { + // got valid SYN|ACK c.initialSeqNumRemote = pkt.Seq + // remove our SYN from unacked and set sequence bases + c.sendMu.Lock() + delete(c.unacked, c.initialSeqNumLocal) + if len(c.unacked) == 0 && c.rtxTimer != nil { + c.rtxTimer.Stop() + c.rtxTimer = nil + } + c.ackedSeqNum = c.initialSeqNumLocal + 1 + c.sendBase = c.initialSeqNumLocal + 1 + c.nextSeqNum = c.initialSeqNumLocal + 1 + c.sendMu.Unlock() - err := c.sendControl(FlagACK, c.initialSeqNumLocal+1, c.initialSeqNumRemote+1) - if err != nil { + // send final ACK (not tracked for retransmission) + if err := c.SendControlPacket(FlagACK, c.initialSeqNumLocal+1, c.initialSeqNumRemote+1); err != nil { return err } c.setState(StateEstablished) - go c.InboundPacketHandler() + // fused receive loop will now dispatch directly return nil } } } } -func (c *Conn) startServerHandshake(ctx context.Context, synPkt *Packet) error { +func (c *Conn) StartServerHandshake(ctx context.Context, synPkt *Packet) error { c.initialSeqNumRemote = synPkt.Seq c.connID = synPkt.Seq - c.initialSeqNumLocal = randUint32NZ() + seq, err := randUint32NZ() + if err != nil { + return fmt.Errorf("failed to generate initial sequence number: %w", err) + } + c.initialSeqNumLocal = seq c.setState(StateSynReceived) - if err := c.sendControl(FlagSYN|FlagACK, c.initialSeqNumLocal, c.initialSeqNumRemote+1); err != nil { + // send SYN|ACK and register for retransmission + if err := c.sendHandshakeControl(FlagSYN|FlagACK, c.initialSeqNumLocal, c.initialSeqNumRemote+1); err != nil { return err } + for { select { case <-ctx.Done(): return udp.ErrHandshakeTimeout case pkt := <-c.inCh: if pkt.Flags&FlagACK != 0 && pkt.Ack == c.initialSeqNumLocal+1 { + // final ACK received + c.sendMu.Lock() + delete(c.unacked, c.initialSeqNumLocal) + if len(c.unacked) == 0 && c.rtxTimer != nil { + c.rtxTimer.Stop() + c.rtxTimer = nil + } + c.ackedSeqNum = c.initialSeqNumLocal + 1 + c.sendBase = c.initialSeqNumLocal + 1 + c.nextSeqNum = c.initialSeqNumLocal + 1 + c.sendMu.Unlock() c.setState(StateEstablished) - go c.InboundPacketHandler() + // fused receive loop will now dispatch directly return nil } } } } -func (c *Conn) sendControl(flags uint8, seq, ack uint32) error { - pkt := &Packet{ - Seq: seq, - Ack: ack, - Flags: flags, - Len: 0, +// sendHandshakeControl builds, sends and registers a handshake control packet for unified retransmissions +func (c *Conn) sendHandshakeControl(flags uint8, seq, ack uint32) error { + pkt := &Packet{Seq: seq, Ack: ack, Flags: flags, Len: 0} + b, err := pkt.Marshal() + if err != nil { + return fmt.Errorf("marshal handshake pkt: %w", err) + } + if c.udpConn == nil { + return udp.ErrConnClosed + } + if _, err := c.udpConn.Write(b); err != nil { + return err } + c.sendMu.Lock() + if _, exists := c.unacked[seq]; !exists { // only register first time + c.unacked[seq] = &Unacked{ + pkt: pkt, + sentTime: time.Now(), + rtxCount: 0, + offset: -1, + length: 0, + isHandshake: true, + } + if c.rtxTimer == nil { + c.startRtxTimer() + } + } + c.sendMu.Unlock() + return nil +} +// SendControlPacket retained for non-handshake control (e.g., FIN, pure ACK), not tracked +func (c *Conn) SendControlPacket(flags uint8, seq, ack uint32) error { + pkt := &Packet{Seq: seq, Ack: ack, Flags: flags, Len: 0} data, err := pkt.Marshal() if err != nil { - return fmt.Errorf(`sendControl failed to marshal control packet: %w`, err) + return fmt.Errorf(`SendControlPacket failed to marshal control packet: %w`, err) } - if c.udpConn == nil { return udp.ErrConnClosed } - _, err = c.udpConn.Write(data) if err != nil { - return fmt.Errorf(`sendControl failed to send control packet: %w`, err) + return fmt.Errorf(`SendControlPacket failed to send control packet: %w`, err) } - return err } -func (c *Conn) handleInbound(pkt *Packet) { - // ...process handshake/control/data packets... -} - -// TODO: implement proper random non-zero uint32 generator -func randUint32NZ() uint32 { - // ...generate non-zero random uint32... - return 1 // stub +func randUint32NZ() (uint32, error) { + var b [4]byte + for { + _, err := rand.Read(b[:]) + if err != nil { + return 0, fmt.Errorf("failed to generate random uint32: %w", err) + } + v := uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) + if v != 0 { + return v, nil + } + } } - -// notifyInbound is a channel for inbound packets -// ...in Conn struct... -// notifyInbound chan *Packet diff --git a/mod/udp/src/conn_handshake_test.go b/mod/udp/src/conn_handshake_test.go deleted file mode 100644 index dcd07acdd..000000000 --- a/mod/udp/src/conn_handshake_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package udp - -import ( - "context" - "testing" - "time" -) - -func TestClientHandshake_Success(t *testing.T) { - conn := &Conn{ - notifyInbound: make(chan *Packet, 1), - } - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - // Simulate server SYN|ACK response - go func() { - time.Sleep(10 * time.Millisecond) - conn.notifyInbound <- &Packet{ - Seq: 12345, // server ISN - Ack: 2, // client ISN + 1 - Flags: FlagSYN | FlagACK, - Len: 0, - } - }() - - conn.initialSeqNumLocal = 1 // deterministic for test - conn.connID = 1 - - err := conn.startClientHandshake(ctx) - if err != nil { - t.Fatalf("expected handshake success, got error: %v", err) - } - if conn.state != StateEstablished { - t.Fatalf("expected StateEstablished, got %v", conn.state) - } - if conn.initialSeqNumRemote != 12345 { - t.Fatalf("expected initialSeqNumRemote=12345, got %v", conn.initialSeqNumRemote) - } -} - -func TestClientHandshake_Timeout(t *testing.T) { - conn := &Conn{ - notifyInbound: make(chan *Packet, 1), - } - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - - err := conn.startClientHandshake(ctx) - if err == nil || err.Error() != "handshake timeout" { - t.Fatalf("expected handshake timeout error, got: %v", err) - } -} - -func TestServerHandshake_Success(t *testing.T) { - conn := &Conn{ - notifyInbound: make(chan *Packet, 1), - } - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - // Simulate client ACK response - go func() { - time.Sleep(10 * time.Millisecond) - conn.notifyInbound <- &Packet{ - Ack: 2, // server ISN + 1 - Flags: FlagACK, - Len: 0, - } - }() - - synPkt := &Packet{Seq: 1, Flags: FlagSYN, Len: 0} - err := conn.startServerHandshake(ctx, synPkt) - if err != nil { - t.Fatalf("expected handshake success, got error: %v", err) - } - if conn.state != StateEstablished { - t.Fatalf("expected StateEstablished, got %v", conn.state) - } -} - -func TestServerHandshake_Timeout(t *testing.T) { - conn := &Conn{ - notifyInbound: make(chan *Packet, 1), - } - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - - synPkt := &Packet{Seq: 1, Flags: FlagSYN, Len: 0} - err := conn.startServerHandshake(ctx, synPkt) - if err == nil || err.Error() != "handshake timeout" { - t.Fatalf("expected handshake timeout error, got: %v", err) - } -} - -func TestClientHandshake_BadAckIgnored(t *testing.T) { - conn := &Conn{ - notifyInbound: make(chan *Packet, 2), - } - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - // Simulate server response with wrong Ack - go func() { - time.Sleep(10 * time.Millisecond) - conn.notifyInbound <- &Packet{ - Seq: 12345, - Ack: 999, // wrong ack - Flags: FlagSYN | FlagACK, - Len: 0, - } - // Then correct response - time.Sleep(10 * time.Millisecond) - conn.notifyInbound <- &Packet{ - Seq: 12345, - Ack: 2, - Flags: FlagSYN | FlagACK, - Len: 0, - } - }() - - conn.initialSeqNumLocal = 1 - conn.connID = 1 - - err := conn.startClientHandshake(ctx) - if err != nil { - t.Fatalf("expected handshake success, got error: %v", err) - } - if conn.state != StateEstablished { - t.Fatalf("expected StateEstablished, got %v", conn.state) - } -} diff --git a/mod/udp/src/fragmenter.go b/mod/udp/src/fragmenter.go new file mode 100644 index 000000000..3c15018cc --- /dev/null +++ b/mod/udp/src/fragmenter.go @@ -0,0 +1,88 @@ +package udp + +// Fragmenter turns buffered bytes into wire packets and reproduces the exact +// same boundaries for retransmission. +type Fragmenter interface { + // MakeNew decides payload size and builds a new Packet at nextSeq. + // 'allowed' is the sender's remaining window in bytes. + // Returns (packet, payloadLen, ok). ok=false if it chooses not to send (e.g., Nagle). + MakeNew(nextSeq uint32, allowed int, buf SendBuffer) (*Packet, int, bool) +} + +// BasicFragmenter is a simple implementation of Fragmenter that splits a +// SendBuffer into Packets of at most MSS size. +type BasicFragmenter struct { + MSS int +} + +// NewBasicFragmenter creates a new BasicFragmenter with the given maximum +// segment size (MSS). +func NewBasicFragmenter(mss int) *BasicFragmenter { + return &BasicFragmenter{MSS: mss} +} + +// MakeNew implements the Fragmenter interface for BasicFragmenter. +func (f *BasicFragmenter) MakeNew(nextSeq uint32, allowed int, buf SendBuffer) (*Packet, int, bool) { + if f.MSS <= 0 { + return nil, 0, false + } + if buf.Len() == 0 { + return nil, 0, false + } + if allowed <= 0 { + return nil, 0, false + } + + maxLen := f.MSS + if allowed < maxLen { + maxLen = allowed + } + if buf.Len() < maxLen { + maxLen = buf.Len() + } + + payload := buf.Peek(maxLen) + packet := &Packet{ + Seq: nextSeq, + Len: uint16(len(payload)), + Payload: payload, + Flags: FlagACK, // Data packets should have ACK flag set + } + return packet, len(payload), true +} + +// Minimal SendBuffer interface for fragmentation +// Provides length and peek access to buffered data +// Implementations can use a byte slice, ring buffer, etc. +type SendBuffer interface { + Len() int + Peek(n int) []byte +} + +// ByteStreamBuffer is a minimal implementation of SendBuffer for fragmentation. +// It represents a contiguous stream of bytes, suitable for segmentation. +type ByteStreamBuffer struct { + data []byte +} + +func NewByteStreamBuffer(data []byte) *ByteStreamBuffer { + return &ByteStreamBuffer{data: data} +} + +func (b *ByteStreamBuffer) Len() int { + return len(b.data) +} + +func (b *ByteStreamBuffer) Peek(n int) []byte { + if n > len(b.data) { + n = len(b.data) + } + return b.data[:n] +} + +func (b *ByteStreamBuffer) Advance(n int) { + if n > len(b.data) { + n = len(b.data) + } + b.data = b.data[n:] +} diff --git a/mod/udp/src/fragmenter_test.go b/mod/udp/src/fragmenter_test.go new file mode 100644 index 000000000..1d7a8bf1d --- /dev/null +++ b/mod/udp/src/fragmenter_test.go @@ -0,0 +1,196 @@ +package udp + +import ( + "testing" +) + +func TestBasicFragmenter_SingleFragment(t *testing.T) { + mss := 100 + frag := NewBasicFragmenter(mss) + data := make([]byte, 80) + for i := range data { + data[i] = byte(i) + } + buf := &ByteStreamBuffer{data: data} + packet, packetLen, ok := frag.MakeNew(0, mss, buf) + if !ok { + t.Fatalf("expected ok=true, got false") + } + if packetLen != 80 { + t.Errorf("expected packetLen=80, got %d", packetLen) + } + if packet == nil { + t.Fatalf("expected non-nil packet") + } + if packet.Seq != 0 { + t.Errorf("expected Seq=0, got %d", packet.Seq) + } + if packet.Len != 80 { + t.Errorf("expected Len=80, got %d", packet.Len) + } + for i := range packet.Payload { + if packet.Payload[i] != byte(i) { + t.Errorf("payload mismatch at %d: got %d, want %d", i, packet.Payload[i], byte(i)) + } + } +} + +func TestBasicFragmenter_MultipleFragments(t *testing.T) { + mss := 50 + frag := NewBasicFragmenter(mss) + data := make([]byte, 120) + for i := range data { + data[i] = byte(i) + } + buf := &ByteStreamBuffer{data: data} + + nextSeq := uint32(0) + total := 0 + fragments := 0 + for buf.Len() > 0 { + packet, packetLength, ok := frag.MakeNew(nextSeq, mss, buf) + if !ok { + t.Fatalf("expected ok=true, got false") + } + if packet == nil { + t.Fatalf("expected non-nil packet") + } + if int(packet.Len) > mss { + t.Errorf("fragment too large: got %d, want <= %d", packet.Len, mss) + } + for i := 0; i < int(packet.Len); i++ { + if packet.Payload[i] != byte(int(nextSeq)+i) { + t.Errorf("payload mismatch at %d: got %d, want %d", int(nextSeq)+i, packet.Payload[i], byte(int(nextSeq)+i)) + } + } + buf.Advance(packetLength) + nextSeq += uint32(packetLength) + total += packetLength + fragments++ + } + if total != 120 { + t.Errorf("expected total=120, got %d", total) + } + if fragments != 3 { + t.Errorf("expected fragments=3, got %d", fragments) + } +} + +func TestBasicFragmenter_ZeroLen(t *testing.T) { + mss := 50 + frag := NewBasicFragmenter(mss) + buf := &ByteStreamBuffer{data: nil} + packet, packetLength, ok := frag.MakeNew(0, mss, buf) + if ok { + t.Errorf("expected ok=false for zero-len buffer") + } + if packet != nil { + t.Errorf("expected nil packet for zero-len buffer") + } + if packetLength != 0 { + t.Errorf("expected packetLength=0 for zero-len buffer, got %d", packetLength) + } +} + +func TestBasicFragmenter_AllowedLessThanMSS(t *testing.T) { + mss := 100 + allowed := 40 + frag := NewBasicFragmenter(mss) + data := make([]byte, 80) + for i := range data { + data[i] = byte(i) + } + buf := &ByteStreamBuffer{data: data} + packet, packetLength, ok := frag.MakeNew(0, allowed, buf) + if !ok { + t.Fatalf("expected ok=true, got false") + } + if packetLength != allowed { + t.Errorf("expected packetLength=%d, got %d", allowed, packetLength) + } + if packet.Len != uint16(allowed) { + t.Errorf("expected Len=%d, got %d", allowed, packet.Len) + } +} + +func TestBasicFragmenter_AllowedLessThanBuffer(t *testing.T) { + mss := 100 + allowed := 60 + frag := NewBasicFragmenter(mss) + data := make([]byte, 80) + for i := range data { + data[i] = byte(i) + } + buf := &ByteStreamBuffer{data: data} + packet, packetLength, ok := frag.MakeNew(0, allowed, buf) + if !ok { + t.Fatalf("expected ok=true, got false") + } + if packetLength != allowed { + t.Errorf("expected packetLength=%d, got %d", allowed, packetLength) + } + if packet.Len != uint16(allowed) { + t.Errorf("expected Len=%d, got %d", allowed, packet.Len) + } +} + +func TestBasicFragmenter_BufferSmallerThanAllowedAndMSS(t *testing.T) { + mss := 100 + allowed := 80 + frag := NewBasicFragmenter(mss) + data := make([]byte, 50) + for i := range data { + data[i] = byte(i) + } + buf := &ByteStreamBuffer{data: data} + packet, packetLength, ok := frag.MakeNew(0, allowed, buf) + if !ok { + t.Fatalf("expected ok=true, got false") + } + if packetLength != 50 { + t.Errorf("expected packetLength=50, got %d", packetLength) + } + if packet.Len != 50 { + t.Errorf("expected Len=50, got %d", packet.Len) + } +} + +func TestBasicFragmenter_NegativeOrZeroAllowed(t *testing.T) { + mss := 100 + frag := NewBasicFragmenter(mss) + data := make([]byte, 50) + buf := &ByteStreamBuffer{data: data} + packet, packetLength, ok := frag.MakeNew(0, 0, buf) + if ok || packet != nil || packetLength != 0 { + t.Errorf("expected no packet for allowed=0") + } + packet, packetLength, ok = frag.MakeNew(0, -10, buf) + if ok || packet != nil || packetLength != 0 { + t.Errorf("expected no packet for allowed<0") + } +} + +func TestBasicFragmenter_ZeroMSS(t *testing.T) { + mss := 0 + frag := NewBasicFragmenter(mss) + data := make([]byte, 50) + buf := &ByteStreamBuffer{data: data} + packet, packetLength, ok := frag.MakeNew(0, 100, buf) + if ok || packet != nil || packetLength != 0 { + t.Errorf("expected no packet for MSS=0") + } +} + +func TestBasicFragmenter_FlagsSet(t *testing.T) { + mss := 100 + frag := NewBasicFragmenter(mss) + data := make([]byte, 50) + buf := &ByteStreamBuffer{data: data} + packet, _, ok := frag.MakeNew(0, 100, buf) + if !ok || packet == nil { + t.Fatalf("expected valid packet") + } + if packet.Flags&FlagACK == 0 { + t.Errorf("expected ACK flag set in data packet") + } +} diff --git a/mod/udp/src/plan.md b/mod/udp/src/plan.md deleted file mode 100644 index 64e9657c9..000000000 --- a/mod/udp/src/plan.md +++ /dev/null @@ -1,209 +0,0 @@ -# Reliable UDP Module: Architectural Brief - -## Purpose & Context -The Reliable UDP module provides stream-like semantics over UDP, ensuring ordered and reliable delivery of data. It is designed to integrate seamlessly with the Astral ecosystem, particularly the `exonet` module and node communication. Unlike raw UDP, this module introduces mechanisms for retransmissions, acknowledgments, and a handshake protocol to establish connection state before data exchange. - -## Interfaces & Contracts -- **Connection Interface**: Provides ordered, reliable byte streams. Implements `io.ReadWriteCloser`. -- **Listener Behavior**: Accepts incoming connections, demultiplexing based on remote endpoints. -- **Endpoint Handling**: Supports parsing, packing, and unpacking of network addresses. -- **Invariants**: - - Data is delivered in order. - - Lost packets are retransmitted. - - Connections are established via a handshake. - -## Handshake Protocol -The handshake follows a three-step process: -1. **SYN**: Initiator sends a SYN packet with an initial sequence number. -2. **SYN|ACK**: Responder replies with a SYN|ACK, acknowledging the initiator's sequence number and providing its own. -3. **ACK**: Initiator acknowledges the responder's sequence number, completing the handshake. - -### Sequence Space Rules -- SYN and FIN each consume one sequence number. -- Retransmissions occur if no acknowledgment is received within the retransmission timeout (RTO). -- A connection is established after the ACK is received. - -### Timing & Retransmission -- Initial RTO: 500ms (configurable). -- Exponential backoff for retransmissions. -- Maximum retries: 8 (configurable). - -## Data Path Overview -- **Segmentation**: Application data is split into packets, each with a sequence number. -- **Ordering**: Out-of-order packets are buffered until missing packets arrive. -- **Acknowledgments**: Cumulative ACKs confirm receipt of all bytes up to a sequence number. -- **Retransmission**: Unacknowledged packets are retransmitted after RTO. -- **Batching**: Multiple packets may be sent together to optimize throughput. - -## Concurrency & I/O Model -- **Locking Domains**: Separate locks for send and receive paths. -- **Goroutines**: - - One for sending data. - - One for receiving and processing packets. - - Timers for retransmissions. -- **Shutdown**: Ensures all goroutines exit cleanly, and no resources are leaked. - -## Error Model & Shutdown Semantics -- **Errors**: Surface as `net.Error` or module-specific errors. -- **Idempotent Close**: Closing a connection multiple times has no adverse effects. -- **Partial Failures**: Errors during send/receive are propagated to the caller. - -## Compatibility & Integration -- **Endpoint Parsing**: Compatible with `exonet` endpoint parsing and unpacking. -- **Lifecycle Alignment**: Designed to align with the lifecycle of other modules like TCP and Tor. -- **Assumptions**: Assumes reliable delivery within the module; does not handle NAT traversal or encryption. - -## Security & Future Considerations -- **Stateless Cookie**: Potential for DoS mitigation using stateless cookies during the handshake. -- **PLPMTUD**: Path MTU discovery to avoid fragmentation. -- **Congestion Control**: Future integration with congestion control mechanisms. - -## Missing Logics and Potential Issues - -### Missing Logics -1. **Connection Handshake Validation**: - - The handshake process lacks validation for replay attacks or duplicate SYN packets. This could lead to unnecessary resource allocation. - -2. **Congestion Control**: - - The module does not implement congestion control mechanisms, which could lead to network congestion in high-traffic scenarios. - -3. **DoS Mitigation**: - - There is no stateless cookie mechanism during the handshake to prevent denial-of-service (DoS) attacks. - -4. **Connection Timeout**: - - The module does not enforce a timeout for idle connections, which could lead to resource exhaustion. - -5. **Error Propagation**: - - Errors during retransmissions or ACK handling are not consistently propagated to the caller, which could make debugging difficult. - -### Potential Performance Issues -1. **Timer Management**: - - The retransmission timer (`armRTO`) and ACK delay timer (`armAckDelay`) are reset frequently, which could lead to high overhead in timer management. - -2. **Lock Contention**: - - The use of mutexes (`rtoMu`) for timer operations could lead to contention under high concurrency. - -3. **Inefficient Buffering**: - - The `sendQ` buffer in `conn.go` may become a bottleneck if the application writes data faster than the network can transmit. - -4. **Packet Parsing Overhead**: - - The `UnmarshalPacket` function in `recv.go` is called for every incoming packet, which could become a performance bottleneck if the parsing logic is complex. - -### Potential Bugs -1. **Timer Race Conditions**: - - The `armRTO` and `stopRTO` functions do not ensure that the timer callback (`handleRTO`) is not running when the timer is stopped, which could lead to race conditions. - -2. **Endpoint Parsing Errors**: - - The `Dial` function in `dial.go` does not handle errors from `udp.ParseEndpoint`, which could lead to nil pointer dereferences. - -3. **Unbounded Retransmissions**: - - The retransmission logic does not enforce a maximum number of retries, which could lead to infinite retransmissions in case of persistent packet loss. - -4. **ACK Timer Reset**: - - The `armAckDelay` function resets the ACK timer without checking if the timer is already running, which could lead to missed ACKs. - -## Implementation Status - -### Fully Implemented and Tested -1. **Ring Buffer**: - - Complete implementation with test coverage for: - - Blocking write/read operations - - Buffer closure handling - - Concurrent access patterns - -2. **Packet Serialization**: - - Full implementation with tests for: - - Marshal/unmarshal operations - - Valid packet handling - - Empty payload cases - -3. **Configuration**: - - Complete implementation with tests covering: - - Default values - - Range validation - - Value normalization - -### Fully Implemented but Untested -1. **Data Transmission**: - - Segmentation and packet sending in `send.go` - - Retransmission handling in `timers.go` - - No test coverage for edge cases or error conditions - -2. **Data Reception**: - - Packet processing and buffering in `recv.go` - - Out-of-order packet handling - - No tests for complex reassembly scenarios - -3. **Server Logic**: - - Connection management in `server.go` - - Datagram routing - - Lacks tests for concurrent connections - -### Partially Implemented -1. **Handshake Protocol**: - - Basic structure defined in `packet.go` (SYN/ACK/FIN flags) - - Missing implementation in: - - Connection establishment logic - - State machine for handshake steps - - Timeout handling during handshake - -2. **Error Handling**: - - Basic error types defined - - Inconsistent propagation in retransmission logic - - Missing comprehensive error recovery - -3. **Timer Management**: - - Basic timer operations implemented - - Race condition risks identified - - Missing proper cleanup and synchronization - -### Missing Components -1. **Connection State Management**: - - No explicit connection state machine - - Missing timeout handling for idle connections - - No graceful connection termination - -2. **Flow Control**: - - Window size tracking implemented - - Missing: - - Congestion control - - Slow start mechanism - - Fast retransmit/recovery - -3. **Security Features**: - - No DoS protection - - Missing replay attack prevention - - No cookie mechanism for handshake - -4. **Testing Infrastructure**: - - Need integration tests for: - - Complete connection lifecycle - - Error scenarios - - Performance under load - - Network condition simulation - -### Next Steps (Prioritized) -1. **Complete Handshake Implementation**: - - Implement state transitions - - Add timeout handling - - Include sequence number validation - -2. **Add Connection Management**: - - Implement idle connection detection - - Add connection timeouts - - Create cleanup mechanisms - -3. **Enhance Security**: - - Add SYN cookie mechanism - - Implement replay protection - - Add rate limiting for new connections - -4. **Improve Reliability**: - - Add congestion control - - Implement proper window management - - Add fast retransmit/recovery - -5. **Complete Test Coverage**: - - Add integration tests - - Create network simulation tests - - Test concurrent connections diff --git a/mod/udp/src/receive.go b/mod/udp/src/receive.go new file mode 100644 index 000000000..34d234dab --- /dev/null +++ b/mod/udp/src/receive.go @@ -0,0 +1,148 @@ +package udp + +import ( + "net" + "time" +) + +// recvLoop fuses raw UDP receive and packet dispatch. During handshake (state < Established) +// packets are delivered into inCh for the blocking handshake loops. After establishment, +// packets are dispatched directly to the appropriate handlers without using inCh. +func (c *Conn) recvLoop() { + const maxPayloadSize = 64 * 1024 + buf := make([]byte, maxPayloadSize) + for { + if c.isClosed() { + return + } + if err := c.udpConn.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { + continue + } + n, addr, err := c.udpConn.ReadFromUDP(buf) + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Timeout() { + continue + } + if c.isClosed() { + return + } + continue + } + // address filter + if addr.String() != c.remoteEndpoint.IP.String() { + continue + } + if n < 13 { // minimum header size + continue + } + packetData := make([]byte, n) + copy(packetData, buf[:n]) + pkt := &Packet{} + if err := pkt.Unmarshal(packetData); err != nil { + continue + } + if int(pkt.Len) > maxPayloadSize { // sanity + continue + } + + // Handshake phase: deliver via channel for Start*Handshake loops + if !c.inState(StateEstablished) { + select { + case c.inCh <- pkt: + default: // if channel full (unlikely in handshake), drop + } + continue + } + + // Established: direct dispatch + if pkt.Flags&FlagACK != 0 { + c.HandleAckPacket(pkt) + continue + } + if pkt.Flags&(FlagSYN|FlagFIN) != 0 { + c.HandleControlPacket(pkt) + continue + } + c.handleDataPacket(pkt) + } +} + +func (c *Conn) handleDataPacket(packet *Packet) { + if packet.Len == 0 { // ignore empty as data + return + } + seq := packet.Seq + plen := uint32(packet.Len) + + c.recvMu.Lock() + exp := c.expected + ackDelay := c.cfg.AckDelay // snapshot once inside lock + + switch { + case seq < exp: // duplicate / already received + c.queueAckLocked() + c.recvMu.Unlock() + c.triggerAck(ackDelay) + return + + case seq == exp: // in-order + if int(packet.Len) > int(c.recvRB.Free()) { + // No space -> request ACK (window advertisement) and drop + c.queueAckLocked() + c.recvMu.Unlock() + c.triggerAck(ackDelay) + return + } + // Write this packet fully + if n, _ := c.recvRB.Write(packet.Payload); n != int(packet.Len) { + // Failed partial write (shouldn't happen with ringbuffer); request ACK and abort + c.queueAckLocked() + c.recvMu.Unlock() + c.triggerAck(ackDelay) + return + } + exp += plen + c.expected = exp + // Drain any now-contiguous buffered packets + for { + nextPkt, ok := c.recvOO[exp] + if !ok { + break + } + if int(nextPkt.Len) > int(c.recvRB.Free()) { + // Break if capacity insufficient; will retry on next in-order arrival + break + } + if n, _ := c.recvRB.Write(nextPkt.Payload); n != int(nextPkt.Len) { + // On partial write, stop draining to preserve consistency + break + } + delete(c.recvOO, exp) + exp += uint32(nextPkt.Len) + c.expected = exp + } + c.queueAckLocked() + c.recvCond.Broadcast() + c.recvMu.Unlock() + + // mirror expected -> ackedSeqNum for piggyback + c.sendMu.Lock() + if c.ackedSeqNum < exp { + c.ackedSeqNum = exp + } + c.sendMu.Unlock() + c.triggerAck(ackDelay) + return + + default: // seq > exp (future / out-of-order) + if int(packet.Len) <= int(c.recvRB.Free()) { + if _, exists := c.recvOO[seq]; !exists { + c.recvOO[seq] = packet + } + } + c.queueAckLocked() // request duplicate ACK + c.recvMu.Unlock() + c.triggerAck(ackDelay) + return + } +} diff --git a/mod/udp/src/retransmissions.go b/mod/udp/src/retransmissions.go new file mode 100644 index 000000000..850d5c317 --- /dev/null +++ b/mod/udp/src/retransmissions.go @@ -0,0 +1,131 @@ +package udp + +import ( + "sort" + "time" + + "github.com/cryptopunkscc/astrald/mod/udp" +) + +// queueAckLocked marks that an ACK should be (re)sent. Caller must hold recvMu. +func (c *Conn) queueAckLocked() { c.ackPending = true } + +// triggerAck decides immediate vs delayed ACK after recvMu has been released. +func (c *Conn) triggerAck(ackDelay time.Duration) { + if ackDelay == 0 { + c.sendPureACK() + } else { + c.scheduleAck() + } +} + +// scheduleAck sets / resets delayed ACK timer +func (c *Conn) scheduleAck() { + c.recvMu.Lock() + if !c.ackPending || c.isClosed() { + c.recvMu.Unlock() + return + } + d := c.cfg.AckDelay + if d <= 0 { + ackNeeded := c.ackPending + c.ackPending = false + c.recvMu.Unlock() + if ackNeeded { + c.sendPureACK() + } + return + } + if c.ackTimer != nil { + c.ackTimer.Reset(d) + } else { + c.ackTimer = time.AfterFunc(d, c.fireAck) + } + c.recvMu.Unlock() +} + +func (c *Conn) fireAck() { + c.recvMu.Lock() + if !c.ackPending || c.isClosed() { + c.recvMu.Unlock() + return + } + c.ackPending = false + c.recvMu.Unlock() + c.sendPureACK() +} + +// sendPureACK sends a standalone ACK reflecting current expected sequence. +func (c *Conn) sendPureACK() { + if !c.inState(StateEstablished) || c.isClosed() { + return + } + // snapshot expected & window + c.recvMu.Lock() + exp := c.expected + winFree := uint32(0) + if c.recvRB != nil { + winFree = uint32(c.recvRB.Free()) + } + c.recvMu.Unlock() + // clamp window to 16-bit + win := uint16(0) + if winFree > 0xFFFF { + win = 0xFFFF + } else { + win = uint16(winFree) + } + pkt := &Packet{Seq: 0, Ack: exp, Flags: FlagACK, Win: win, Len: 0} + b, err := pkt.Marshal() + if err != nil { + return + } + // best-effort send (no tracking) + _, _ = c.udpConn.Write(b) +} + +// handleRetransmissionTimeoutLocked assumes sendMu is held and performs retransmissions. +// Returns true if the retransmission limit was exceeded for any packet. +func (c *Conn) handleRetransmissionTimeoutLocked() (limitExceeded bool) { + if len(c.unacked) == 0 { + return false + } + + seqs := make([]uint32, 0, len(c.unacked)) + for s := range c.unacked { + seqs = append(seqs, s) + } + sort.Slice(seqs, func(i, j int) bool { return seqs[i] < seqs[j] }) + + for _, s := range seqs { + u := c.unacked[s] + if u.rtxCount >= c.cfg.RetransmissionLimit { + limitExceeded = true + break + } + // Update ACK field to latest cumulative ACK and retransmit + u.pkt.Ack = c.ackedSeqNum + b, err := u.pkt.Marshal() + if err == nil { + _, _ = c.udpConn.WriteToUDP(b, c.remoteEndpoint.UDPAddr()) + } + u.rtxCount++ + u.sentTime = time.Now() + } + return +} + +// handleRetransmissionTimeout provides backward-compatible external behavior (locking & close semantics) +func (c *Conn) handleRetransmissionTimeout() { + c.sendMu.Lock() + limitExceeded := c.handleRetransmissionTimeoutLocked() + c.sendMu.Unlock() + if limitExceeded { + select { + case c.ErrChan <- udp.ErrRetransmissionLimitExceeded: + default: + } + c.Close() + close(c.ErrChan) + } +} diff --git a/mod/udp/src/ring_buffer.go b/mod/udp/src/ring_buffer.go deleted file mode 100644 index 76237b65a..000000000 --- a/mod/udp/src/ring_buffer.go +++ /dev/null @@ -1,159 +0,0 @@ -package udp - -import ( - "io" - "sync" -) - -// ringBuffer implements a blocking circular buffer for byte streams. -type ringBuffer struct { - buf []byte // underlying buffer - cap int // capacity (fixed) - n int // current bytes stored - r int // read position - w int // write position - mu sync.Mutex // protects all fields - notEmp *sync.Cond // signaled when buffer becomes non-empty - notFul *sync.Cond // signaled when buffer has space available - closed bool // whether buffer is closed -} - -// newRingBuffer creates a new ring buffer with the specified capacity. -func newRingBuffer(capacity int) *ringBuffer { - if capacity < 0 { - capacity = 0 - } - rb := &ringBuffer{ - buf: make([]byte, capacity), - cap: capacity, - } - rb.notEmp = sync.NewCond(&rb.mu) - rb.notFul = sync.NewCond(&rb.mu) - return rb -} - -// WriteAll blocks until all bytes are written or the buffer is closed. -func (rb *ringBuffer) WriteAll(b []byte) (int, error) { - rb.mu.Lock() - defer rb.mu.Unlock() - - written := 0 - for written < len(b) { - for rb.n == rb.cap && !rb.closed { - rb.notFul.Wait() - } - if rb.closed { - return written, io.ErrClosedPipe - } - - space := rb.cap - rb.n - toWrite := len(b) - written - if toWrite > space { - toWrite = space - } - - end := (rb.w + toWrite) % rb.cap - if end > rb.w { - copy(rb.buf[rb.w:end], b[written:written+toWrite]) - } else { - copy(rb.buf[rb.w:], b[written:written+toWrite]) - copy(rb.buf[:end], b[written+rb.cap-rb.w:written+toWrite]) - } - - rb.w = end - rb.n += toWrite - written += toWrite - rb.notEmp.Signal() - } - - return written, nil -} - -// TryWrite attempts to write bytes without blocking. -func (rb *ringBuffer) TryWrite(b []byte) int { - rb.mu.Lock() - defer rb.mu.Unlock() - - if rb.n == rb.cap || rb.closed { - return 0 - } - - space := rb.cap - rb.n - toWrite := len(b) - if toWrite > space { - toWrite = space - } - - end := (rb.w + toWrite) % rb.cap - if end > rb.w { - copy(rb.buf[rb.w:end], b[:toWrite]) - } else { - copy(rb.buf[rb.w:], b[:toWrite]) - copy(rb.buf[:end], b[rb.cap-rb.w:toWrite]) - } - - rb.w = end - rb.n += toWrite - rb.notEmp.Signal() - - return toWrite -} - -// Read blocks until at least one byte is available or the buffer is closed and drained. -func (rb *ringBuffer) Read(p []byte) (int, error) { - rb.mu.Lock() - defer rb.mu.Unlock() - - for rb.n == 0 && !rb.closed { - rb.notEmp.Wait() - } - - if rb.n == 0 && rb.closed { - return 0, io.EOF - } - - toRead := len(p) - if toRead > rb.n { - toRead = rb.n - } - - end := (rb.r + toRead) % rb.cap - if end > rb.r { - copy(p, rb.buf[rb.r:end]) - } else { - copy(p, rb.buf[rb.r:]) - copy(p[rb.cap-rb.r:], rb.buf[:end]) - } - - rb.r = end - rb.n -= toRead - rb.notFul.Signal() - - return toRead, nil -} - -// Close marks the buffer as closed and wakes all waiters. -func (rb *ringBuffer) Close() { - rb.mu.Lock() - defer rb.mu.Unlock() - - if !rb.closed { - rb.closed = true - rb.notEmp.Broadcast() - rb.notFul.Broadcast() - } -} - -// Len returns the number of bytes currently stored in the buffer. -func (rb *ringBuffer) Len() int { - rb.mu.Lock() - defer rb.mu.Unlock() - return rb.n -} - -// Cap returns the capacity of the buffer. -func (rb *ringBuffer) Cap() int { - rb.mu.Lock() - defer rb.mu.Unlock() - return rb.cap -} diff --git a/mod/udp/src/ring_buffer_test.go b/mod/udp/src/ring_buffer_test.go deleted file mode 100644 index 4b7d9b99d..000000000 --- a/mod/udp/src/ring_buffer_test.go +++ /dev/null @@ -1,170 +0,0 @@ -package udp - -import ( - "io" - "sync" - "testing" - "time" -) - -func TestRingBufferBlockingWriteRead(t *testing.T) { - rb := newRingBuffer(10) - - var wg sync.WaitGroup - wg.Add(2) - - // Writer - go func() { - defer wg.Done() - data := []byte("hello world") - written, err := rb.WriteAll(data) - if written != len(data) || err != nil { - t.Errorf("WriteAll failed: written=%d, err=%v", written, err) - } - }() - - // Reader - go func() { - defer wg.Done() - buf := make([]byte, 11) - read, err := rb.Read(buf) - if read != 11 || err != nil || string(buf) != "hello world" { - t.Errorf("Read failed: read=%d, err=%v, buf=%s", read, err, buf) - } - }() - - wg.Wait() -} - -func TestRingBufferCloseWhileReaderWaiting(t *testing.T) { - rb := newRingBuffer(10) - - var wg sync.WaitGroup - wg.Add(1) - - // Reader - go func() { - defer wg.Done() - buf := make([]byte, 10) - _, err := rb.Read(buf) - if err != io.EOF { - t.Errorf("Expected io.EOF, got %v", err) - } - }() - - time.Sleep(100 * time.Millisecond) // Ensure reader is waiting - rb.Close() - - wg.Wait() -} - -func TestRingBufferCloseWhileWriterWaiting(t *testing.T) { - rb := newRingBuffer(5) - - var wg sync.WaitGroup - wg.Add(1) - - // Writer - go func() { - defer wg.Done() - data := []byte("hello world") - _, err := rb.WriteAll(data) - if err != io.ErrClosedPipe { - t.Errorf("Expected io.ErrClosedPipe, got %v", err) - } - }() - - time.Sleep(100 * time.Millisecond) // Ensure writer is waiting - rb.Close() - - wg.Wait() -} - -func TestRingBufferConcurrentProducersConsumers(t *testing.T) { - rb := newRingBuffer(50) - - var wg sync.WaitGroup - producers := 5 - consumers := 5 - iterations := 100 - - // Producers - for i := 0; i < producers; i++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - for j := 0; j < iterations; j++ { - rb.WriteAll([]byte{byte(id)}) - } - }(i) - } - - // Consumers - for i := 0; i < consumers; i++ { - wg.Add(1) - go func() { - defer wg.Done() - buf := make([]byte, 1) - for j := 0; j < iterations; j++ { - rb.Read(buf) - } - }() - } - - wg.Wait() -} - -func TestRingBufferZeroCapacity(t *testing.T) { - rb := newRingBuffer(0) - - var wg sync.WaitGroup - wg.Add(1) - - // Writer - go func() { - defer wg.Done() - data := []byte("data") - _, err := rb.WriteAll(data) - if err != io.ErrClosedPipe { - t.Errorf("Expected io.ErrClosedPipe, got %v", err) - } - }() - - time.Sleep(100 * time.Millisecond) // Ensure writer is waiting - rb.Close() - - wg.Wait() -} - -func TestRingBufferRaceSafety(t *testing.T) { - rb := newRingBuffer(100) - - var wg sync.WaitGroup - producers := 10 - consumers := 10 - - // Producers - for i := 0; i < producers; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for j := 0; j < 1000; j++ { - rb.TryWrite([]byte{byte(j % 256)}) - } - }() - } - - // Consumers - for i := 0; i < consumers; i++ { - wg.Add(1) - go func() { - defer wg.Done() - buf := make([]byte, 10) - for j := 0; j < 1000; j++ { - rb.Read(buf) - } - }() - } - - wg.Wait() -} diff --git a/mod/udp/src/send.go b/mod/udp/src/send.go new file mode 100644 index 000000000..867a1f2a8 --- /dev/null +++ b/mod/udp/src/send.go @@ -0,0 +1,189 @@ +package udp + +import ( + "fmt" + "time" +) + +// Write enqueues data into the send ring buffer. It blocks until enough space is available. +// Signals the sender goroutine after enqueue. +func (c *Conn) writeSend(p []byte) (n int, err error) { + c.sendMu.Lock() + defer c.sendMu.Unlock() + writeLen := len(p) + for writeLen > int(c.sendRB.Free()) { + c.sendCond.Wait() + } + _, err = c.sendRB.Write(p) + if err != nil { + return 0, err + } + c.sendCond.Broadcast() + return writeLen, nil +} + +// windowFull returns true if the packet window is full (no more packets can be sent). +func (c *Conn) windowFull() bool { + return len(c.unacked) >= c.cfg.MaxWindowPackets +} + +// nextSendOffset returns the offset in sendRB for the next fragment to send. +// If no packets are unacked, starts at sendBase. Otherwise, finds the highest offset of unacked packets. +func (c *Conn) nextSendOffset() int { + if len(c.unacked) == 0 { + return int(c.sendBase) + } + + maxOff := 0 + for _, u := range c.unacked { + if u.offset+u.length > maxOff { + maxOff = u.offset + u.length + } + } + return maxOff +} + +// fillFragBuf fills buf with ask bytes from sendRB at offset off. +// Returns an error if the buffer is too small or the read fails. +func (c *Conn) fillFragBuf(buf []byte, _ int, ask int) error { + if ask > len(buf) { + return fmt.Errorf("fragBuf too small: ask=%d, buf=%d", ask, len(buf)) + } + readN, err := c.sendRB.Read(buf[:ask]) + if err != nil { + return err + } + if readN != ask { + return fmt.Errorf("partial read from sendRB: expected %d, got %d", ask, readN) + } + return nil +} + +// sendFragment fragments, marshals, and sends a packet from sendRB at offset off, up to ask bytes. +// Updates unacked and nextSeqNum. Returns true if a packet was sent, false otherwise. +func (c *Conn) sendFragment(off int, ask int) (bool, error) { + c.sendMu.Lock() + + fragBuf := make([]byte, ask) + if err := c.fillFragBuf(fragBuf, off, ask); err != nil { + c.sendMu.Unlock() + return false, err + } + + pkt, pktLen, ok := c.frag.MakeNew(c.nextSeqNum, ask, &ByteStreamBuffer{data: fragBuf[:ask]}) + if !ok || pktLen == 0 { + c.sendMu.Unlock() + return false, nil + } + pkt.Ack = c.ackedSeqNum + + b, err := pkt.Marshal() + if err != nil { + c.sendMu.Unlock() + return false, err + } + + seq := c.nextSeqNum + c.unacked[seq] = &Unacked{ + pkt: pkt, + sentTime: time.Now(), + rtxCount: 0, + offset: off, + length: pktLen, + } + c.nextSeqNum += uint32(pktLen) + startTimer := len(c.unacked) == 1 + c.sendCond.Broadcast() // wake writers: space freed by consumption + c.sendMu.Unlock() + + if startTimer { + c.startRtxTimer() + } + + // Network I/O outside lock + _, err = c.udpConn.WriteToUDP(b, c.remoteEndpoint.UDPAddr()) + if err != nil { + c.sendMu.Lock() + if u, ok2 := c.unacked[seq]; ok2 && u.length == pktLen { + delete(c.unacked, seq) + if c.nextSeqNum == seq+uint32(pktLen) { // rewind if no later sends + c.nextSeqNum = seq + } + if len(c.unacked) == 0 && c.rtxTimer != nil { + c.rtxTimer.Stop() + c.rtxTimer = nil + } + c.sendCond.Broadcast() // notify writers rollback restored space + } + c.sendMu.Unlock() + return false, err + } + + return true, nil +} + +// startRtxTimer arms the retransmission timer if not already running +func (c *Conn) startRtxTimer() { + c.sendMu.Lock() + if c.rtxTimer != nil { + c.sendMu.Unlock() + return // already running + } + interval := c.cfg.RetransmissionInterval + c.rtxTimer = time.AfterFunc(interval, func() { + c.handleRetransmissionTimeout() + c.sendMu.Lock() + if len(c.unacked) > 0 && !c.isClosed() { + c.rtxTimer.Reset(interval) + } else { + c.rtxTimer = nil + } + c.sendMu.Unlock() + }) + c.sendMu.Unlock() +} + +// senderLoop runs as a goroutine and is responsible for sending packets from the send buffer. +// It enforces flow control (packet window), fragments data, and sends packets. +// The loop blocks when the window is full or the buffer is empty, and wakes up on sendCond. +// Exits cleanly on connection close or state change, and notifies waiters. +// For PoC, does not implement advanced pacing, batching, or prioritization. +func (c *Conn) senderLoop() { + defer func() { c.sendCond.Broadcast() }() + for { + c.sendMu.Lock() + if c.isClosed() || !c.inState(StateEstablished) { + c.sendMu.Unlock() + return + } + for c.sendRB.Length() == 0 || c.windowFull() { + c.sendCond.Wait() + if c.isClosed() || !c.inState(StateEstablished) { + c.sendMu.Unlock() + return + } + } + off := c.nextSendOffset() + end := int(c.sendRB.Length()) + if off >= end { // nothing to send after all + c.sendMu.Unlock() + continue + } + rem := end - off + ask := rem + if ask > c.cfg.MaxSegmentSize { + ask = c.cfg.MaxSegmentSize + } + c.sendMu.Unlock() + + sent, err := c.sendFragment(off, ask) + if err != nil { + fmt.Printf("sendFragment error: %v\n", err) + continue + } + if !sent { + fmt.Println("sendFragment did not send packet (sent=false)") + continue + } + } +} diff --git a/mod/udp/src/server.go b/mod/udp/src/server.go index 090eb61d0..9b798ac44 100644 --- a/mod/udp/src/server.go +++ b/mod/udp/src/server.go @@ -112,7 +112,7 @@ func (s *Server) readLoop(ctx *astral.Context, localEndpoint *udp.Endpoint) { conn.inCh = make(chan *Packet, 128) s.conns[remoteKey] = conn go func() { - err := conn.startServerHandshake(ctx, pkt) + err := conn.StartServerHandshake(ctx, pkt) if err != nil { s.log.Errorv(1, "handshake error for %v: %v", addr, err) } From ce795ec7aabdd8709766839957fa3ce2eb37f983 Mon Sep 17 00:00:00 2001 From: Rekseto Date: Tue, 30 Sep 2025 21:38:34 +0200 Subject: [PATCH 05/13] mod/udp/src: major fixes and improvements around buffering and simplifications to handshake logic --- mod/udp/src/config.go | 4 +-- mod/udp/src/config_test.go | 4 +-- mod/udp/src/conn.go | 43 ++++++++++++++++++++++---- mod/udp/src/conn_handshake.go | 17 ++++++++--- mod/udp/src/dial.go | 4 ++- mod/udp/src/receive.go | 8 +++-- mod/udp/src/send.go | 57 ++++++++++++----------------------- mod/udp/src/server.go | 4 ++- 8 files changed, 84 insertions(+), 57 deletions(-) diff --git a/mod/udp/src/config.go b/mod/udp/src/config.go index 30df398f0..fad02def8 100644 --- a/mod/udp/src/config.go +++ b/mod/udp/src/config.go @@ -59,7 +59,7 @@ type Config struct { PublicEndpoints []string `yaml:"public_endpoints,omitempty"` DialTimeout time.Duration `yaml:"dial_timeout,omitempty"` // Timeout for dialing connections (default 1 minute) - FlowControl ReliableTransportConfig `yaml:"flow_control,omitempty"` // Flow control settings for UDP connections + TransportConfig ReliableTransportConfig `yaml:"transport_config,omitempty"` // Flow control settings for UDP connections } // ReliableTransportConfig holds configuration for individual UDP connections. @@ -157,7 +157,7 @@ func clampDur(v, lo, hi time.Duration) time.Duration { var defaultConfig = Config{ ListenPort: ListenPort, DialTimeout: time.Minute, - FlowControl: ReliableTransportConfig{ + TransportConfig: ReliableTransportConfig{ MaxSegmentSize: DefaultMSS, MaxWindowBytes: DefaultWindowBytes, RetransmissionInterval: DefaultRTO, diff --git a/mod/udp/src/config_test.go b/mod/udp/src/config_test.go index 173f34e21..6bd16ac4f 100644 --- a/mod/udp/src/config_test.go +++ b/mod/udp/src/config_test.go @@ -6,9 +6,9 @@ import ( ) func TestFlowControlConfigDefaults(t *testing.T) { - def := defaultConfig.FlowControl + def := defaultConfig.TransportConfig if def.MSS != DefaultMSS || def.WindowBytes != DefaultWindowBytes || def.RTO != DefaultRTO || def.RTOMax != DefaultRTOMax || def.RetryLimit != DefaultRetries || def.AckDelay != DefaultAckDelay || def.RecvBufBytes != DefaultRecvBufBytes || def.SendBufBytes != DefaultSendBufBytes { - t.Errorf("defaultConfig.FlowControl does not match expected defaults: %+v", def) + t.Errorf("defaultConfig.TransportConfig does not match expected defaults: %+v", def) } } diff --git a/mod/udp/src/conn.go b/mod/udp/src/conn.go index 1fdb68b0f..5963856a4 100644 --- a/mod/udp/src/conn.go +++ b/mod/udp/src/conn.go @@ -8,6 +8,7 @@ import ( "sync/atomic" "time" + "github.com/cryptopunkscc/astrald/mod/exonet" "github.com/cryptopunkscc/astrald/mod/udp" "github.com/smallnest/ringbuffer" ) @@ -28,11 +29,10 @@ type Handshaker interface { } type Unacked struct { - pkt *Packet // Packet metadata (seq, len, ring offsets) + pkt *Packet // Packet metadata (seq, len) sentTime time.Time // Last sent time rtxCount int // Retransmit count - offset int // Offset in sendRB (handshake: -1) - length int // Length in sendRB / payload length (handshake: 0) + length int // Payload length isHandshake bool // True if this entry is for a handshake control packet } @@ -50,6 +50,7 @@ type Conn struct { udpConn *net.UDPConn // Underlying UDP socket localEndpoint *udp.Endpoint remoteEndpoint *udp.Endpoint + outbound bool // true if we initiated the connection // Configuration (reliability, flow control, etc.) cfg ReliableTransportConfig // All protocol parameters @@ -71,9 +72,9 @@ type Conn struct { inflight uint32 // Number of unacked packets // Send buffer and reliability - sendRB *ringbuffer.RingBuffer // Persistent send ring buffer + sendRB *ringbuffer.RingBuffer // Persistent send ring buffer (FIFO; bytes consumed at packetization) frag *BasicFragmenter // Fragmenter for packetization - unacked map[uint32]*Unacked // Map of unacked packets (seq -> Unacked) + unacked map[uint32]*Unacked // Map of unacked packets (seq -> Unacked); stores full packet copies for retransmission // Concurrency and coordination sendMu sync.Mutex // Protects all shared state @@ -198,12 +199,27 @@ func NewConn(cn *net.UDPConn, l, r *udp.Endpoint, cfg ReliableTransportConfig) ( rc.recvCond = sync.NewCond(&rc.recvMu) rc.recvOO = make(map[uint32]*Packet) - // start fused receive loop + // start fused receive loop immediately so handshake packets can be processed + // NOTE: senderLoop is started only after handshake succeeds (see onEstablished()) go rc.recvLoop() return rc, nil } +func (c *Conn) onEstablished() { + // Idempotent: only transition once + if c.inState(StateEstablished) || c.isClosed() { + return + } + // Initialize receive-side expected sequence to remote initial seq + 1 (account for SYN consuming one seq) + if c.initialSeqNumRemote != 0 && c.expected == 0 { + c.expected = c.initialSeqNumRemote + 1 + } + c.setState(StateEstablished) + // Future established-only initializations (keepalives, metrics, etc.) go here. + go c.senderLoop() +} + // HandleAckPacket processes ACK packets func (c *Conn) HandleAckPacket(packet *Packet) { ack := packet.Ack @@ -250,3 +266,18 @@ func (c *Conn) HandleControlPacket(packet *Packet) { } // ...handle other control flags as needed... } + +// Interface compliance for exonet.Conn +func (c *Conn) Outbound() bool { return c.outbound } +func (c *Conn) LocalEndpoint() exonet.Endpoint { + if c == nil { + return nil + } + return c.localEndpoint +} +func (c *Conn) RemoteEndpoint() exonet.Endpoint { + if c == nil { + return nil + } + return c.remoteEndpoint +} diff --git a/mod/udp/src/conn_handshake.go b/mod/udp/src/conn_handshake.go index 5b08882a7..5b94cbb9e 100644 --- a/mod/udp/src/conn_handshake.go +++ b/mod/udp/src/conn_handshake.go @@ -60,7 +60,8 @@ func (c *Conn) StartClientHandshake(ctx context.Context) error { if err := c.SendControlPacket(FlagACK, c.initialSeqNumLocal+1, c.initialSeqNumRemote+1); err != nil { return err } - c.setState(StateEstablished) + // Transition to established (state change + sender loop) via helper + c.onEstablished() // fused receive loop will now dispatch directly return nil } @@ -100,7 +101,7 @@ func (c *Conn) StartServerHandshake(ctx context.Context, synPkt *Packet) error { c.sendBase = c.initialSeqNumLocal + 1 c.nextSeqNum = c.initialSeqNumLocal + 1 c.sendMu.Unlock() - c.setState(StateEstablished) + c.onEstablished() // fused receive loop will now dispatch directly return nil } @@ -118,7 +119,11 @@ func (c *Conn) sendHandshakeControl(flags uint8, seq, ack uint32) error { if c.udpConn == nil { return udp.ErrConnClosed } - if _, err := c.udpConn.Write(b); err != nil { + // Use WriteToUDP to support server-side unconnected sockets and client connected sockets uniformly. + if c.remoteEndpoint == nil { + return fmt.Errorf("remote endpoint nil") + } + if _, err := c.udpConn.WriteToUDP(b, c.remoteEndpoint.UDPAddr()); err != nil { return err } c.sendMu.Lock() @@ -127,7 +132,6 @@ func (c *Conn) sendHandshakeControl(flags uint8, seq, ack uint32) error { pkt: pkt, sentTime: time.Now(), rtxCount: 0, - offset: -1, length: 0, isHandshake: true, } @@ -149,7 +153,10 @@ func (c *Conn) SendControlPacket(flags uint8, seq, ack uint32) error { if c.udpConn == nil { return udp.ErrConnClosed } - _, err = c.udpConn.Write(data) + if c.remoteEndpoint == nil { + return fmt.Errorf("remote endpoint nil") + } + _, err = c.udpConn.WriteToUDP(data, c.remoteEndpoint.UDPAddr()) if err != nil { return fmt.Errorf(`SendControlPacket failed to send control packet: %w`, err) } diff --git a/mod/udp/src/dial.go b/mod/udp/src/dial.go index 136383fa5..aa29d0da5 100644 --- a/mod/udp/src/dial.go +++ b/mod/udp/src/dial.go @@ -34,10 +34,12 @@ func (mod *Module) Dial(ctx *astral.Context, endpoint exonet.Endpoint) (exonet.C } reliableConn, err := NewConn(udpConn, localEndpoint, remoteEndpoint, - mod.config.FlowControl) + mod.config.TransportConfig) if err != nil { return nil, err } + reliableConn.outbound = true // mark as outbound connection + return reliableConn, nil } diff --git a/mod/udp/src/receive.go b/mod/udp/src/receive.go index 34d234dab..97e4c7b26 100644 --- a/mod/udp/src/receive.go +++ b/mod/udp/src/receive.go @@ -28,10 +28,14 @@ func (c *Conn) recvLoop() { } continue } - // address filter - if addr.String() != c.remoteEndpoint.IP.String() { + // address filter (compare only IP; optional port check commented for NAT flexibility) + if !addr.IP.Equal(net.IP(c.remoteEndpoint.IP)) { + // TODO: debug log for unexpected source (guarded by verbosity flag) continue } + // If strict port matching is desired and stable, uncomment: + // if addr.Port != int(c.remoteEndpoint.Port) { continue } + if n < 13 { // minimum header size continue } diff --git a/mod/udp/src/send.go b/mod/udp/src/send.go index 867a1f2a8..350b66465 100644 --- a/mod/udp/src/send.go +++ b/mod/udp/src/send.go @@ -27,29 +27,12 @@ func (c *Conn) windowFull() bool { return len(c.unacked) >= c.cfg.MaxWindowPackets } -// nextSendOffset returns the offset in sendRB for the next fragment to send. -// If no packets are unacked, starts at sendBase. Otherwise, finds the highest offset of unacked packets. -func (c *Conn) nextSendOffset() int { - if len(c.unacked) == 0 { - return int(c.sendBase) - } - - maxOff := 0 - for _, u := range c.unacked { - if u.offset+u.length > maxOff { - maxOff = u.offset + u.length - } - } - return maxOff -} - -// fillFragBuf fills buf with ask bytes from sendRB at offset off. -// Returns an error if the buffer is too small or the read fails. -func (c *Conn) fillFragBuf(buf []byte, _ int, ask int) error { +// fillFragBuf reads exactly ask bytes from the head of sendRB into buf. +func (c *Conn) fillFragBuf(buf []byte, ask int) error { if ask > len(buf) { return fmt.Errorf("fragBuf too small: ask=%d, buf=%d", ask, len(buf)) } - readN, err := c.sendRB.Read(buf[:ask]) + readN, err := c.sendRB.Read(buf[:ask]) // destructive read (FIFO) if err != nil { return err } @@ -59,13 +42,21 @@ func (c *Conn) fillFragBuf(buf []byte, _ int, ask int) error { return nil } -// sendFragment fragments, marshals, and sends a packet from sendRB at offset off, up to ask bytes. -// Updates unacked and nextSeqNum. Returns true if a packet was sent, false otherwise. -func (c *Conn) sendFragment(off int, ask int) (bool, error) { +// sendFragment consumes up to ask bytes from the send ring buffer, builds a packet and sends it. +// Returns true if a packet was sent. +func (c *Conn) sendFragment(ask int) (bool, error) { c.sendMu.Lock() + if ask <= 0 || c.sendRB.Length() == 0 { + c.sendMu.Unlock() + return false, nil + } + if ask > int(c.sendRB.Length()) { // clamp to available + ask = int(c.sendRB.Length()) + } + fragBuf := make([]byte, ask) - if err := c.fillFragBuf(fragBuf, off, ask); err != nil { + if err := c.fillFragBuf(fragBuf, ask); err != nil { c.sendMu.Unlock() return false, err } @@ -88,7 +79,6 @@ func (c *Conn) sendFragment(off int, ask int) (bool, error) { pkt: pkt, sentTime: time.Now(), rtxCount: 0, - offset: off, length: pktLen, } c.nextSeqNum += uint32(pktLen) @@ -144,10 +134,8 @@ func (c *Conn) startRtxTimer() { } // senderLoop runs as a goroutine and is responsible for sending packets from the send buffer. -// It enforces flow control (packet window), fragments data, and sends packets. -// The loop blocks when the window is full or the buffer is empty, and wakes up on sendCond. -// Exits cleanly on connection close or state change, and notifies waiters. -// For PoC, does not implement advanced pacing, batching, or prioritization. +// FIFO model: bytes are consumed from sendRB as soon as they are packetized. Retransmissions +// use copies stored in unacked map. No random access over the ring is performed. func (c *Conn) senderLoop() { defer func() { c.sendCond.Broadcast() }() for { @@ -163,20 +151,13 @@ func (c *Conn) senderLoop() { return } } - off := c.nextSendOffset() - end := int(c.sendRB.Length()) - if off >= end { // nothing to send after all - c.sendMu.Unlock() - continue - } - rem := end - off - ask := rem + ask := int(c.sendRB.Length()) if ask > c.cfg.MaxSegmentSize { ask = c.cfg.MaxSegmentSize } c.sendMu.Unlock() - sent, err := c.sendFragment(off, ask) + sent, err := c.sendFragment(ask) if err != nil { fmt.Printf("sendFragment error: %v\n", err) continue diff --git a/mod/udp/src/server.go b/mod/udp/src/server.go index 9b798ac44..f9dcd92b7 100644 --- a/mod/udp/src/server.go +++ b/mod/udp/src/server.go @@ -102,13 +102,15 @@ func (s *Server) readLoop(ctx *astral.Context, localEndpoint *udp.Endpoint) { continue } - conn, err = NewConn(s.listener, localEndpoint, remoteEndpoint, s.Module.config.FlowControl) + conn, err = NewConn(s.listener, localEndpoint, remoteEndpoint, s.Module.config.TransportConfig) if err != nil { s.log.Errorv(1, "NewConn error for %v: %v", addr, err) s.mutex.Unlock() continue } + conn.outbound = false // mark as inbound connection + conn.inCh = make(chan *Packet, 128) s.conns[remoteKey] = conn go func() { From 8f4468aeae7542d1a63488e5ca5d3d242595353b Mon Sep 17 00:00:00 2001 From: Rekseto Date: Wed, 1 Oct 2025 14:05:15 +0200 Subject: [PATCH 06/13] mod/udp: major fixes and improvements around buffering and simplifications to handshake logic --- mod/udp/errors.go | 1 + mod/udp/rudp/config.go | 64 +++++++ mod/udp/rudp/config_test.go | 205 ++++++++++++++++++++++ mod/udp/{src => rudp}/conn.go | 207 +++++++++++++++++------ mod/udp/{src => rudp}/conn_handshake.go | 103 +++++++---- mod/udp/{src => rudp}/fragmenter.go | 2 +- mod/udp/{src => rudp}/fragmenter_test.go | 2 +- mod/udp/rudp/integration_test.go | 164 ++++++++++++++++++ mod/udp/rudp/listener.go | 153 +++++++++++++++++ mod/udp/{src => rudp}/packet.go | 2 +- mod/udp/{src => rudp}/packet_test.go | 2 +- mod/udp/{src => rudp}/receive.go | 20 ++- mod/udp/{src => rudp}/retransmissions.go | 9 +- mod/udp/{src => rudp}/send.go | 6 +- mod/udp/src/config.go | 169 +----------------- mod/udp/src/config_test.go | 78 --------- mod/udp/src/dial.go | 24 ++- mod/udp/src/server.go | 133 ++++----------- 18 files changed, 890 insertions(+), 454 deletions(-) create mode 100644 mod/udp/rudp/config.go create mode 100644 mod/udp/rudp/config_test.go rename mod/udp/{src => rudp}/conn.go (65%) rename mod/udp/{src => rudp}/conn_handshake.go (66%) rename mod/udp/{src => rudp}/fragmenter.go (99%) rename mod/udp/{src => rudp}/fragmenter_test.go (99%) create mode 100644 mod/udp/rudp/integration_test.go create mode 100644 mod/udp/rudp/listener.go rename mod/udp/{src => rudp}/packet.go (99%) rename mod/udp/{src => rudp}/packet_test.go (99%) rename mod/udp/{src => rudp}/receive.go (90%) rename mod/udp/{src => rudp}/retransmissions.go (95%) rename mod/udp/{src => rudp}/send.go (95%) delete mode 100644 mod/udp/src/config_test.go diff --git a/mod/udp/errors.go b/mod/udp/errors.go index 4b8924e12..e7f3dc17b 100644 --- a/mod/udp/errors.go +++ b/mod/udp/errors.go @@ -3,6 +3,7 @@ package udp import "errors" var ( + ErrListenerClosed = errors.New("listener closed") ErrRetransmissionLimitExceeded = errors.New( "retransmissions limit exceeded") ErrPacketTooShort = errors.New("packet too short") diff --git a/mod/udp/rudp/config.go b/mod/udp/rudp/config.go new file mode 100644 index 000000000..10151f978 --- /dev/null +++ b/mod/udp/rudp/config.go @@ -0,0 +1,64 @@ +package rudp + +import "time" + +// Transport default constants (exported for visibility in tests & integration) +const ( + DefaultMSS = 1200 - 13 // 1187 (1200 minus header) + DefaultWindowBytes = 16 * DefaultMSS + DefaultWndPkts = 32 + DefaultRTO = 500 * time.Millisecond + DefaultRTOMax = 4 * time.Second + DefaultRetries = 8 + DefaultAckDelay = 25 * time.Millisecond + DefaultRecvBufBytes = 1 << 20 + DefaultSendBufBytes = 1 << 20 +) + +// Config holds reliability / buffering parameters for the rudp transport. +type Config struct { + MaxSegmentSize int `yaml:"max_segment_size"` + MaxWindowBytes int `yaml:"max_window_bytes"` + MaxWindowPackets int `yaml:"max_window_packets"` + RetransmissionInterval time.Duration `yaml:"retransmission_interval"` + MaxRetransmissionInterval time.Duration `yaml:"max_retransmission_interval"` + RetransmissionLimit int `yaml:"retransmission_limit"` + AckDelay time.Duration `yaml:"ack_delay"` + RecvBufBytes int `yaml:"recv_buf_bytes"` + SendBufBytes int `yaml:"send_buf_bytes"` +} + +// Normalize applies defaults to zero-value fields (no clamping beyond basic sanity here). +func (c *Config) Normalize() { + if c.MaxSegmentSize == 0 { + c.MaxSegmentSize = DefaultMSS + } + if c.MaxWindowBytes == 0 { + c.MaxWindowBytes = DefaultWindowBytes + } + if c.MaxWindowPackets == 0 { + c.MaxWindowPackets = DefaultWndPkts + } + if c.RetransmissionInterval == 0 { + c.RetransmissionInterval = DefaultRTO + } + if c.MaxRetransmissionInterval == 0 { + c.MaxRetransmissionInterval = DefaultRTOMax + } + if c.RetransmissionLimit == 0 { + c.RetransmissionLimit = DefaultRetries + } + if c.AckDelay == 0 { + c.AckDelay = DefaultAckDelay + } + if c.RecvBufBytes == 0 { + c.RecvBufBytes = DefaultRecvBufBytes + } + if c.SendBufBytes == 0 { + c.SendBufBytes = DefaultSendBufBytes + } +} + +// - AckDelay: mirrors QUIC MAX_ACK_DELAY (RFC 9000 §13.2.1). +// - Buffer sizes: 1 MiB default, capped for safety, must be >= window. +// - All invariants enforced for safety and interoperability. diff --git a/mod/udp/rudp/config_test.go b/mod/udp/rudp/config_test.go new file mode 100644 index 000000000..17074f5e6 --- /dev/null +++ b/mod/udp/rudp/config_test.go @@ -0,0 +1,205 @@ +package rudp + +import ( + "testing" + "time" +) + +// TestNormalizeAppliesDefaults verifies that a zero-value Config gets all default values. +func TestNormalizeAppliesDefaults(t *testing.T) { + var c Config + c.Normalize() + + if c.MaxSegmentSize != DefaultMSS { + f(t, "MaxSegmentSize", c.MaxSegmentSize, DefaultMSS) + } + if c.MaxWindowBytes != DefaultWindowBytes { + f(t, "MaxWindowBytes", c.MaxWindowBytes, DefaultWindowBytes) + } + if c.MaxWindowPackets != DefaultWndPkts { + f(t, "MaxWindowPackets", c.MaxWindowPackets, DefaultWndPkts) + } + if c.RetransmissionInterval != DefaultRTO { + f(t, "RetransmissionInterval", c.RetransmissionInterval, DefaultRTO) + } + if c.MaxRetransmissionInterval != DefaultRTOMax { + f(t, "MaxRetransmissionInterval", c.MaxRetransmissionInterval, DefaultRTOMax) + } + if c.RetransmissionLimit != DefaultRetries { + f(t, "RetransmissionLimit", c.RetransmissionLimit, DefaultRetries) + } + if c.AckDelay != DefaultAckDelay { + f(t, "AckDelay", c.AckDelay, DefaultAckDelay) + } + if c.RecvBufBytes != DefaultRecvBufBytes { + f(t, "RecvBufBytes", c.RecvBufBytes, DefaultRecvBufBytes) + } + if c.SendBufBytes != DefaultSendBufBytes { + f(t, "SendBufBytes", c.SendBufBytes, DefaultSendBufBytes) + } +} + +// TestNormalizePreservesNonZero verifies that non-zero fields are not overwritten. +func TestNormalizePreservesNonZero(t *testing.T) { + orig := Config{ + MaxSegmentSize: 999, + MaxWindowBytes: 123456, + MaxWindowPackets: 77, + RetransmissionInterval: 321 * time.Millisecond, + MaxRetransmissionInterval: 987 * time.Millisecond, + RetransmissionLimit: 42, + AckDelay: 11 * time.Millisecond, + RecvBufBytes: 222222, + SendBufBytes: 333333, + } + c := orig + c.Normalize() + + if c != orig { + // Compare field-by-field for clearer diagnostics. + if c.MaxSegmentSize != orig.MaxSegmentSize { + g(t, "MaxSegmentSize", c.MaxSegmentSize, orig.MaxSegmentSize) + } + if c.MaxWindowBytes != orig.MaxWindowBytes { + g(t, "MaxWindowBytes", c.MaxWindowBytes, orig.MaxWindowBytes) + } + if c.MaxWindowPackets != orig.MaxWindowPackets { + g(t, "MaxWindowPackets", c.MaxWindowPackets, orig.MaxWindowPackets) + } + if c.RetransmissionInterval != orig.RetransmissionInterval { + g(t, "RetransmissionInterval", c.RetransmissionInterval, orig.RetransmissionInterval) + } + if c.MaxRetransmissionInterval != orig.MaxRetransmissionInterval { + g(t, "MaxRetransmissionInterval", c.MaxRetransmissionInterval, orig.MaxRetransmissionInterval) + } + if c.RetransmissionLimit != orig.RetransmissionLimit { + g(t, "RetransmissionLimit", c.RetransmissionLimit, orig.RetransmissionLimit) + } + if c.AckDelay != orig.AckDelay { + g(t, "AckDelay", c.AckDelay, orig.AckDelay) + } + if c.RecvBufBytes != orig.RecvBufBytes { + g(t, "RecvBufBytes", c.RecvBufBytes, orig.RecvBufBytes) + } + if c.SendBufBytes != orig.SendBufBytes { + g(t, "SendBufBytes", c.SendBufBytes, orig.SendBufBytes) + } + // Fail after reporting discrepancies. + if t.Failed() { + return + } + } +} + +// TestNormalizePartial ensures only zero fields get populated. +func TestNormalizePartial(t *testing.T) { + c := Config{ + MaxSegmentSize: 500, // keep + // others zero + AckDelay: 5 * time.Millisecond, // keep + } + c.Normalize() + + if c.MaxSegmentSize != 500 { + g(t, "MaxSegmentSize", c.MaxSegmentSize, 500) + } + if c.AckDelay != 5*time.Millisecond { + g(t, "AckDelay", c.AckDelay, 5*time.Millisecond) + } + + if c.MaxWindowBytes != DefaultWindowBytes { + f(t, "MaxWindowBytes", c.MaxWindowBytes, DefaultWindowBytes) + } + if c.MaxWindowPackets != DefaultWndPkts { + f(t, "MaxWindowPackets", c.MaxWindowPackets, DefaultWndPkts) + } + if c.RetransmissionInterval != DefaultRTO { + f(t, "RetransmissionInterval", c.RetransmissionInterval, DefaultRTO) + } + if c.MaxRetransmissionInterval != DefaultRTOMax { + f(t, "MaxRetransmissionInterval", c.MaxRetransmissionInterval, DefaultRTOMax) + } + if c.RetransmissionLimit != DefaultRetries { + f(t, "RetransmissionLimit", c.RetransmissionLimit, DefaultRetries) + } + if c.RecvBufBytes != DefaultRecvBufBytes { + f(t, "RecvBufBytes", c.RecvBufBytes, DefaultRecvBufBytes) + } + if c.SendBufBytes != DefaultSendBufBytes { + f(t, "SendBufBytes", c.SendBufBytes, DefaultSendBufBytes) + } +} + +// TestNormalizeIdempotent ensures calling Normalize twice doesn't change values after first call. +func TestNormalizeIdempotent(t *testing.T) { + var c Config + c.Normalize() + first := c + c.Normalize() + if c != first { + g(t, "ConfigAfterSecondNormalize", c, first) + } +} + +// TestNormalizeNegativeValues ensures negative values are preserved (no implicit clamping yet). +func TestNormalizeNegativeValues(t *testing.T) { + c := Config{ + MaxSegmentSize: -1, + MaxWindowBytes: -2, + MaxWindowPackets: -3, + RetransmissionLimit: -4, + RecvBufBytes: -5, + SendBufBytes: -6, + } + // Durations negative as well + c.RetransmissionInterval = -10 * time.Millisecond + c.MaxRetransmissionInterval = -20 * time.Millisecond + c.AckDelay = -30 * time.Millisecond + + c.Normalize() + + if c.MaxSegmentSize != -1 { + g(t, "MaxSegmentSize", c.MaxSegmentSize, -1) + } + if c.MaxWindowBytes != -2 { + g(t, "MaxWindowBytes", c.MaxWindowBytes, -2) + } + if c.MaxWindowPackets != -3 { + g(t, "MaxWindowPackets", c.MaxWindowPackets, -3) + } + if c.RetransmissionLimit != -4 { + g(t, "RetransmissionLimit", c.RetransmissionLimit, -4) + } + if c.RecvBufBytes != -5 { + g(t, "RecvBufBytes", c.RecvBufBytes, -5) + } + if c.SendBufBytes != -6 { + g(t, "SendBufBytes", c.SendBufBytes, -6) + } + if c.RetransmissionInterval != -10*time.Millisecond { + g(t, "RetransmissionInterval", c.RetransmissionInterval, -10*time.Millisecond) + } + if c.MaxRetransmissionInterval != -20*time.Millisecond { + g(t, "MaxRetransmissionInterval", c.MaxRetransmissionInterval, -20*time.Millisecond) + } + if c.AckDelay != -30*time.Millisecond { + g(t, "AckDelay", c.AckDelay, -30*time.Millisecond) + } +} + +// Helper failure formatters for brevity. +func f[T comparable](t *testing.T, field string, got, want T) { + if got != want { + // Using t.Fatalf to stop early in default-path tests where cascading errors add little value. + // Use %v for generic print. + //nolint:forbidigo // simple test diagnostic + t.Fatalf("%s mismatch: got=%v want=%v", field, got, want) + } +} + +func g[T comparable](t *testing.T, field string, got, want T) { + if got != want { + //nolint:forbidigo + t.Errorf("%s mismatch: got=%v want=%v", field, got, want) + } +} diff --git a/mod/udp/src/conn.go b/mod/udp/rudp/conn.go similarity index 65% rename from mod/udp/src/conn.go rename to mod/udp/rudp/conn.go index 5963856a4..7b179b822 100644 --- a/mod/udp/src/conn.go +++ b/mod/udp/rudp/conn.go @@ -1,33 +1,21 @@ // conn.go -package udp +package rudp import ( + "errors" + "fmt" "io" "net" "sync" "sync/atomic" "time" + "github.com/cryptopunkscc/astrald/astral" "github.com/cryptopunkscc/astrald/mod/exonet" "github.com/cryptopunkscc/astrald/mod/udp" "github.com/smallnest/ringbuffer" ) -// DatagramWriter is how Conn sends bytes to its peer. -type DatagramWriter interface { - WriteDatagram(b []byte) error -} - -// DatagramReceiver is how Conn *receives* parsed packets when it does not own a socket read loop. -// (For active conns, the recvLoop calls HandleDatagram itself.) -type DatagramReceiver interface { - HandleDatagram(raw []byte) // fast path: parse + process (ACK/data) -} - -type Handshaker interface { - Handshake() error -} - type Unacked struct { pkt *Packet // Packet metadata (seq, len) sentTime time.Time // Last sent time @@ -47,13 +35,15 @@ type Unacked struct { // - PoC limitations: no congestion control, no SACK, no adaptive pacing type Conn struct { // UDP socket and addressing - udpConn *net.UDPConn // Underlying UDP socket - localEndpoint *udp.Endpoint - remoteEndpoint *udp.Endpoint - outbound bool // true if we initiated the connection + udpConn *net.UDPConn // Underlying UDP socket + localEndpoint *udp.Endpoint + remoteEndpoint *udp.Endpoint + outbound bool // true if we initiated the connection + onEstablishedCb func(*Conn) + onClosedCb func(*Conn) // Configuration (reliability, flow control, etc.) - cfg ReliableTransportConfig // All protocol parameters + cfg Config // All protocol parameters // Connection state (atomic for lock-free reads) state uint32 // Current connection state (stores ConnState) @@ -97,6 +87,88 @@ type Conn struct { recvOO map[uint32]*Packet // stored packets with Seq > expected awaiting in-order delivery } +// unified handshake channel capacity (applies to inbound & outbound) +const handshakeQueueCap = 64 + +// NewConn constructs a connection around a UDP socket. +// Parameters: +// +// outbound - true for client initiated connection (owns socket & recv loop) +// firstPacket / ctx - required only for inbound (outbound=false) to drive server handshake +func NewConn(cn *net.UDPConn, l, r *udp.Endpoint, cfg Config, outbound bool, firstPacket *Packet, ctx *astral.Context) (*Conn, error) { + cfg.Normalize() + if cfg.MaxSegmentSize <= 0 { + return nil, udp.ErrZeroMSS + } + + if !outbound { + if firstPacket == nil || ctx == nil { + return nil, errors.New("inbound connection requires firstPacket and ctx") + } + } + + sendRBSize := cfg.MaxWindowBytes * 2 // allow for some retransmit slack + rb := ringbuffer.New(sendRBSize) + frag := NewBasicFragmenter(cfg.MaxSegmentSize) + + rc := &Conn{ + udpConn: cn, + localEndpoint: l, + remoteEndpoint: r, + cfg: cfg, + sendRB: rb, + frag: frag, + unacked: make(map[uint32]*Unacked), + ErrChan: make(chan error, 1), // Buffered to avoid blocking + inCh: make(chan *Packet, handshakeQueueCap), // handshake delivery channel + outbound: outbound, + } + rc.sendCond = sync.NewCond(&rc.sendMu) + rc.recvRB = ringbuffer.New(cfg.RecvBufBytes) + rc.recvCond = sync.NewCond(&rc.recvMu) + rc.recvOO = make(map[uint32]*Packet) + + if outbound { + // Do not start recvLoop here; StartClientHandshake will start it after handshake completes + } else { + go rc.recvLoop() + } + + if !outbound { + // Start server handshake asynchronously for inbound connections + go func() { + if err := rc.StartServerHandshake(ctx, firstPacket); err != nil { + rc.Close() + } + }() + } + + return rc, nil +} + +// OnEstablished registers a callback invoked exactly once when the connection transitions to Established. +func (c *Conn) OnEstablished(cb func(*Conn)) { + c.sendMu.Lock() + c.onEstablishedCb = cb + c.sendMu.Unlock() + // Fast path: if already established, invoke asynchronously + if c.inState(StateEstablished) && cb != nil { + go cb(c) + } +} + +// OnClosed registers a callback invoked exactly once after Close() releases resources. +// If the connection is already closed when registering, the callback is invoked asynchronously. +func (c *Conn) OnClosed(cb func(*Conn)) { + c.sendMu.Lock() + c.onClosedCb = cb + closed := c.isClosed() + c.sendMu.Unlock() + if closed && cb != nil { + go cb(c) + } +} + func (c *Conn) setState(state ConnState) { atomic.StoreUint32(&c.state, uint32(state)) } @@ -154,12 +226,15 @@ func (c *Conn) Close() error { c.sendCond.Broadcast() ch := c.inCh c.inCh = nil // detach channel to prevent further sends + closedCb := c.onClosedCb c.sendMu.Unlock() if ch != nil { close(ch) } - _ = c.udpConn.SetReadDeadline(time.Now()) + if c.outbound { // only needed to unblock recvLoop for outbound + _ = c.udpConn.SetReadDeadline(time.Now()) + } c.recvMu.Lock() if c.ackTimer != nil { c.ackTimer.Stop() @@ -169,41 +244,16 @@ func (c *Conn) Close() error { c.recvCond.Broadcast() } c.recvMu.Unlock() - return c.udpConn.Close() -} - -// NewConn constructs a connection around an already-connected UDP socket. -func NewConn(cn *net.UDPConn, l, r *udp.Endpoint, cfg ReliableTransportConfig) (*Conn, error) { - cfg.Normalize() - if cfg.MaxSegmentSize <= 0 { - return nil, udp.ErrZeroMSS + var err error + if c.outbound { + err = c.udpConn.Close() } - - sendRBSize := cfg.MaxWindowBytes * 2 // allow for some retransmit slack - rb := ringbuffer.New(sendRBSize) - frag := NewBasicFragmenter(cfg.MaxSegmentSize) - - rc := &Conn{ - udpConn: cn, - localEndpoint: l, - remoteEndpoint: r, - cfg: cfg, - sendRB: rb, - frag: frag, - unacked: make(map[uint32]*Unacked), - ErrChan: make(chan error, 1), // Buffered to avoid blocking - inCh: make(chan *Packet, 32), // handshake delivery channel + // Invoke close callback after resources released + if closedCb != nil { + closedCb(c) } - rc.sendCond = sync.NewCond(&rc.sendMu) - rc.recvRB = ringbuffer.New(cfg.RecvBufBytes) - rc.recvCond = sync.NewCond(&rc.recvMu) - rc.recvOO = make(map[uint32]*Packet) - - // start fused receive loop immediately so handshake packets can be processed - // NOTE: senderLoop is started only after handshake succeeds (see onEstablished()) - go rc.recvLoop() - return rc, nil + return err } func (c *Conn) onEstablished() { @@ -218,6 +268,11 @@ func (c *Conn) onEstablished() { c.setState(StateEstablished) // Future established-only initializations (keepalives, metrics, etc.) go here. go c.senderLoop() + // Invoke callback (outside locks) if set + cb := func() func(*Conn) { c.sendMu.Lock(); defer c.sendMu.Unlock(); return c.onEstablishedCb }() + if cb != nil { + cb(c) + } } // HandleAckPacket processes ACK packets @@ -281,3 +336,47 @@ func (c *Conn) RemoteEndpoint() exonet.Endpoint { } return c.remoteEndpoint } + +// ProcessPacket feeds a received packet into the connection (server-side demux path). +// If the connection is not yet established, the packet is queued for handshake processing. +func (c *Conn) ProcessPacket(pkt *Packet) { + if !c.inState(StateEstablished) { + c.sendMu.Lock() + ch := c.inCh + c.sendMu.Unlock() + if ch != nil { + select { + case ch <- pkt: + default: + } + } + return + } + if pkt.Flags&FlagACK != 0 { + c.HandleAckPacket(pkt) + return + } + if pkt.Flags&(FlagSYN|FlagFIN) != 0 { + c.HandleControlPacket(pkt) + return + } + c.handleDataPacket(pkt) +} + +// sendDatagram sends a raw packet buffer choosing the correct syscall based on +// connection role (outbound connections use Write on a connected socket; +// inbound connections use WriteToUDP specifying the remote endpoint). +func (c *Conn) sendDatagram(b []byte) (n int, err error) { + if c.udpConn == nil { + return n, udp.ErrConnClosed + } + if c.outbound { + _, err := c.udpConn.Write(b) + return n, err + } + if c.remoteEndpoint == nil { + return n, fmt.Errorf("remote endpoint nil") + } + written, err := c.udpConn.WriteToUDP(b, c.remoteEndpoint.UDPAddr()) + return written, err +} diff --git a/mod/udp/src/conn_handshake.go b/mod/udp/rudp/conn_handshake.go similarity index 66% rename from mod/udp/src/conn_handshake.go rename to mod/udp/rudp/conn_handshake.go index 5b94cbb9e..83311de78 100644 --- a/mod/udp/src/conn_handshake.go +++ b/mod/udp/rudp/conn_handshake.go @@ -1,9 +1,10 @@ -package udp +package rudp import ( "context" "crypto/rand" "fmt" + "net" "time" "github.com/cryptopunkscc/astrald/mod/udp" @@ -23,49 +24,70 @@ const ( ) func (c *Conn) StartClientHandshake(ctx context.Context) error { + return c.startClientHandshakeDirect(ctx) +} + +// startClientHandshakeDirect performs the 3-way handshake using direct socket reads. +func (c *Conn) startClientHandshakeDirect(ctx context.Context) error { seq, err := randUint32NZ() if err != nil { return fmt.Errorf("failed to generate initial sequence number: %w", err) } c.initialSeqNumLocal = seq - c.connID = c.initialSeqNumLocal + c.connID = seq c.setState(StateSynSent) - // build + send initial SYN and register in unacked for unified retransmission - if err := c.sendHandshakeControl(FlagSYN, c.initialSeqNumLocal, 0); err != nil { + if err := c.sendHandshakeControl(FlagSYN, seq, 0); err != nil { return err } + buf := make([]byte, 1500) + deadlineInterval := 300 * time.Millisecond + for { - select { - case <-ctx.Done(): + if ctx.Err() != nil { return udp.ErrHandshakeTimeout - case pkt := <-c.inCh: - if pkt.Flags&(FlagSYN|FlagACK) == (FlagSYN|FlagACK) && pkt.Ack == c.initialSeqNumLocal+1 && pkt.Seq != 0 { - // got valid SYN|ACK - c.initialSeqNumRemote = pkt.Seq - // remove our SYN from unacked and set sequence bases - c.sendMu.Lock() - delete(c.unacked, c.initialSeqNumLocal) - if len(c.unacked) == 0 && c.rtxTimer != nil { - c.rtxTimer.Stop() - c.rtxTimer = nil - } - c.ackedSeqNum = c.initialSeqNumLocal + 1 - c.sendBase = c.initialSeqNumLocal + 1 - c.nextSeqNum = c.initialSeqNumLocal + 1 - c.sendMu.Unlock() + } + _ = c.udpConn.SetReadDeadline(time.Now().Add(deadlineInterval)) + n, addr, err := c.udpConn.ReadFromUDP(buf) - // send final ACK (not tracked for retransmission) - if err := c.SendControlPacket(FlagACK, c.initialSeqNumLocal+1, c.initialSeqNumRemote+1); err != nil { - return err - } - // Transition to established (state change + sender loop) via helper - c.onEstablished() - // fused receive loop will now dispatch directly - return nil + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Timeout() { + continue + } + continue + } + // Filter by remote IP (ignore port mismatch in case of NAT rebinding; only IP match) + if c.remoteEndpoint != nil && !addr.IP.Equal(net.IP(c.remoteEndpoint.IP)) { + continue + } + if n < 13 { + continue + } + pkt := &Packet{} + if err := pkt.Unmarshal(buf[:n]); err != nil { + continue + } + + // Expect SYN|ACK + if pkt.Flags&(FlagSYN|FlagACK) == (FlagSYN|FlagACK) && pkt.Ack == seq+1 && pkt.Seq != 0 { + c.initialSeqNumRemote = pkt.Seq + // finalize local send base + c.sendMu.Lock() + c.ackedSeqNum = seq + 1 + c.sendBase = seq + 1 + c.nextSeqNum = seq + 1 + c.sendMu.Unlock() + // Send final ACK + if err := c.SendControlPacket(FlagACK, seq+1, c.initialSeqNumRemote+1); err != nil { + return err } + c.onEstablished() + // Start recvLoop after establishment for outbound + go c.recvLoop() + return nil } + // ignore other control/data until handshake completes } } @@ -80,7 +102,8 @@ func (c *Conn) StartServerHandshake(ctx context.Context, synPkt *Packet) error { c.setState(StateSynReceived) // send SYN|ACK and register for retransmission - if err := c.sendHandshakeControl(FlagSYN|FlagACK, c.initialSeqNumLocal, c.initialSeqNumRemote+1); err != nil { + err = c.sendHandshakeControl(FlagSYN|FlagACK, c.initialSeqNumLocal, c.initialSeqNumRemote+1) + if err != nil { return err } @@ -119,13 +142,17 @@ func (c *Conn) sendHandshakeControl(flags uint8, seq, ack uint32) error { if c.udpConn == nil { return udp.ErrConnClosed } - // Use WriteToUDP to support server-side unconnected sockets and client connected sockets uniformly. if c.remoteEndpoint == nil { return fmt.Errorf("remote endpoint nil") } - if _, err := c.udpConn.WriteToUDP(b, c.remoteEndpoint.UDPAddr()); err != nil { + + _, err = c.sendDatagram(b) + if err != nil { return err } + + // Register packet and decide if timer start is needed - do this OUTSIDE the lock + needTimer := false c.sendMu.Lock() if _, exists := c.unacked[seq]; !exists { // only register first time c.unacked[seq] = &Unacked{ @@ -136,10 +163,16 @@ func (c *Conn) sendHandshakeControl(flags uint8, seq, ack uint32) error { isHandshake: true, } if c.rtxTimer == nil { - c.startRtxTimer() + needTimer = true } } c.sendMu.Unlock() + + // Start timer AFTER releasing lock to avoid deadlock + if needTimer { + c.startRtxTimer() + } + return nil } @@ -156,11 +189,11 @@ func (c *Conn) SendControlPacket(flags uint8, seq, ack uint32) error { if c.remoteEndpoint == nil { return fmt.Errorf("remote endpoint nil") } - _, err = c.udpConn.WriteToUDP(data, c.remoteEndpoint.UDPAddr()) + _, err = c.sendDatagram(data) if err != nil { return fmt.Errorf(`SendControlPacket failed to send control packet: %w`, err) } - return err + return nil } func randUint32NZ() (uint32, error) { diff --git a/mod/udp/src/fragmenter.go b/mod/udp/rudp/fragmenter.go similarity index 99% rename from mod/udp/src/fragmenter.go rename to mod/udp/rudp/fragmenter.go index 3c15018cc..0cc8d427b 100644 --- a/mod/udp/src/fragmenter.go +++ b/mod/udp/rudp/fragmenter.go @@ -1,4 +1,4 @@ -package udp +package rudp // Fragmenter turns buffered bytes into wire packets and reproduces the exact // same boundaries for retransmission. diff --git a/mod/udp/src/fragmenter_test.go b/mod/udp/rudp/fragmenter_test.go similarity index 99% rename from mod/udp/src/fragmenter_test.go rename to mod/udp/rudp/fragmenter_test.go index 1d7a8bf1d..366acccdd 100644 --- a/mod/udp/src/fragmenter_test.go +++ b/mod/udp/rudp/fragmenter_test.go @@ -1,4 +1,4 @@ -package udp +package rudp import ( "testing" diff --git a/mod/udp/rudp/integration_test.go b/mod/udp/rudp/integration_test.go new file mode 100644 index 000000000..bc9555cfc --- /dev/null +++ b/mod/udp/rudp/integration_test.go @@ -0,0 +1,164 @@ +package rudp + +import ( + "context" + "net" + "testing" + "time" + + "github.com/cryptopunkscc/astrald/astral" + udpmod "github.com/cryptopunkscc/astrald/mod/udp" +) + +// TestListenerDialHelloWorld exercises a minimal end-to-end handshake and +// one-shot data transfer ("Hello World") between an outbound client Conn +// and an inbound server Conn accepted through Listener.Accept(). +func TestListenerDialHelloWorld(t *testing.T) { + baseCtx := astral.NewContext(context.Background()) + + // Start listener on an IPv4 loopback ephemeral port (force IPv4 to avoid ::/127.0.0.1 mismatch) + l, err := Listen(baseCtx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, Config{}, 2*time.Second) + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + defer l.Close() + + serverAddr := l.Addr().(*net.UDPAddr) + // Force IPv4 127.0.0.1 target (avoid ::1 / unspecified ambiguity) + ipv4Dest := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: serverAddr.Port} + + // Channel to receive accepted server connection + acceptedCh := make(chan *Conn, 1) + + // Accept in background + go func() { + acceptCtx, cancel := baseCtx.WithTimeout(3 * time.Second) + defer cancel() + c, err := l.Accept(acceptCtx) + if err != nil { + return + } + acceptedCh <- c + }() + + // Dial UDP (raw) for outbound side + udpConn, err := net.DialUDP("udp4", nil, ipv4Dest) + if err != nil { + // fallback try udp if udp4 failed + udpConn, err = net.DialUDP("udp", nil, ipv4Dest) + } + if err != nil { + l.Close() + t.Fatalf("DialUDP failed: %v", err) + } + defer udpConn.Close() + + // Build endpoints + localEP, _ := udpmod.ParseEndpoint(udpConn.LocalAddr().String()) + remoteEP, _ := udpmod.ParseEndpoint(udpConn.RemoteAddr().String()) + + // Create outbound reliable Conn + outConn, err := NewConn(udpConn, localEP, remoteEP, Config{}, true, nil, baseCtx) + if err != nil { + t.Fatalf("NewConn outbound failed: %v", err) + } + t.Logf("client local=%v remote=%v", udpConn.LocalAddr(), udpConn.RemoteAddr()) + + // Run client handshake + hCtx, hCancel := baseCtx.WithTimeout(2 * time.Second) + defer hCancel() + if err := outConn.StartClientHandshake(hCtx); err != nil { + outConn.Close() + l.Close() + t.Fatalf("client handshake failed: %v", err) + } + t.Logf("client handshake complete") + + // Wait for server side acceptance + var serverConn *Conn + select { + case serverConn = <-acceptedCh: + if serverConn != nil { + t.Logf("server accepted remote=%v", serverConn.RemoteEndpoint()) + } + case <-time.After(3 * time.Second): + // handshake should have completed well before this + outConn.Close() + l.Close() + t.Fatalf("timeout waiting for Accept()") + } + if serverConn == nil { + outConn.Close() + l.Close() + t.Fatalf("nil serverConn returned") + } + defer serverConn.Close() + + // Send payload client->server + msg := []byte("Hello World") + if _, err := outConn.Write(msg); err != nil { + // ensure cleanup before failing + outConn.Close() + serverConn.Close() + l.Close() + t.Fatalf("client write failed: %v", err) + } + t.Logf("client wrote payload len=%d", len(msg)) + + // Read at server side (no direct read deadline API; use goroutine + timeout) + readCh := make(chan struct{}) + var got []byte + var readErr error + go func() { + b := make([]byte, 64) + if n, err := serverConn.Read(b); err != nil { + readErr = err + } else { + got = append(got, b[:n]...) + } + close(readCh) + }() + select { + case <-readCh: + case <-time.After(10 * time.Second): + outConn.Close() + serverConn.Close() + l.Close() + t.Fatalf("timeout waiting for server read") + } + if readErr != nil { + outConn.Close() + serverConn.Close() + l.Close() + t.Fatalf("server read error: %v", readErr) + } + if string(got) != string(msg) { + outConn.Close() + serverConn.Close() + l.Close() + t.Fatalf("unexpected server payload: got %q want %q", string(got), string(msg)) + } + t.Logf("server read payload: %s", string(got)) + + // (Optional) Echo back from server to client to validate reverse path + if _, err := serverConn.Write([]byte("ACK")); err == nil { + // attempt client read with timeout + clientReadCh := make(chan []byte, 1) + go func() { + b := make([]byte, 8) + if n, e := outConn.Read(b); e == nil { + clientReadCh <- b[:n] + } + close(clientReadCh) + }() + select { + case resp := <-clientReadCh: + if len(resp) > 0 && string(resp) != "ACK" { + // Non-fatal; log mismatch + // t.Logf("unexpected client echo: %q", string(resp)) + } + case <-time.After(2 * time.Second): + // ignore echo timeout (non-fatal for core test) + } + } +} diff --git a/mod/udp/rudp/listener.go b/mod/udp/rudp/listener.go new file mode 100644 index 000000000..8623e659e --- /dev/null +++ b/mod/udp/rudp/listener.go @@ -0,0 +1,153 @@ +package rudp + +import ( + "net" + "sync" + "sync/atomic" + "time" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/udp" +) + +const defaultAcceptBacklog = 32 +const defaultHandshakeTimeout = 5 * time.Second + +// Listener accepts inbound RUDP connections (server side) and returns only +// fully established connections via Accept(). +type Listener struct { + udpConn *net.UDPConn + cfg Config + baseCtx *astral.Context + mu sync.Mutex + conns map[string]*Conn // remoteKey -> Conn (includes handshaking ones) + acceptCh chan *Conn + closed atomic.Bool +} + +// Listen creates a new Listener bound to addr. handshakeTimeout==0 falls back to a small default. +func Listen(ctx *astral.Context, addr *net.UDPAddr, cfg Config, handshakeTimeout time.Duration) (*Listener, error) { + conn, err := net.ListenUDP("udp", addr) + if err != nil { + return nil, err + } + cfg.Normalize() + if handshakeTimeout <= 0 { + handshakeTimeout = defaultHandshakeTimeout + } + l := &Listener{ + udpConn: conn, + cfg: cfg, + baseCtx: ctx, + conns: make(map[string]*Conn), + acceptCh: make(chan *Conn, defaultAcceptBacklog), + } + go l.readLoop(handshakeTimeout) + return l, nil +} + +// Addr returns the underlying listening address. +func (l *Listener) Addr() net.Addr { return l.udpConn.LocalAddr() } + +// Accept blocks until an established connection is available, the context is canceled, or the listener is closed. +func (l *Listener) Accept(ctx *astral.Context) (*Conn, error) { + for { + if l.closed.Load() { + return nil, udp.ErrListenerClosed + } + select { + case c, ok := <-l.acceptCh: + if !ok { + return nil, udp.ErrListenerClosed + } + return c, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +// Close shuts down the listener and all active connections. +func (l *Listener) Close() error { + if !l.closed.CompareAndSwap(false, true) { + return udp.ErrListenerClosed + } + // Closing the UDP socket unblocks readLoop. + _ = l.udpConn.Close() + l.mu.Lock() + for _, c := range l.conns { + c.Close() + } + l.conns = nil + l.mu.Unlock() + close(l.acceptCh) + return nil +} + +// readLoop performs demultiplexing and inbound connection setup. +func (l *Listener) readLoop(handshakeTimeout time.Duration) { + buf := make([]byte, 64*1024) + for { + if l.closed.Load() { + return + } + n, addr, err := l.udpConn.ReadFromUDP(buf) + if err != nil { + if l.closed.Load() { + return + } + continue + } + if n < 13 { // minimal header length + continue + } + pkt := &Packet{} + if err := pkt.Unmarshal(buf[:n]); err != nil { + continue + } + + remoteKey := addr.String() + l.mu.Lock() + conn := l.conns[remoteKey] + if conn == nil && pkt.Flags&FlagSYN != 0 { // new inbound attempt + remoteEP, perr := udp.ParseEndpoint(addr.String()) + if perr != nil { + l.mu.Unlock() + continue + } + localEP, _ := udp.ParseEndpoint(l.udpConn.LocalAddr().String()) + // Per-handshake timeout context derived from baseCtx + hCtx, cancel := l.baseCtx.WithTimeout(handshakeTimeout) + c, cerr := NewConn(l.udpConn, localEP, remoteEP, l.cfg, false, pkt, hCtx) + if cerr != nil { // immediate constructor error; cancel context + cancel() + l.mu.Unlock() + continue + } + c.OnEstablished(func(ec *Conn) { + if l.closed.Load() { + cancel() + return + } + select { + case l.acceptCh <- ec: + default: + } + cancel() + }) + c.OnClosed(func(ec *Conn) { + l.mu.Lock() + delete(l.conns, remoteKey) + l.mu.Unlock() + cancel() + }) + l.conns[remoteKey] = c + conn = c + } + l.mu.Unlock() + + if conn != nil { + conn.ProcessPacket(pkt) + } + } +} diff --git a/mod/udp/src/packet.go b/mod/udp/rudp/packet.go similarity index 99% rename from mod/udp/src/packet.go rename to mod/udp/rudp/packet.go index cb7d499b5..9d0de18bf 100644 --- a/mod/udp/src/packet.go +++ b/mod/udp/rudp/packet.go @@ -1,4 +1,4 @@ -package udp +package rudp import ( "bytes" diff --git a/mod/udp/src/packet_test.go b/mod/udp/rudp/packet_test.go similarity index 99% rename from mod/udp/src/packet_test.go rename to mod/udp/rudp/packet_test.go index 167337082..f4e2760e5 100644 --- a/mod/udp/src/packet_test.go +++ b/mod/udp/rudp/packet_test.go @@ -1,4 +1,4 @@ -package udp +package rudp import ( "bytes" diff --git a/mod/udp/src/receive.go b/mod/udp/rudp/receive.go similarity index 90% rename from mod/udp/src/receive.go rename to mod/udp/rudp/receive.go index 97e4c7b26..0663914c7 100644 --- a/mod/udp/src/receive.go +++ b/mod/udp/rudp/receive.go @@ -1,4 +1,4 @@ -package udp +package rudp import ( "net" @@ -30,11 +30,8 @@ func (c *Conn) recvLoop() { } // address filter (compare only IP; optional port check commented for NAT flexibility) if !addr.IP.Equal(net.IP(c.remoteEndpoint.IP)) { - // TODO: debug log for unexpected source (guarded by verbosity flag) continue } - // If strict port matching is desired and stable, uncomment: - // if addr.Port != int(c.remoteEndpoint.Port) { continue } if n < 13 { // minimum header size continue @@ -59,15 +56,28 @@ func (c *Conn) recvLoop() { } // Established: direct dispatch + // Handle data packets first (they may have ACK piggybacked) + if pkt.Len > 0 { + // Process piggyback ACK if present + if pkt.Flags&FlagACK != 0 { + c.HandleAckPacket(pkt) + } + // Process data payload + c.handleDataPacket(pkt) + continue + } + + // Pure ACK packets (no data payload) if pkt.Flags&FlagACK != 0 { c.HandleAckPacket(pkt) continue } + + // Pure control packets (SYN, FIN, etc.) if pkt.Flags&(FlagSYN|FlagFIN) != 0 { c.HandleControlPacket(pkt) continue } - c.handleDataPacket(pkt) } } diff --git a/mod/udp/src/retransmissions.go b/mod/udp/rudp/retransmissions.go similarity index 95% rename from mod/udp/src/retransmissions.go rename to mod/udp/rudp/retransmissions.go index 850d5c317..44701a196 100644 --- a/mod/udp/src/retransmissions.go +++ b/mod/udp/rudp/retransmissions.go @@ -1,4 +1,4 @@ -package udp +package rudp import ( "sort" @@ -80,8 +80,8 @@ func (c *Conn) sendPureACK() { if err != nil { return } - // best-effort send (no tracking) - _, _ = c.udpConn.Write(b) + // best-effort send via unified path + _, _ = c.sendDatagram(b) } // handleRetransmissionTimeoutLocked assumes sendMu is held and performs retransmissions. @@ -103,11 +103,12 @@ func (c *Conn) handleRetransmissionTimeoutLocked() (limitExceeded bool) { limitExceeded = true break } + // Update ACK field to latest cumulative ACK and retransmit u.pkt.Ack = c.ackedSeqNum b, err := u.pkt.Marshal() if err == nil { - _, _ = c.udpConn.WriteToUDP(b, c.remoteEndpoint.UDPAddr()) + _, _ = c.sendDatagram(b) } u.rtxCount++ u.sentTime = time.Now() diff --git a/mod/udp/src/send.go b/mod/udp/rudp/send.go similarity index 95% rename from mod/udp/src/send.go rename to mod/udp/rudp/send.go index 350b66465..80767ae5f 100644 --- a/mod/udp/src/send.go +++ b/mod/udp/rudp/send.go @@ -1,4 +1,4 @@ -package udp +package rudp import ( "fmt" @@ -91,7 +91,7 @@ func (c *Conn) sendFragment(ask int) (bool, error) { } // Network I/O outside lock - _, err = c.udpConn.WriteToUDP(b, c.remoteEndpoint.UDPAddr()) + _, err = c.sendDatagram(b) if err != nil { c.sendMu.Lock() if u, ok2 := c.unacked[seq]; ok2 && u.length == pktLen { @@ -159,11 +159,9 @@ func (c *Conn) senderLoop() { sent, err := c.sendFragment(ask) if err != nil { - fmt.Printf("sendFragment error: %v\n", err) continue } if !sent { - fmt.Println("sendFragment did not send packet (sent=false)") continue } } diff --git a/mod/udp/src/config.go b/mod/udp/src/config.go index fad02def8..b4bd5bbea 100644 --- a/mod/udp/src/config.go +++ b/mod/udp/src/config.go @@ -2,179 +2,18 @@ package udp import ( "time" -) - -// RFC-backed constants for src UDP config -const ( - ListenPort = 1791 - // QUIC requires endpoints to handle 1200-byte UDP datagrams without fragmentation (RFC 9000 §14.1) - - DefaultMSS = 1200 - 13 // 1187: 1200 minus our header - MinMSS = 512 // RFC 8085: avoid fragmentation, safe for most links - MaxMSS = 1400 // Keeps under 1500B MTU with IP/UDP/tunnel headroom (RFC 8085, RFC 791, RFC 8200) - - // WindowBytes conservative buffer, RFC 8085 - - DefaultWindowBytes = 16 * DefaultMSS - MinWindowBytes = MinMSS - MaxWindowBytes = 1 << 20 // 1 MiB - - // Retransmission timers: RFC 6298 - - DefaultRTO = 500 * time.Millisecond - DefaultRTOMax = 4 * time.Second - MinRTO = 10 * time.Millisecond // LAN-friendly floor - MaxRTOCeiling = 60 * time.Second // Avoid excessive backoff - - // Retries fail fast on persistent loss - DefaultRetries = 8 - MinRetries = 1 - MaxRetries = 20 - - // AckDelay QUIC MAX_ACK_DELAY (RFC 9000 §13.2.1) - - DefaultAckDelay = 25 * time.Millisecond - MinAckDelay = 0 - - // Buffer sizes - - DefaultRecvBufBytes = 1 << 20 // 1 MiB - MinRecvBufBytes = DefaultWindowBytes // Should be at least as large as window - MaxRecvBufBytes = 8 << 20 // 8 MiB - DefaultSendBufBytes = 1 << 20 - MinSendBufBytes = DefaultWindowBytes - MaxSendBufBytes = 8 << 20 + "github.com/cryptopunkscc/astrald/mod/udp/rudp" ) -const ( - DefaultWndPkts = 32 - MinWndPkts = 1 - MaxWndPkts = 256 -) - -// Config holds general settings for the UDP module. type Config struct { ListenPort int `yaml:"listen_port,omitempty"` // Port to listen on for incoming connections (default 1791) PublicEndpoints []string `yaml:"public_endpoints,omitempty"` - DialTimeout time.Duration `yaml:"dial_timeout,omitempty"` // Timeout for dialing connections (default 1 minute) - - TransportConfig ReliableTransportConfig `yaml:"transport_config,omitempty"` // Flow control settings for UDP connections -} - -// ReliableTransportConfig holds configuration for individual UDP connections. -type ReliableTransportConfig struct { - MaxSegmentSize int // Maximum Segment Size (default 1187) - MaxWindowBytes int // Send window size in bytes (default 16 * MaxSegmentSize) - MaxWindowPackets int // Max in-flight packets (packet-count window, default 32) - RetransmissionInterval time.Duration // Initial retransmission timeout (default 500ms) - MaxRetransmissionInterval time.Duration // Maximum retransmission timeout (default 4s) - RetransmissionLimit int // Maximum retransmission attempts (default 8) - IdleTimeout time.Duration // Connection idle timeout (default 60s) - AckDelay time.Duration // Delayed ACK timer (default 25ms) - RecvBufBytes int // Receive buffer size (default 1MB) - SendBufBytes int // Send buffer size (default 1MB) -} - -// Normalize sets sensible defaults for zero-values, clamps to safe ranges, and enforces invariants. -// See RFC 9000, RFC 8085, RFC 6298 for rationale. -func (c *ReliableTransportConfig) Normalize() { - c.SetDefaults() - c.clampValues() -} - -// SetDefaults initializes zero-values with sensible defaults. -func (c *ReliableTransportConfig) SetDefaults() { - if c.MaxSegmentSize == 0 { - c.MaxSegmentSize = DefaultMSS - } - if c.MaxWindowBytes == 0 { - c.MaxWindowBytes = DefaultWindowBytes - } - if c.MaxWindowPackets == 0 { - c.MaxWindowPackets = DefaultWndPkts - } - if c.RetransmissionInterval == 0 { - c.RetransmissionInterval = DefaultRTO - } - if c.MaxRetransmissionInterval == 0 { - c.MaxRetransmissionInterval = DefaultRTOMax - } - if c.RetransmissionLimit == 0 { - c.RetransmissionLimit = DefaultRetries - } - if c.AckDelay == 0 { - c.AckDelay = DefaultAckDelay - } - if c.RecvBufBytes == 0 { - c.RecvBufBytes = DefaultRecvBufBytes - } - if c.SendBufBytes == 0 { - c.SendBufBytes = DefaultSendBufBytes - } -} - -// NOTE: normally i would not introduce such function but when it comes to -// parameters of network protocols, -// i believe it is better to keep things within certain range of values ( -// all of which are stated at the top of this file) - -// clampValues ensures all fields are within safe ranges and enforces invariants. -func (c *ReliableTransportConfig) clampValues() { - c.MaxSegmentSize = clampInt(c.MaxSegmentSize, MinMSS, MaxMSS) - c.MaxWindowBytes = clampInt(c.MaxWindowBytes, c.MaxSegmentSize, MaxWindowBytes) - c.MaxWindowPackets = clampInt(c.MaxWindowPackets, MinWndPkts, MaxWndPkts) - c.RetransmissionInterval = clampDur(c.RetransmissionInterval, MinRTO, MaxRTOCeiling) - c.MaxRetransmissionInterval = clampDur(c.MaxRetransmissionInterval, c.RetransmissionInterval, MaxRTOCeiling) - c.RetransmissionLimit = clampInt(c.RetransmissionLimit, MinRetries, MaxRetries) - c.AckDelay = clampDur(c.AckDelay, MinAckDelay, c.RetransmissionInterval/2) - c.RecvBufBytes = clampInt(c.RecvBufBytes, MinRecvBufBytes, MaxRecvBufBytes) - c.SendBufBytes = clampInt(c.SendBufBytes, MinSendBufBytes, MaxSendBufBytes) -} - -// clampInt clamps an integer value to a specified range. -func clampInt(v, lo, hi int) int { - if v < lo { - return lo - } - if v > hi { - return hi - } - return v -} - -// clampDur clamps a time.Duration value to a specified range. -func clampDur(v, lo, hi time.Duration) time.Duration { - if v < lo { - return lo - } - if v > hi { - return hi - } - return v + DialTimeout time.Duration `yaml:"dial_timeout,omitempty"` // Timeout for dialing connections (default 1 minute) + TransportConfig rudp.Config `yaml:"transport_config,omitempty"` // Flow control settings for UDP connections } var defaultConfig = Config{ - ListenPort: ListenPort, DialTimeout: time.Minute, - TransportConfig: ReliableTransportConfig{ - MaxSegmentSize: DefaultMSS, - MaxWindowBytes: DefaultWindowBytes, - RetransmissionInterval: DefaultRTO, - MaxRetransmissionInterval: DefaultRTOMax, - RetransmissionLimit: DefaultRetries, - IdleTimeout: 60 * time.Second, // Default idle timeout of 1 minute - AckDelay: DefaultAckDelay, - RecvBufBytes: DefaultRecvBufBytes, - SendBufBytes: DefaultSendBufBytes, - }, + ListenPort: 1791, } - -// RFC rationale summary: -// -// - MSS: QUIC requires 1200B UDP datagrams (RFC 9000 §14.1), clamped to avoid fragmentation (RFC 8085). -// - WindowBytes: conservative buffer, RFC 8085, must be >= MSS. -// - RTO/RTOMax: TCP discipline (RFC 6298), pragmatic for UDP, exponential backoff. -// - AckDelay: mirrors QUIC MAX_ACK_DELAY (RFC 9000 §13.2.1). -// - Buffer sizes: 1 MiB default, capped for safety, must be >= window. -// - All invariants enforced for safety and interoperability. diff --git a/mod/udp/src/config_test.go b/mod/udp/src/config_test.go deleted file mode 100644 index 6bd16ac4f..000000000 --- a/mod/udp/src/config_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package udp - -import ( - "testing" - "time" -) - -func TestFlowControlConfigDefaults(t *testing.T) { - def := defaultConfig.TransportConfig - if def.MSS != DefaultMSS || def.WindowBytes != DefaultWindowBytes || def.RTO != DefaultRTO || def.RTOMax != DefaultRTOMax || def.RetryLimit != DefaultRetries || def.AckDelay != DefaultAckDelay || def.RecvBufBytes != DefaultRecvBufBytes || def.SendBufBytes != DefaultSendBufBytes { - t.Errorf("defaultConfig.TransportConfig does not match expected defaults: %+v", def) - } -} - -func TestFlowControlConfigClamp(t *testing.T) { - tests := []struct { - name string - input ReliableTransportConfig - expected ReliableTransportConfig - }{ - { - name: "Values below range are clamped", - input: ReliableTransportConfig{ - MSS: 100, - WindowBytes: 100, - RTO: 5 * time.Millisecond, - RTOMax: 5 * time.Millisecond, - RetryLimit: 0, - AckDelay: -1 * time.Millisecond, - RecvBufBytes: 100, - SendBufBytes: 100, - }, - expected: ReliableTransportConfig{ - MSS: MinMSS, - WindowBytes: MinMSS, - RTO: MinRTO, - RTOMax: MinRTO, - RetryLimit: MinRetries, - AckDelay: MinAckDelay, - RecvBufBytes: MinRecvBufBytes, - SendBufBytes: MinSendBufBytes, - }, - }, - { - name: "Values above range are clamped", - input: ReliableTransportConfig{ - MSS: 2000, - WindowBytes: 2 << 20, - RTO: 70 * time.Second, - RTOMax: 70 * time.Second, - RetryLimit: 50, - AckDelay: 1 * time.Second, - RecvBufBytes: 16 << 20, - SendBufBytes: 16 << 20, - }, - expected: ReliableTransportConfig{ - MSS: MaxMSS, - WindowBytes: MaxWindowBytes, - RTO: MaxRTOCeiling, - RTOMax: MaxRTOCeiling, - RetryLimit: MaxRetries, - AckDelay: 1 * time.Second, // Correct expected value - RecvBufBytes: MaxRecvBufBytes, - SendBufBytes: MaxSendBufBytes, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - input := test.input - input.clampValues() - if input != test.expected { - t.Errorf("clampValues() failed.\nGot: %+v\nExpected: %+v", input, test.expected) - } - }) - } -} diff --git a/mod/udp/src/dial.go b/mod/udp/src/dial.go index aa29d0da5..15fc615f3 100644 --- a/mod/udp/src/dial.go +++ b/mod/udp/src/dial.go @@ -6,20 +6,28 @@ import ( "github.com/cryptopunkscc/astrald/astral" "github.com/cryptopunkscc/astrald/mod/exonet" "github.com/cryptopunkscc/astrald/mod/udp" + "github.com/cryptopunkscc/astrald/mod/udp/rudp" ) var _ exonet.Dialer = &Module{} +// Dial establishes a reliable (rudp) connection and returns it only after the +// RUDP client handshake succeeds. Timeout / cancellation behavior: +// - If the caller's context has a deadline, it governs both dial and handshake. +// - Otherwise net.Dialer.Timeout (DialTimeout in config, if >0) limits only the dial phase. +// - The handshake then reuses the original context (may block if no deadline provided). func (mod *Module) Dial(ctx *astral.Context, endpoint exonet.Endpoint) (exonet.Conn, error) { switch endpoint.Network() { case "udp": - // Supported network default: return nil, exonet.ErrUnsupportedNetwork } - // Use net.Dialer for dialing UDP connections - dialer := net.Dialer{Timeout: mod.config.DialTimeout} + dialer := net.Dialer{} + if mod.config.DialTimeout > 0 { + dialer.Timeout = mod.config.DialTimeout + } + conn, err := dialer.DialContext(ctx, "udp", endpoint.Address()) if err != nil { return nil, err @@ -33,13 +41,17 @@ func (mod *Module) Dial(ctx *astral.Context, endpoint exonet.Endpoint) (exonet.C return nil, exonet.ErrUnsupportedNetwork } - reliableConn, err := NewConn(udpConn, localEndpoint, remoteEndpoint, - mod.config.TransportConfig) + reliableConn, err := rudp.NewConn(udpConn, localEndpoint, remoteEndpoint, + mod.config.TransportConfig, true, nil, ctx) if err != nil { return nil, err } - reliableConn.outbound = true // mark as outbound connection + err = reliableConn.StartClientHandshake(ctx) + if err != nil { + reliableConn.Close() + return nil, err + } return reliableConn, nil } diff --git a/mod/udp/src/server.go b/mod/udp/src/server.go index f9dcd92b7..3d1055774 100644 --- a/mod/udp/src/server.go +++ b/mod/udp/src/server.go @@ -2,133 +2,68 @@ package udp import ( "net" - "sync" + "time" "github.com/cryptopunkscc/astrald/astral" - "github.com/cryptopunkscc/astrald/mod/udp" + "github.com/cryptopunkscc/astrald/mod/udp/rudp" ) -// Server implements src UDP listener with connection demultiplexing +// Server implements UDP listening with connection acceptance via rudp.Listener type Server struct { *Module - listener *net.UDPConn - conns map[string]*Conn // Remote address → connection map - mutex sync.Mutex // Protects access to conns - acceptCh chan *Conn // Channel for accepted connections - stopCh chan struct{} // Channel to signal server shutdown - wg sync.WaitGroup // WaitGroup for managing goroutines + rListener *rudp.Listener + acceptCh chan *rudp.Conn } // NewServer creates a new src UDP server func NewServer(module *Module) *Server { return &Server{ Module: module, - conns: make(map[string]*Conn), - acceptCh: make(chan *Conn, 16), - stopCh: make(chan struct{}), + acceptCh: make(chan *rudp.Conn, 16), } } // Run starts the server and listens for incoming connections func (s *Server) Run(ctx *astral.Context) error { - listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: s.config.ListenPort}) - if err != nil { - s.log.Errorv(0, "failed to start server: %v", err) - return err + addr := &net.UDPAddr{Port: s.config.ListenPort} + hto := s.config.DialTimeout + if hto <= 0 { + hto = 5 * time.Second } - s.listener = listener - s.log.Info("started server at %v", listener.LocalAddr()) - defer s.log.Info("stopped server at %v", listener.LocalAddr()) - - localEndpoint, err := udp.ParseEndpoint(listener.LocalAddr(). - String()) + rListener, err := rudp.Listen(ctx, addr, s.Module.config.TransportConfig, hto) if err != nil { - s.log.Errorv(1, "error parsing local endpoint: %v", err) + s.log.Errorv(0, "failed to start rudp listener: %v", err) return err } - - s.wg.Add(1) - go s.readLoop(ctx, localEndpoint) + s.rListener = rListener + s.log.Info("started server at %v", rListener.Addr()) + defer s.log.Info("stopped server at %v", rListener.Addr()) + + // Accept loop + go func() { + acceptCtx := ctx + for { + conn, err := rListener.Accept(acceptCtx) + if err != nil { + return + } + select { + case s.acceptCh <- conn: + default: + // drop if application not consuming fast enough + } + } + }() <-ctx.Done() - s.Close() - return nil + return s.Close() } // Close gracefully shuts down the server func (s *Server) Close() error { - close(s.stopCh) - s.mutex.Lock() - for _, conn := range s.conns { - conn.Close() + if s.rListener != nil { + _ = s.rListener.Close() } - s.mutex.Unlock() - - s.listener.Close() - s.wg.Wait() return nil } - -func (s *Server) readLoop(ctx *astral.Context, localEndpoint *udp.Endpoint) { - defer s.wg.Done() - - buf := make([]byte, 64*1024) // TODO: Max packet size? - for { - n, addr, err := s.listener.ReadFromUDP(buf) - if err != nil { - select { - case <-s.stopCh: - return // Graceful shutdown - default: - s.log.Errorv(1, "read error: %v", err) - continue - } - } - - pkt := &Packet{} - if err := pkt.Unmarshal(buf[:n]); err != nil { - s.log.Errorv(1, "packet unmarshal error from %v: %v", addr, err) - continue // drop malformed - } - - remoteKey := addr.String() - s.mutex.Lock() - conn, foundConn := s.conns[remoteKey] - if !foundConn && pkt.Flags&FlagSYN != 0 { - remoteEndpoint, err := udp.ParseEndpoint(addr.String()) - if err != nil { - s.log.Errorv(1, "ParseEndpoint error for %v: %v", addr, err) - continue - } - - conn, err = NewConn(s.listener, localEndpoint, remoteEndpoint, s.Module.config.TransportConfig) - if err != nil { - s.log.Errorv(1, "NewConn error for %v: %v", addr, err) - s.mutex.Unlock() - continue - } - - conn.outbound = false // mark as inbound connection - - conn.inCh = make(chan *Packet, 128) - s.conns[remoteKey] = conn - go func() { - err := conn.StartServerHandshake(ctx, pkt) - if err != nil { - s.log.Errorv(1, "handshake error for %v: %v", addr, err) - } - }() - } - s.mutex.Unlock() - - if conn != nil { - select { - case conn.inCh <- pkt: - // success - default: - s.log.Errorv(1, "inCh full for %v, dropping packet", addr) - } - } - } -} From 46f9dcdb3d91ec124d26756ae68a21537954aa31 Mon Sep 17 00:00:00 2001 From: Rekseto Date: Wed, 1 Oct 2025 16:30:43 +0200 Subject: [PATCH 07/13] mod/udp: streamline packet sending and error handling in reliable UDP --- mod/udp/rudp/send.go | 133 +++++++++++++++++++++---------------------- 1 file changed, 65 insertions(+), 68 deletions(-) diff --git a/mod/udp/rudp/send.go b/mod/udp/rudp/send.go index 80767ae5f..d9ed57716 100644 --- a/mod/udp/rudp/send.go +++ b/mod/udp/rudp/send.go @@ -1,7 +1,6 @@ package rudp import ( - "fmt" "time" ) @@ -27,88 +26,90 @@ func (c *Conn) windowFull() bool { return len(c.unacked) >= c.cfg.MaxWindowPackets } -// fillFragBuf reads exactly ask bytes from the head of sendRB into buf. -func (c *Conn) fillFragBuf(buf []byte, ask int) error { - if ask > len(buf) { - return fmt.Errorf("fragBuf too small: ask=%d, buf=%d", ask, len(buf)) +// planFragmentLocked decides how many bytes to send (<= ask) and drains them into a fresh buffer. +// Caller MUST hold sendMu. Returns nil buffer if nothing to send. +func (c *Conn) planFragmentLocked(ask int) (seq uint32, buf []byte, n int) { + if ask <= 0 || c.sendRB.Length() == 0 { + return 0, nil, 0 } - readN, err := c.sendRB.Read(buf[:ask]) // destructive read (FIFO) - if err != nil { - return err + if ask > int(c.sendRB.Length()) { + ask = int(c.sendRB.Length()) } - if readN != ask { - return fmt.Errorf("partial read from sendRB: expected %d, got %d", ask, readN) + if ask > c.cfg.MaxSegmentSize { + ask = c.cfg.MaxSegmentSize } - return nil + fragBuf := make([]byte, ask) + readN, _ := c.sendRB.Read(fragBuf) + if readN == 0 { + return 0, nil, 0 + } + return c.nextSeqNum, fragBuf[:readN], readN } -// sendFragment consumes up to ask bytes from the send ring buffer, builds a packet and sends it. -// Returns true if a packet was sent. -func (c *Conn) sendFragment(ask int) (bool, error) { - c.sendMu.Lock() - - if ask <= 0 || c.sendRB.Length() == 0 { - c.sendMu.Unlock() - return false, nil +// buildPacket converts raw payload into a Packet and marshals it. +func (c *Conn) buildPacket(seq uint32, payload []byte) (*Packet, []byte, error) { + pkt := &Packet{Seq: seq, Ack: c.ackedSeqNum, Flags: FlagACK, Len: uint16(len(payload)), Payload: payload} + b, err := pkt.Marshal() + if err != nil { + return nil, nil, err } - if ask > int(c.sendRB.Length()) { // clamp to available - ask = int(c.sendRB.Length()) + return pkt, b, nil +} + +// commitPacketLocked registers the packet as unacked and advances sequence numbers. Caller holds sendMu. +func (c *Conn) commitPacketLocked(pkt *Packet) (startTimer bool) { + seq := pkt.Seq + c.unacked[seq] = &Unacked{pkt: pkt, sentTime: time.Now(), rtxCount: 0, length: int(pkt.Len)} + c.nextSeqNum += uint32(pkt.Len) + return len(c.unacked) == 1 +} + +// armRetransmitTimer starts retransmission timer if needed (no lock held). +func (c *Conn) armRetransmitTimer(need bool) { + if need { + c.startRtxTimer() } +} - fragBuf := make([]byte, ask) - if err := c.fillFragBuf(fragBuf, ask); err != nil { - c.sendMu.Unlock() - return false, err +// rollbackPacketLocked removes an unacked packet on send failure. Caller does NOT hold lock upon entry. +func (c *Conn) rollbackPacketLocked(seq uint32, length int) { + c.sendMu.Lock() + if u, ok := c.unacked[seq]; ok && u.length == length { + delete(c.unacked, seq) + if c.nextSeqNum == seq+uint32(length) { + c.nextSeqNum = seq + } + if len(c.unacked) == 0 && c.rtxTimer != nil { + c.rtxTimer.Stop() + c.rtxTimer = nil + } + c.sendCond.Broadcast() } + c.sendMu.Unlock() +} - pkt, pktLen, ok := c.frag.MakeNew(c.nextSeqNum, ask, &ByteStreamBuffer{data: fragBuf[:ask]}) - if !ok || pktLen == 0 { +// sendFragment consumes up to ask bytes, builds a packet and sends it. +func (c *Conn) sendFragment(ask int) (bool, error) { + c.sendMu.Lock() + seq, payload, plen := c.planFragmentLocked(ask) + if plen == 0 { c.sendMu.Unlock() return false, nil } - pkt.Ack = c.ackedSeqNum - - b, err := pkt.Marshal() + pkt, raw, err := c.buildPacket(seq, payload) if err != nil { c.sendMu.Unlock() return false, err } - - seq := c.nextSeqNum - c.unacked[seq] = &Unacked{ - pkt: pkt, - sentTime: time.Now(), - rtxCount: 0, - length: pktLen, - } - c.nextSeqNum += uint32(pktLen) - startTimer := len(c.unacked) == 1 - c.sendCond.Broadcast() // wake writers: space freed by consumption + startTimer := c.commitPacketLocked(pkt) + c.sendCond.Broadcast() c.sendMu.Unlock() - if startTimer { - c.startRtxTimer() - } - - // Network I/O outside lock - _, err = c.sendDatagram(b) - if err != nil { - c.sendMu.Lock() - if u, ok2 := c.unacked[seq]; ok2 && u.length == pktLen { - delete(c.unacked, seq) - if c.nextSeqNum == seq+uint32(pktLen) { // rewind if no later sends - c.nextSeqNum = seq - } - if len(c.unacked) == 0 && c.rtxTimer != nil { - c.rtxTimer.Stop() - c.rtxTimer = nil - } - c.sendCond.Broadcast() // notify writers rollback restored space - } - c.sendMu.Unlock() + c.armRetransmitTimer(startTimer) + if _, err := c.sendDatagram(raw); err != nil { + c.rollbackPacketLocked(seq, int(pkt.Len)) return false, err } - return true, nil } @@ -117,7 +118,7 @@ func (c *Conn) startRtxTimer() { c.sendMu.Lock() if c.rtxTimer != nil { c.sendMu.Unlock() - return // already running + return } interval := c.cfg.RetransmissionInterval c.rtxTimer = time.AfterFunc(interval, func() { @@ -157,11 +158,7 @@ func (c *Conn) senderLoop() { } c.sendMu.Unlock() - sent, err := c.sendFragment(ask) - if err != nil { - continue - } - if !sent { + if sent, _ := c.sendFragment(ask); !sent { continue } } From eddc36809bdf01e4045bb65d271a848ad7b7d75a Mon Sep 17 00:00:00 2001 From: Rekseto Date: Wed, 1 Oct 2025 16:54:13 +0200 Subject: [PATCH 08/13] mod/udp: remove rollback logic for unacked packets in send failure handling --- mod/udp/rudp/send.go | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/mod/udp/rudp/send.go b/mod/udp/rudp/send.go index d9ed57716..88b7963c0 100644 --- a/mod/udp/rudp/send.go +++ b/mod/udp/rudp/send.go @@ -71,23 +71,6 @@ func (c *Conn) armRetransmitTimer(need bool) { } } -// rollbackPacketLocked removes an unacked packet on send failure. Caller does NOT hold lock upon entry. -func (c *Conn) rollbackPacketLocked(seq uint32, length int) { - c.sendMu.Lock() - if u, ok := c.unacked[seq]; ok && u.length == length { - delete(c.unacked, seq) - if c.nextSeqNum == seq+uint32(length) { - c.nextSeqNum = seq - } - if len(c.unacked) == 0 && c.rtxTimer != nil { - c.rtxTimer.Stop() - c.rtxTimer = nil - } - c.sendCond.Broadcast() - } - c.sendMu.Unlock() -} - // sendFragment consumes up to ask bytes, builds a packet and sends it. func (c *Conn) sendFragment(ask int) (bool, error) { c.sendMu.Lock() @@ -107,8 +90,8 @@ func (c *Conn) sendFragment(ask int) (bool, error) { c.armRetransmitTimer(startTimer) if _, err := c.sendDatagram(raw); err != nil { - c.rollbackPacketLocked(seq, int(pkt.Len)) - return false, err + // Keep packet in unacked for normal retransmission (no fast retransmit tweak). + return true, nil } return true, nil } From b238e50cab4894e15ca28a3820300c2d1d91472c Mon Sep 17 00:00:00 2001 From: Rekseto Date: Wed, 1 Oct 2025 17:16:06 +0200 Subject: [PATCH 09/13] mod/udp: enhance connection handling and error reporting in reliable UDP --- mod/udp/errors.go | 3 ++ mod/udp/rudp/conn.go | 20 +++++++++-- mod/udp/rudp/retransmissions.go | 21 +++++++++++ mod/udp/rudp/send.go | 64 +++++++++++++++------------------ 4 files changed, 69 insertions(+), 39 deletions(-) diff --git a/mod/udp/errors.go b/mod/udp/errors.go index e7f3dc17b..c64bc159e 100644 --- a/mod/udp/errors.go +++ b/mod/udp/errors.go @@ -6,6 +6,9 @@ var ( ErrListenerClosed = errors.New("listener closed") ErrRetransmissionLimitExceeded = errors.New( "retransmissions limit exceeded") + + // ErrDataLost is emitted on close if there was still buffered or unacked data. + ErrDataLost = errors.New("unsent data lost on close") ErrPacketTooShort = errors.New("packet too short") ErrConnClosed = errors.New("connection closed") ErrInvalidPayloadLength = errors.New("invalid payload length") diff --git a/mod/udp/rudp/conn.go b/mod/udp/rudp/conn.go index 7b179b822..c050165fd 100644 --- a/mod/udp/rudp/conn.go +++ b/mod/udp/rudp/conn.go @@ -207,6 +207,14 @@ func (c *Conn) Read(p []byte) (n int, err error) { // Write enqueues data into the send ring buffer. Implementation in send.go. func (c *Conn) Write(p []byte) (n int, err error) { + if c.isClosed() { + return 0, udp.ErrConnClosed + } + + if !c.inState(StateEstablished) { + return 0, udp.ErrConnectionNotEstablished + } + return c.writeSend(p) } @@ -216,19 +224,25 @@ func (c *Conn) Close() error { c.sendMu.Unlock() return nil } + pendingData := c.sendRB != nil && c.sendRB.Length() > 0 + pendingUnacked := len(c.unacked) > 0 atomic.StoreUint32(&c.closedFlag, 1) - // stop retransmission timer if running if c.rtxTimer != nil { c.rtxTimer.Stop() c.rtxTimer = nil } - // wake any waiters (writers, senderLoop) c.sendCond.Broadcast() ch := c.inCh - c.inCh = nil // detach channel to prevent further sends + c.inCh = nil closedCb := c.onClosedCb c.sendMu.Unlock() + if pendingData || pendingUnacked { + select { + case c.ErrChan <- udp.ErrDataLost: + default: + } + } if ch != nil { close(ch) } diff --git a/mod/udp/rudp/retransmissions.go b/mod/udp/rudp/retransmissions.go index 44701a196..4b3c76952 100644 --- a/mod/udp/rudp/retransmissions.go +++ b/mod/udp/rudp/retransmissions.go @@ -7,6 +7,27 @@ import ( "github.com/cryptopunkscc/astrald/mod/udp" ) +// startRtxTimer arms the retransmission timer if not already running +func (c *Conn) startRtxTimer() { + c.sendMu.Lock() + if c.rtxTimer != nil { + c.sendMu.Unlock() + return + } + interval := c.cfg.RetransmissionInterval + c.rtxTimer = time.AfterFunc(interval, func() { + c.handleRetransmissionTimeout() + c.sendMu.Lock() + if len(c.unacked) > 0 && !c.isClosed() { + c.rtxTimer.Reset(interval) + } else { + c.rtxTimer = nil + } + c.sendMu.Unlock() + }) + c.sendMu.Unlock() +} + // queueAckLocked marks that an ACK should be (re)sent. Caller must hold recvMu. func (c *Conn) queueAckLocked() { c.ackPending = true } diff --git a/mod/udp/rudp/send.go b/mod/udp/rudp/send.go index 88b7963c0..2dacf9277 100644 --- a/mod/udp/rudp/send.go +++ b/mod/udp/rudp/send.go @@ -90,31 +90,31 @@ func (c *Conn) sendFragment(ask int) (bool, error) { c.armRetransmitTimer(startTimer) if _, err := c.sendDatagram(raw); err != nil { - // Keep packet in unacked for normal retransmission (no fast retransmit tweak). - return true, nil + return true, nil // retained for retransmission } return true, nil } -// startRtxTimer arms the retransmission timer if not already running -func (c *Conn) startRtxTimer() { +// acquireNextSendSize blocks until there is data to send and window is not full, +// or returns (0,false) if the connection is closed or not established anymore. +// It releases the lock before returning. +func (c *Conn) acquireNextSendSize() (int, bool) { c.sendMu.Lock() - if c.rtxTimer != nil { - c.sendMu.Unlock() - return - } - interval := c.cfg.RetransmissionInterval - c.rtxTimer = time.AfterFunc(interval, func() { - c.handleRetransmissionTimeout() - c.sendMu.Lock() - if len(c.unacked) > 0 && !c.isClosed() { - c.rtxTimer.Reset(interval) - } else { - c.rtxTimer = nil + for { + if c.isClosed() || !c.inState(StateEstablished) { + c.sendMu.Unlock() + return 0, false } - c.sendMu.Unlock() - }) - c.sendMu.Unlock() + if c.sendRB.Length() > 0 && !c.windowFull() { + ask := int(c.sendRB.Length()) + if ask > c.cfg.MaxSegmentSize { + ask = c.cfg.MaxSegmentSize + } + c.sendMu.Unlock() + return ask, true + } + c.sendCond.Wait() + } } // senderLoop runs as a goroutine and is responsible for sending packets from the send buffer. @@ -123,26 +123,18 @@ func (c *Conn) startRtxTimer() { func (c *Conn) senderLoop() { defer func() { c.sendCond.Broadcast() }() for { - c.sendMu.Lock() - if c.isClosed() || !c.inState(StateEstablished) { - c.sendMu.Unlock() + ask, ok := c.acquireNextSendSize() + if !ok { // closed or not established return } - for c.sendRB.Length() == 0 || c.windowFull() { - c.sendCond.Wait() - if c.isClosed() || !c.inState(StateEstablished) { - c.sendMu.Unlock() - return + _, err := c.sendFragment(ask) + if err != nil { + select { + case c.ErrChan <- err: + default: } - } - ask := int(c.sendRB.Length()) - if ask > c.cfg.MaxSegmentSize { - ask = c.cfg.MaxSegmentSize - } - c.sendMu.Unlock() - - if sent, _ := c.sendFragment(ask); !sent { - continue + c.Close() + return } } } From f9db70eb2e7787b95e92a700e484850cd665bdb6 Mon Sep 17 00:00:00 2001 From: Rekseto Date: Wed, 1 Oct 2025 17:52:08 +0200 Subject: [PATCH 10/13] mod/udp: refine ACK handling and streamline sender loop in reliable UDP --- mod/udp/rudp/conn.go | 26 ++----- mod/udp/rudp/conn_handshake.go | 4 +- mod/udp/rudp/send.go | 130 +++++++++++++-------------------- 3 files changed, 57 insertions(+), 103 deletions(-) diff --git a/mod/udp/rudp/conn.go b/mod/udp/rudp/conn.go index c050165fd..6013b0a4f 100644 --- a/mod/udp/rudp/conn.go +++ b/mod/udp/rudp/conn.go @@ -55,12 +55,8 @@ type Conn struct { initialSeqNumRemote uint32 nextSeqNum uint32 connID uint32 // Connection ID - sendBase uint32 // Oldest unacked sequence (ACK floor) - ackedSeqNum uint32 // Highest cumulative ACK seen + ackedSeqNum uint32 // Highest cumulative ACK observed for our sent data (also piggybacked when sending) expected uint32 // Next expected sequence number (receive side) - - inflight uint32 // Number of unacked packets - // Send buffer and reliability sendRB *ringbuffer.RingBuffer // Persistent send ring buffer (FIFO; bytes consumed at packetization) frag *BasicFragmenter // Fragmenter for packetization @@ -77,12 +73,11 @@ type Conn struct { ErrChan chan error // Channel for connection-level errors (e.g., retransmission failure) // Inbound buffering & ACK state - recvRB *ringbuffer.RingBuffer - recvMu sync.Mutex - recvCond *sync.Cond - ackTimer *time.Timer - ackPending bool - lastAckSent uint32 + recvRB *ringbuffer.RingBuffer + recvMu sync.Mutex + recvCond *sync.Cond + ackTimer *time.Timer + ackPending bool // Out-of-order buffer (keyed by sequence number of first byte) recvOO map[uint32]*Packet // stored packets with Seq > expected awaiting in-order delivery } @@ -297,26 +292,19 @@ func (c *Conn) HandleAckPacket(packet *Packet) { if ack > c.ackedSeqNum { c.ackedSeqNum = ack } - if ack > c.sendBase { - c.sendBase = ack - } + // cumulative ACK floor implicitly reflected by ackedSeqNum // Remove fully acked packets (keyed by seq) for s, u := range c.unacked { if u.isHandshake { - // Handshake control (SYN / SYN|ACK) conceptually consumes 1 sequence number. - // Require ack > s (i.e., ack == s+1) to delete, avoiding premature removal - // if an unexpected ack echo with ack==s arrives. if ack > s { // expected ack == s+1 delete(c.unacked, s) } continue } - // Data packet: remove when cumulative ack covers entire payload. if s+uint32(u.length) <= ack { delete(c.unacked, s) } } - // Stop retransmission timer if no unacked packets remain if len(c.unacked) == 0 && c.rtxTimer != nil { c.rtxTimer.Stop() c.rtxTimer = nil diff --git a/mod/udp/rudp/conn_handshake.go b/mod/udp/rudp/conn_handshake.go index 83311de78..def2179ba 100644 --- a/mod/udp/rudp/conn_handshake.go +++ b/mod/udp/rudp/conn_handshake.go @@ -72,10 +72,9 @@ func (c *Conn) startClientHandshakeDirect(ctx context.Context) error { // Expect SYN|ACK if pkt.Flags&(FlagSYN|FlagACK) == (FlagSYN|FlagACK) && pkt.Ack == seq+1 && pkt.Seq != 0 { c.initialSeqNumRemote = pkt.Seq - // finalize local send base + // finalize local cumulative ack/next sequence c.sendMu.Lock() c.ackedSeqNum = seq + 1 - c.sendBase = seq + 1 c.nextSeqNum = seq + 1 c.sendMu.Unlock() // Send final ACK @@ -121,7 +120,6 @@ func (c *Conn) StartServerHandshake(ctx context.Context, synPkt *Packet) error { c.rtxTimer = nil } c.ackedSeqNum = c.initialSeqNumLocal + 1 - c.sendBase = c.initialSeqNumLocal + 1 c.nextSeqNum = c.initialSeqNumLocal + 1 c.sendMu.Unlock() c.onEstablished() diff --git a/mod/udp/rudp/send.go b/mod/udp/rudp/send.go index 2dacf9277..e8d866875 100644 --- a/mod/udp/rudp/send.go +++ b/mod/udp/rudp/send.go @@ -17,7 +17,8 @@ func (c *Conn) writeSend(p []byte) (n int, err error) { if err != nil { return 0, err } - c.sendCond.Broadcast() + // Only the senderLoop needs to be awakened; writers waiting for space are not helped by a write. + c.sendCond.Signal() return writeLen, nil } @@ -26,36 +27,6 @@ func (c *Conn) windowFull() bool { return len(c.unacked) >= c.cfg.MaxWindowPackets } -// planFragmentLocked decides how many bytes to send (<= ask) and drains them into a fresh buffer. -// Caller MUST hold sendMu. Returns nil buffer if nothing to send. -func (c *Conn) planFragmentLocked(ask int) (seq uint32, buf []byte, n int) { - if ask <= 0 || c.sendRB.Length() == 0 { - return 0, nil, 0 - } - if ask > int(c.sendRB.Length()) { - ask = int(c.sendRB.Length()) - } - if ask > c.cfg.MaxSegmentSize { - ask = c.cfg.MaxSegmentSize - } - fragBuf := make([]byte, ask) - readN, _ := c.sendRB.Read(fragBuf) - if readN == 0 { - return 0, nil, 0 - } - return c.nextSeqNum, fragBuf[:readN], readN -} - -// buildPacket converts raw payload into a Packet and marshals it. -func (c *Conn) buildPacket(seq uint32, payload []byte) (*Packet, []byte, error) { - pkt := &Packet{Seq: seq, Ack: c.ackedSeqNum, Flags: FlagACK, Len: uint16(len(payload)), Payload: payload} - b, err := pkt.Marshal() - if err != nil { - return nil, nil, err - } - return pkt, b, nil -} - // commitPacketLocked registers the packet as unacked and advances sequence numbers. Caller holds sendMu. func (c *Conn) commitPacketLocked(pkt *Packet) (startTimer bool) { seq := pkt.Seq @@ -71,63 +42,55 @@ func (c *Conn) armRetransmitTimer(need bool) { } } -// sendFragment consumes up to ask bytes, builds a packet and sends it. -func (c *Conn) sendFragment(ask int) (bool, error) { - c.sendMu.Lock() - seq, payload, plen := c.planFragmentLocked(ask) - if plen == 0 { - c.sendMu.Unlock() - return false, nil - } - pkt, raw, err := c.buildPacket(seq, payload) - if err != nil { - c.sendMu.Unlock() - return false, err - } - startTimer := c.commitPacketLocked(pkt) - c.sendCond.Broadcast() - c.sendMu.Unlock() - - c.armRetransmitTimer(startTimer) - if _, err := c.sendDatagram(raw); err != nil { - return true, nil // retained for retransmission - } - return true, nil -} - -// acquireNextSendSize blocks until there is data to send and window is not full, -// or returns (0,false) if the connection is closed or not established anymore. -// It releases the lock before returning. -func (c *Conn) acquireNextSendSize() (int, bool) { - c.sendMu.Lock() - for { - if c.isClosed() || !c.inState(StateEstablished) { - c.sendMu.Unlock() - return 0, false - } - if c.sendRB.Length() > 0 && !c.windowFull() { - ask := int(c.sendRB.Length()) - if ask > c.cfg.MaxSegmentSize { - ask = c.cfg.MaxSegmentSize - } - c.sendMu.Unlock() - return ask, true - } - c.sendCond.Wait() - } -} - // senderLoop runs as a goroutine and is responsible for sending packets from the send buffer. +// Inlined fragmentation & packet build: we avoid an extra lock/unlock by performing +// (select bytes -> copy -> build packet -> commit) under one critical section, then +// releasing the lock before marshaling and sending. // FIFO model: bytes are consumed from sendRB as soon as they are packetized. Retransmissions // use copies stored in unacked map. No random access over the ring is performed. func (c *Conn) senderLoop() { defer func() { c.sendCond.Broadcast() }() for { - ask, ok := c.acquireNextSendSize() - if !ok { // closed or not established - return + c.sendMu.Lock() + for { + if c.isClosed() || !c.inState(StateEstablished) { + c.sendMu.Unlock() + return + } + if c.sendRB.Length() > 0 && !c.windowFull() { + break + } + c.sendCond.Wait() + } + + // Decide fragment size and drain payload while still holding the lock. + ask := int(c.sendRB.Length()) + if ask > c.cfg.MaxSegmentSize { + ask = c.cfg.MaxSegmentSize + } + // Allocate buffer and read from ring (destructive read). + payload := make([]byte, ask) + readN, _ := c.sendRB.Read(payload) + if readN == 0 { // nothing actually read; loop and re-evaluate predicates + c.sendMu.Unlock() + continue + } + if readN != ask { // shrink payload if ring gave us fewer bytes + payload = payload[:readN] } - _, err := c.sendFragment(ask) + + seq := c.nextSeqNum + pkt := &Packet{Seq: seq, Ack: c.ackedSeqNum, Flags: FlagACK, Len: uint16(len(payload)), Payload: payload} + startTimer := c.commitPacketLocked(pkt) + // We freed ring space; signal at least one waiting writer or (rarely) another waiter. + c.sendCond.Signal() + c.sendMu.Unlock() + + // Manage retransmission timer outside lock + c.armRetransmitTimer(startTimer) + + // Marshal & send outside lock to minimize critical section time. + raw, err := pkt.Marshal() if err != nil { select { case c.ErrChan <- err: @@ -136,5 +99,10 @@ func (c *Conn) senderLoop() { c.Close() return } + if _, err := c.sendDatagram(raw); err != nil { + // Suppress error (packet retained in unacked for retransmission). Continue loop. + continue + } + // next iteration } } From 8f5c95fba86691a4e1ee2b34d383bf0898befc8e Mon Sep 17 00:00:00 2001 From: Rekseto Date: Thu, 2 Oct 2025 04:41:28 +0200 Subject: [PATCH 11/13] mod/udp: add benchmarks for RUDP transfer performance and enhance configuration defaults --- mod/udp/rudp/README.md | 202 +++++++++++++++++++++++++++++++ mod/udp/rudp/bench_test.go | 199 ++++++++++++++++++++++++++++++ mod/udp/rudp/config.go | 19 ++- mod/udp/rudp/config_test.go | 17 +-- mod/udp/rudp/conn.go | 34 ++++-- mod/udp/rudp/fragmenter.go | 47 ++----- mod/udp/rudp/fragmenter_test.go | 27 +++-- mod/udp/rudp/integration_test.go | 135 +++++++++++++++++++++ mod/udp/rudp/retransmissions.go | 8 ++ 9 files changed, 600 insertions(+), 88 deletions(-) create mode 100644 mod/udp/rudp/README.md create mode 100644 mod/udp/rudp/bench_test.go diff --git a/mod/udp/rudp/README.md b/mod/udp/rudp/README.md new file mode 100644 index 000000000..e5d552e8d --- /dev/null +++ b/mod/udp/rudp/README.md @@ -0,0 +1,202 @@ +# RUDP (Proof of Concept) + +Reliable UDP with minimal mechanisms: fixed-interval retransmissions, cumulative ACKs, simple packet window, basic 3‑way handshake. + +## Current Status Summary + +Works today: +- 3‑way handshake (SYN → SYN|ACK → ACK) with random non‑zero ISNs. +- Cumulative reliability (no SACK) using single connection-level retransmission timer. +- Fixed RTO (no RTT sampling, no backoff) + global resend of all unacked packets each tick. +- Flow control: sender-side packet count window (`MaxWindowPackets` considered conversion for MaxWindowBytes for better performance), MSS-bounded fragments. +- Fragmentation: simple slice of up to MSS bytes per packet (no coalescing / segmentation logic beyond cap). +- Delayed ACK: single configurable `AckDelay` timer (pure ACK or piggyback on outgoing data). +- Out-of-order receive buffering with gap fill & drain (`recvOO` map). +- Ring buffers: send (bytes) + recv (bytes). Blocking `Write` until space. +- Error surfacing: `ErrChan` (retransmission limit, data loss on close, marshal errors). +- Listener accepts only established connections; server handshake runs per inbound SYN. + +Not implemented / not present: +- FIN/close protocol (flags exist; logic is TODO / unused). +- Half-close, TIME-WAIT, RST, keepalive/ping. +- Congestion control, pacing, rate limiting. +- Adaptive RTO (no RTT measurement), exponential backoff, per-packet timers. +- Selective ACK, fast retransmit, loss detection heuristics beyond timeout. +- Use of advertised receive window (`Win` field) by sender (currently ignored). +- MSS/Path MTU discovery (fixed MSS from config only). +- Byte-based dynamic window / BDP autotuning. +- Checksum beyond UDP (no payload integrity apart from lower layer). +- Security (no encryption, authentication, replay protection). +- Stats/metrics/telemetry surfaces. +- Config field `SendBufBytes` (unused; send buffer size derived from window*MSS*2). +- Backpressure signaling aside from blocking `Write`. + +Platform assumptions: standard Go net.UDP; loopback benchmark only; no OS autotuning hooks. + +## Architecture Overview + +Data path (outbound): +Application → `Conn.Write` → send ring buffer → `senderLoop` → fragment (≤ MSS) → packet build (Seq/Ack/Flags) → UDP send → network. + +Receive path: UDP recv → `recvLoop` → parse → (handshake vs established) → data: ACK handling + reordering + ring write → `Conn.Read` → Application. + +ACK path: receiver marks `ackPending`; delayed (timer) or immediate; pure ACK or piggyback in data packet (`FlagACK`). + +Mermaid sequence (simplified steady-state): +```mermaid +sequenceDiagram + participant AppW as App (Writer) + participant Conn as Conn + participant Snd as senderLoop + participant UDP as UDP Socket + participant Rcv as recvLoop + participant AppR as App (Reader) + + AppW->>Conn: Write(bytes) + Conn->>Conn: Enqueue bytes (sendRB) + Conn->>Snd: Signal cond + loop While data & window + Snd->>Snd: Fragment ≤ MSS + Snd->>Snd: Build Packet(Seq,Ack,ACK) + Snd->>UDP: sendDatagram() + end + UDP-->>Rcv: Datagram + Rcv->>Rcv: Unmarshal Packet + Rcv->>Conn: HandleAck (if piggyback) + Rcv->>Conn: handleDataPacket + Conn->>Conn: Reorder / drain recvOO + Conn->>Conn: queueAckLocked() + Conn->>UDP: (maybe) Pure ACK (delayed) + AppR->>Conn: Read() + Conn-->>AppR: Bytes +``` + +## Reliability Mechanisms (Existing) +- ACKs: Cumulative only. Data packets always carry `Ack`. Pure ACK packets (`FlagACK`, `Len=0`) sent on delay timer or immediately (duplicates, out-of-order, buffer full). +- Retransmission: Single fixed-interval connection-level timer; every tick retransmits ALL unacked packets (handshake + data). No RTT measurement, no per-packet timers, no exponential backoff. Limit enforced by `RetransmissionLimit` per packet; exceeding closes connection and emits error. +- Window: Sender-side count of in-flight packets (`len(unacked) < MaxWindowPackets`). Receiver advertised `Win` (bytes) set only on pure ACK; unused by sender. +- Reordering: Out-of-order packets stored in `recvOO` keyed by starting `Seq`; drained when gap closes. No explicit cap beyond recv buffer space. +- Fragmentation / MSS: Each packet contains up to `MaxSegmentSize` bytes; no multi-packet segmentation logic aside from slicing to MSS. +- Checksum/validation: Only length/header sanity checks; no additional checksum. +- Pacing / rate limiting: Not implemented. +- Close semantics: `Close()` does NOT send FIN; it halts timers, signals waiters, may emit `ErrDataLost` if unsent or unacked data remained; no half-close. + +## Configuration & Defaults +(From `config.go` after `Normalize()`) +- `MaxSegmentSize` (MSS): `DefaultMSS` (1200 - 13 = 1187 bytes payload). +- `MaxWindowPackets`: 1024. +- `RetransmissionInterval` (RTO fixed): 200ms. +- `MaxRetransmissionInterval`: 4s (NOT USED currently). +- `RetransmissionLimit`: 8 (per packet, including handshake packets). +- `AckDelay`: 5ms (0 => immediate ACK). +- `RecvBufBytes`: 16 MiB. +- `SendBufBytes`: 16 MiB (NOT USED; actual send buffer = window*MSS*2). + +## Wire Format (As Implemented) +Header (13 bytes): +- `Seq` (4): first byte sequence number of this segment. +- `Ack` (4): cumulative ACK (all bytes < Ack received). +- `Flags` (1): bitmask (SYN=1, ACK=2, FIN=4). +- `Win` (2): advertised free receive window in BYTES (only set on pure ACK; 0 on data packets currently). +- `Len` (2): payload length. +- `Payload` (Len bytes). + +Usage: +- Data: `Len>0`, always sets `FlagACK`, `Win`=0. +- Pure ACK: `Len=0`, `FlagACK`, sets `Win` to clamped free bytes (≤ 0xFFFF). +- Handshake: SYN (no payload), SYN|ACK, final ACK. +- FIN: Flag defined but not produced/consumed. + +## State & Timers (As Implemented) +States defined: `Closed, Listen, SynSent, SynReceived, Established, FinSent, FinReceived, TimeWait`. +Active transitions exercised: Closed → SynSent → Established (client); Closed → SynReceived → Established (server). FIN-related states unused. +Timers: +- Retransmission timer: single `time.AfterFunc` repeating at fixed `RetransmissionInterval` while any unacked packets exist. +- Delayed ACK timer: per-connection `ackTimer` (fires once per pending ACK batch) with delay `AckDelay`. +Cancellation: +- Retransmission timer stopped when `unacked` empty or connection closed. +- ACK timer canceled (stop) on close or when ACK sent. +No other timers (no keepalive, no linger/TIME-WAIT). + +## Buffers & Memory +- Send ring: size = `MaxWindowPackets * MaxSegmentSize * 2` bytes; holds queued outbound data not yet packetized. Bytes removed upon packetization (not upon ACK); retransmissions use stored packet copies in `unacked` map. +- Receive ring: size = `RecvBufBytes`; holds in-order assembled payload bytes for `Read`. +- Out-of-order map: one entry per future gap start (key = packet.Seq). Evicted on drainage. +- Packet copies: each in-flight packet (data + handshake) retained until cumulatively ACKed; payload slice owned exclusively by `unacked` entry. +- Zero-copy: Not present (copies on UDP read into new slice, into recv ring, then into caller buffer on `Read`). +- GC considerations: One allocation per outbound packet payload; per in-order + out-of-order inbound packet payload; ring buffers reuse internal slabs. + +## Benchmarks & Expected Localhost Behavior +Benchmark: `BenchmarkRUDPTransfer10MiB` (one 10 MiB unidirectional transfer client→server; handshake excluded from timed portion). +Running: +```bash +go test -run ^$ -bench BenchmarkRUDPTransfer10MiB -benchtime=1x ./mod/udp/rudp +RUDP_BENCH_PROGRESS=1 go test -run ^$ -bench BenchmarkRUDPTransfer10MiB -benchtime=1x ./mod/udp/rudp +``` +Reports: +- MiB/s throughput (write phase only). +- Allocations, bytes per op. +Typical throughput: Not documented in code (dependent on fixed window * MSS / RTO / ACK delay and host performance). + +## Missing Pieces +- No FIN exchange / graceful close / half-close. +- No congestion control or pacing. +- No RTT measurement / adaptive RTO / backoff (fixed interval only). +- `MaxRetransmissionInterval` unused. +- No selective ACK (SACK) / fast retransmit / loss detection beyond timeout. +- No use of advertised `Win` (flow control purely packet-count based). +- No path MTU discovery / dynamic MSS. +- No checksum beyond UDP; silent data corruption undetected at transport layer. +- No encryption / authentication. +- No keepalive / idle timeout / ping. +- No connection IDs beyond initial Seq reuse; no NAT rebinding handling beyond IP match. +- `SendBufBytes` config unused. +- Backpressure only via blocking writes; no error signaling before block. +- +## Future Improvements +### Near-term Safety & Performance +- Adaptive byte-based window & BDP autotune. +- RTT sampling + RFC 6298 RTO calculation + exponential backoff. +- Packet pacing / rate shaping. +- Time-threshold (fast) loss detection + limited fast retransmit. +- Adaptive ACK delay (heuristics based on flight / reordering). + +### Protocol Completeness +- Proper FIN handshake, half-close, and TIME-WAIT handling. +- Explicit connection IDs distinct from ISNs. +- RST/abort semantics. +- Keepalive / ping frames. + +### Robustness & Reliability +- Path MTU discovery (PLPMTUD) + MSS adjustment. +- Selective ACK (SACK) & out-of-order range tracking. +- Reordering window limits / discard policies. +- Detailed stats / telemetry (RTT, loss, retransmits, goodput). +- Error surfaces for flow control violations / protocol anomalies. + +### Tooling & Testing +- Loss / delay / duplication chaos test harness. +- Cross-platform CI with network emulation. +- Long-running soak / stress tests (memory & stability). +- Benchmark variants (bidirectional, varying window/MSS). + +## Benchmarks (uTP Comparison) + +| Implementation | Iterations | ns/op (mean) | Throughput (MB/s) | B/op | allocs/op | +|---|---:|---:|---:|---:|---:| +| **rudp** (`github.com/cryptopunkscc/astrald/mod/udp/rudp`) | 48 | 71,213,610 | 147.24 | 23,393,100 | 94,552 | +| **uTP** (`github.com/anacrolix/utp`) | 48 | 24,602,392 | 426.21 | 82,243,573 | 93,825 | + +### Notes + +Numbers come from the respective benchmark outputs on loopback (localhost), 10 MiB payload; handshake/setup excluded from the timed region. + +uTP shows ~3× higher throughput on this workload. Allocation counts are comparable; rudp allocates fewer bytes/op but runs slower overall. + +This rudp is a proof of concept and has not been profiler-guided or optimized yet. + +### Potential Improvements to Close the Gap + +Adding fast retransmits / time-threshold loss detection and switching from a packet-count window (MaxWindowPackets) to a byte-based window (MaxWindowBytes) with BDP-aware autotuning will most likely bring performance closer to uTP on high-throughput paths. + +Additional expected wins: adaptive ACK delay (immediate on reordering/gap), pacing, and better RTO sizing \ No newline at end of file diff --git a/mod/udp/rudp/bench_test.go b/mod/udp/rudp/bench_test.go new file mode 100644 index 000000000..ae4e7431e --- /dev/null +++ b/mod/udp/rudp/bench_test.go @@ -0,0 +1,199 @@ +package rudp_test + +import ( + "context" + "net" + "os" + "sync/atomic" + "testing" + "time" + + "github.com/cryptopunkscc/astrald/astral" + udpmod "github.com/cryptopunkscc/astrald/mod/udp" + rudp "github.com/cryptopunkscc/astrald/mod/udp/rudp" +) + +var benchDebug = os.Getenv("RUDP_BENCH_DEBUG") != "" +var benchProgress = os.Getenv("RUDP_BENCH_PROGRESS") != "" + +// BenchmarkRUDPTransfer10MiB performs a single unidirectional transfer of ~10 MiB +// from client -> server over a single RUDP connection on loopback. Handshake and +// setup are excluded from the timed section. The benchmark runs only one data +// transfer regardless of b.N (subsequent iterations return early) to bound total +// execution time while still reporting throughput via b.SetBytes. +func BenchmarkRUDPTransfer10MiB(b *testing.B) { + if testing.Short() { + b.Skip("skip in short mode") + } + + const totalBytes = 10 * 1024 * 1024 // 10 MiB per iteration + baseChunkSize := 64 * 1024 + b.SetBytes(totalBytes) + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Stop the timer during setup/teardown so only the write phase is measured. + b.StopTimer() + + b.Logf("[iter %d] setup starting", i) + baseCtx := astral.NewContext(context.Background()) + + cfg := rudp.Config{} + cfg.Normalize() // simplest model: use normalized defaults + + l, err := rudp.Listen(baseCtx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, cfg, 2*time.Second) + if err != nil { + b.Fatalf("listen: %v", err) + } + b.Logf("[iter %d] listener=%v", i, l.Addr()) + + serverAddr := l.Addr().(*net.UDPAddr) + ipv4Dest := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: serverAddr.Port} + + acceptedCh := make(chan *rudp.Conn, 1) + acceptErrCh := make(chan error, 1) + go func() { + acceptCtx, cancel := baseCtx.WithTimeout(3 * time.Second) + defer cancel() + c, aerr := l.Accept(acceptCtx) + if aerr != nil { + acceptErrCh <- aerr + return + } + acceptedCh <- c + }() + + udpConn, err := net.DialUDP("udp4", nil, ipv4Dest) + if err != nil { + udpConn, err = net.DialUDP("udp", nil, ipv4Dest) + } + if err != nil { + l.Close() + b.Fatalf("dial: %v", err) + } + b.Logf("[iter %d] client local=%v remote=%v", i, udpConn.LocalAddr(), udpConn.RemoteAddr()) + + localEP, _ := udpmod.ParseEndpoint(udpConn.LocalAddr().String()) + remoteEP, _ := udpmod.ParseEndpoint(udpConn.RemoteAddr().String()) + outConn, err := rudp.NewConn(udpConn, localEP, remoteEP, cfg, true, nil, baseCtx) + if err != nil { + udpConn.Close() + l.Close() + b.Fatalf("newconn: %v", err) + } + + hCtx, hCancel := baseCtx.WithTimeout(2 * time.Second) + if err := outConn.StartClientHandshake(hCtx); err != nil { + hCancel() + outConn.Close() + udpConn.Close() + l.Close() + b.Fatalf("handshake: %v", err) + } + hCancel() + b.Logf("[iter %d] handshake complete local=%v remote=%v", i, udpConn.LocalAddr(), udpConn.RemoteAddr()) + + var serverConn *rudp.Conn + select { + case serverConn = <-acceptedCh: + case aerr := <-acceptErrCh: + outConn.Close() + udpConn.Close() + l.Close() + b.Fatalf("accept: %v", aerr) + case <-time.After(3 * time.Second): + outConn.Close() + udpConn.Close() + l.Close() + b.Fatalf("accept timeout") + } + if serverConn == nil { + outConn.Close() + udpConn.Close() + l.Close() + b.Fatalf("nil server conn") + } + b.Logf("[iter %d] server accepted remote=%v", i, serverConn.RemoteEndpoint()) + + // Determine a safe chunk size that will not block on first write: + // send ring buffer size = MaxWindowBytes * 2. Choose <= MaxWindowBytes /2 for margin. + safeChunk := cfg.MaxWindowPackets * cfg.MaxSegmentSize / 2 + if safeChunk < cfg.MaxSegmentSize { + safeChunk = cfg.MaxSegmentSize + } + if baseChunkSize > safeChunk { + if benchDebug { + b.Logf("[iter %d] adjusting chunk size from %d to %d ("+ + "MaxWindowBytes=%d)", i, baseChunkSize, safeChunk, cfg.MaxWindowPackets) + } + } + chunkSize := safeChunk + + // Prepare transfer + chunk := make([]byte, chunkSize) + for j := range chunk { + chunk[j] = byte(j) + } + var received int64 + readDone := make(chan struct{}) + var readErr atomic.Value + go func() { + buf := make([]byte, 128*1024) + for atomic.LoadInt64(&received) < int64(totalBytes) { + n, err := serverConn.Read(buf) + if err != nil { + readErr.Store(err) + break + } + if n > 0 { + atomic.AddInt64(&received, int64(n)) + } + } + close(readDone) + }() + + b.Logf("[iter %d] starting timed transfer", i) + writeStart := time.Now() + // Start the timer only for the write phase (do not reset cumulative timer) + b.StartTimer() + sent := 0 + nextMilestone := 1 * 1024 * 1024 + for sent < totalBytes { + toWrite := chunkSize + if rem := totalBytes - sent; rem < toWrite { + toWrite = rem + } + if _, err := outConn.Write(chunk[:toWrite]); err != nil { + b.Fatalf("write err after %d bytes: %v", sent, err) + } + sent += toWrite + if benchProgress && sent >= nextMilestone { + b.Logf("[iter %d] progress sent=%d / %d (%.2f%%)", i, sent, totalBytes, 100*float64(sent)/float64(totalBytes)) + nextMilestone += 1 * 1024 * 1024 + } + } + b.StopTimer() + elapsedWrite := time.Since(writeStart) + + select { + case <-readDone: + case <-time.After(5 * time.Second): + b.Fatalf("read timeout received=%d", atomic.LoadInt64(&received)) + } + if v := readErr.Load(); v != nil { + b.Fatalf("server read error: %v", v.(error)) + } + if got := atomic.LoadInt64(&received); got != int64(totalBytes) { + b.Fatalf("recv mismatch got=%d want=%d", got, totalBytes) + } + + mbps := (float64(totalBytes) / (1024 * 1024)) / elapsedWrite.Seconds() + b.Logf("[iter %d] transfer complete bytes=%d duration=%v throughput=%.2f MiB/s", i, totalBytes, elapsedWrite, mbps) + + serverConn.Close() + outConn.Close() + udpConn.Close() + l.Close() + b.Logf("[iter %d] cleanup done", i) + } +} diff --git a/mod/udp/rudp/config.go b/mod/udp/rudp/config.go index 10151f978..42e6931f0 100644 --- a/mod/udp/rudp/config.go +++ b/mod/udp/rudp/config.go @@ -5,20 +5,18 @@ import "time" // Transport default constants (exported for visibility in tests & integration) const ( DefaultMSS = 1200 - 13 // 1187 (1200 minus header) - DefaultWindowBytes = 16 * DefaultMSS - DefaultWndPkts = 32 - DefaultRTO = 500 * time.Millisecond + DefaultWndPkts = 1024 + DefaultRTO = 200 * time.Millisecond DefaultRTOMax = 4 * time.Second DefaultRetries = 8 - DefaultAckDelay = 25 * time.Millisecond - DefaultRecvBufBytes = 1 << 20 - DefaultSendBufBytes = 1 << 20 + DefaultAckDelay = 5 * time.Millisecond + DefaultRecvBufBytes = 16 << 20 + DefaultSendBufBytes = 16 << 20 ) // Config holds reliability / buffering parameters for the rudp transport. type Config struct { MaxSegmentSize int `yaml:"max_segment_size"` - MaxWindowBytes int `yaml:"max_window_bytes"` MaxWindowPackets int `yaml:"max_window_packets"` RetransmissionInterval time.Duration `yaml:"retransmission_interval"` MaxRetransmissionInterval time.Duration `yaml:"max_retransmission_interval"` @@ -33,9 +31,6 @@ func (c *Config) Normalize() { if c.MaxSegmentSize == 0 { c.MaxSegmentSize = DefaultMSS } - if c.MaxWindowBytes == 0 { - c.MaxWindowBytes = DefaultWindowBytes - } if c.MaxWindowPackets == 0 { c.MaxWindowPackets = DefaultWndPkts } @@ -60,5 +55,5 @@ func (c *Config) Normalize() { } // - AckDelay: mirrors QUIC MAX_ACK_DELAY (RFC 9000 §13.2.1). -// - Buffer sizes: 1 MiB default, capped for safety, must be >= window. -// - All invariants enforced for safety and interoperability. +// - Buffer sizes: 1 MiB default, capped for safety, must be >= aggregate window. +// - Aggregate window bytes can be derived as MaxWindowPackets * MaxSegmentSize. diff --git a/mod/udp/rudp/config_test.go b/mod/udp/rudp/config_test.go index 17074f5e6..5e0b415d5 100644 --- a/mod/udp/rudp/config_test.go +++ b/mod/udp/rudp/config_test.go @@ -13,9 +13,7 @@ func TestNormalizeAppliesDefaults(t *testing.T) { if c.MaxSegmentSize != DefaultMSS { f(t, "MaxSegmentSize", c.MaxSegmentSize, DefaultMSS) } - if c.MaxWindowBytes != DefaultWindowBytes { - f(t, "MaxWindowBytes", c.MaxWindowBytes, DefaultWindowBytes) - } + if c.MaxWindowPackets != DefaultWndPkts { f(t, "MaxWindowPackets", c.MaxWindowPackets, DefaultWndPkts) } @@ -43,7 +41,6 @@ func TestNormalizeAppliesDefaults(t *testing.T) { func TestNormalizePreservesNonZero(t *testing.T) { orig := Config{ MaxSegmentSize: 999, - MaxWindowBytes: 123456, MaxWindowPackets: 77, RetransmissionInterval: 321 * time.Millisecond, MaxRetransmissionInterval: 987 * time.Millisecond, @@ -60,9 +57,7 @@ func TestNormalizePreservesNonZero(t *testing.T) { if c.MaxSegmentSize != orig.MaxSegmentSize { g(t, "MaxSegmentSize", c.MaxSegmentSize, orig.MaxSegmentSize) } - if c.MaxWindowBytes != orig.MaxWindowBytes { - g(t, "MaxWindowBytes", c.MaxWindowBytes, orig.MaxWindowBytes) - } + if c.MaxWindowPackets != orig.MaxWindowPackets { g(t, "MaxWindowPackets", c.MaxWindowPackets, orig.MaxWindowPackets) } @@ -107,9 +102,6 @@ func TestNormalizePartial(t *testing.T) { g(t, "AckDelay", c.AckDelay, 5*time.Millisecond) } - if c.MaxWindowBytes != DefaultWindowBytes { - f(t, "MaxWindowBytes", c.MaxWindowBytes, DefaultWindowBytes) - } if c.MaxWindowPackets != DefaultWndPkts { f(t, "MaxWindowPackets", c.MaxWindowPackets, DefaultWndPkts) } @@ -145,7 +137,6 @@ func TestNormalizeIdempotent(t *testing.T) { func TestNormalizeNegativeValues(t *testing.T) { c := Config{ MaxSegmentSize: -1, - MaxWindowBytes: -2, MaxWindowPackets: -3, RetransmissionLimit: -4, RecvBufBytes: -5, @@ -161,9 +152,7 @@ func TestNormalizeNegativeValues(t *testing.T) { if c.MaxSegmentSize != -1 { g(t, "MaxSegmentSize", c.MaxSegmentSize, -1) } - if c.MaxWindowBytes != -2 { - g(t, "MaxWindowBytes", c.MaxWindowBytes, -2) - } + if c.MaxWindowPackets != -3 { g(t, "MaxWindowPackets", c.MaxWindowPackets, -3) } diff --git a/mod/udp/rudp/conn.go b/mod/udp/rudp/conn.go index 6013b0a4f..d29186566 100644 --- a/mod/udp/rudp/conn.go +++ b/mod/udp/rudp/conn.go @@ -16,14 +16,6 @@ import ( "github.com/smallnest/ringbuffer" ) -type Unacked struct { - pkt *Packet // Packet metadata (seq, len) - sentTime time.Time // Last sent time - rtxCount int // Retransmit count - length int // Payload length - isHandshake bool // True if this entry is for a handshake control packet -} - // Conn represents a reliable UDP connection. // Implements reliability, flow control, retransmissions, and error notification. // Key mechanisms: @@ -80,6 +72,7 @@ type Conn struct { ackPending bool // Out-of-order buffer (keyed by sequence number of first byte) recvOO map[uint32]*Packet // stored packets with Seq > expected awaiting in-order delivery + } // unified handshake channel capacity (applies to inbound & outbound) @@ -102,7 +95,8 @@ func NewConn(cn *net.UDPConn, l, r *udp.Endpoint, cfg Config, outbound bool, fir } } - sendRBSize := cfg.MaxWindowBytes * 2 // allow for some retransmit slack + // sendRBSize in BYTES: MaxWindowPackets * MaxSegmentSize * 2 for retransmit slack + sendRBSize := cfg.MaxWindowPackets * cfg.MaxSegmentSize * 2 rb := ringbuffer.New(sendRBSize) frag := NewBasicFragmenter(cfg.MaxSegmentSize) @@ -324,14 +318,18 @@ func (c *Conn) HandleControlPacket(packet *Packet) { // ...handle other control flags as needed... } -// Interface compliance for exonet.Conn +// Outbound interface compliance for exonet.Conn func (c *Conn) Outbound() bool { return c.outbound } + +// LocalEndpoint returns the local UDP endpoint of the connection. func (c *Conn) LocalEndpoint() exonet.Endpoint { if c == nil { return nil } return c.localEndpoint } + +// RemoteEndpoint returns the remote UDP endpoint of the connection. func (c *Conn) RemoteEndpoint() exonet.Endpoint { if c == nil { return nil @@ -354,15 +352,29 @@ func (c *Conn) ProcessPacket(pkt *Packet) { } return } + + // Data packet (may include piggyback ACK) + if pkt.Len > 0 { + if pkt.Flags&FlagACK != 0 { + c.HandleAckPacket(pkt) + } + + c.handleDataPacket(pkt) + return + } + + // Pure ACK if pkt.Flags&FlagACK != 0 { c.HandleAckPacket(pkt) return } + // Control (SYN/FIN/etc.) if pkt.Flags&(FlagSYN|FlagFIN) != 0 { c.HandleControlPacket(pkt) return } - c.handleDataPacket(pkt) + + // Ignore anything else } // sendDatagram sends a raw packet buffer choosing the correct syscall based on diff --git a/mod/udp/rudp/fragmenter.go b/mod/udp/rudp/fragmenter.go index 0cc8d427b..2b822e36c 100644 --- a/mod/udp/rudp/fragmenter.go +++ b/mod/udp/rudp/fragmenter.go @@ -1,16 +1,18 @@ package rudp +import "bytes" + // Fragmenter turns buffered bytes into wire packets and reproduces the exact // same boundaries for retransmission. type Fragmenter interface { // MakeNew decides payload size and builds a new Packet at nextSeq. // 'allowed' is the sender's remaining window in bytes. // Returns (packet, payloadLen, ok). ok=false if it chooses not to send (e.g., Nagle). - MakeNew(nextSeq uint32, allowed int, buf SendBuffer) (*Packet, int, bool) + MakeNew(nextSeq uint32, allowed int, buf *bytes.Buffer) (*Packet, int, bool) } // BasicFragmenter is a simple implementation of Fragmenter that splits a -// SendBuffer into Packets of at most MSS size. +// buffer into Packets of at most MSS size. type BasicFragmenter struct { MSS int } @@ -22,7 +24,8 @@ func NewBasicFragmenter(mss int) *BasicFragmenter { } // MakeNew implements the Fragmenter interface for BasicFragmenter. -func (f *BasicFragmenter) MakeNew(nextSeq uint32, allowed int, buf SendBuffer) (*Packet, int, bool) { +func (f *BasicFragmenter) MakeNew(nextSeq uint32, allowed int, + buf *bytes.Buffer) (*Packet, int, bool) { if f.MSS <= 0 { return nil, 0, false } @@ -41,7 +44,7 @@ func (f *BasicFragmenter) MakeNew(nextSeq uint32, allowed int, buf SendBuffer) ( maxLen = buf.Len() } - payload := buf.Peek(maxLen) + payload := buf.Bytes()[:maxLen] packet := &Packet{ Seq: nextSeq, Len: uint16(len(payload)), @@ -50,39 +53,3 @@ func (f *BasicFragmenter) MakeNew(nextSeq uint32, allowed int, buf SendBuffer) ( } return packet, len(payload), true } - -// Minimal SendBuffer interface for fragmentation -// Provides length and peek access to buffered data -// Implementations can use a byte slice, ring buffer, etc. -type SendBuffer interface { - Len() int - Peek(n int) []byte -} - -// ByteStreamBuffer is a minimal implementation of SendBuffer for fragmentation. -// It represents a contiguous stream of bytes, suitable for segmentation. -type ByteStreamBuffer struct { - data []byte -} - -func NewByteStreamBuffer(data []byte) *ByteStreamBuffer { - return &ByteStreamBuffer{data: data} -} - -func (b *ByteStreamBuffer) Len() int { - return len(b.data) -} - -func (b *ByteStreamBuffer) Peek(n int) []byte { - if n > len(b.data) { - n = len(b.data) - } - return b.data[:n] -} - -func (b *ByteStreamBuffer) Advance(n int) { - if n > len(b.data) { - n = len(b.data) - } - b.data = b.data[n:] -} diff --git a/mod/udp/rudp/fragmenter_test.go b/mod/udp/rudp/fragmenter_test.go index 366acccdd..5ec3c55f2 100644 --- a/mod/udp/rudp/fragmenter_test.go +++ b/mod/udp/rudp/fragmenter_test.go @@ -1,6 +1,7 @@ package rudp import ( + "bytes" "testing" ) @@ -11,8 +12,7 @@ func TestBasicFragmenter_SingleFragment(t *testing.T) { for i := range data { data[i] = byte(i) } - buf := &ByteStreamBuffer{data: data} - packet, packetLen, ok := frag.MakeNew(0, mss, buf) + packet, packetLen, ok := frag.MakeNew(0, mss, bytes.NewBuffer(data)) if !ok { t.Fatalf("expected ok=true, got false") } @@ -42,7 +42,7 @@ func TestBasicFragmenter_MultipleFragments(t *testing.T) { for i := range data { data[i] = byte(i) } - buf := &ByteStreamBuffer{data: data} + buf := bytes.NewBuffer(data) nextSeq := uint32(0) total := 0 @@ -63,7 +63,12 @@ func TestBasicFragmenter_MultipleFragments(t *testing.T) { t.Errorf("payload mismatch at %d: got %d, want %d", int(nextSeq)+i, packet.Payload[i], byte(int(nextSeq)+i)) } } - buf.Advance(packetLength) + n := packetLength + if n > buf.Len() { + n = buf.Len() + } // clamp for safety + _ = buf.Next(n) + nextSeq += uint32(packetLength) total += packetLength fragments++ @@ -79,7 +84,7 @@ func TestBasicFragmenter_MultipleFragments(t *testing.T) { func TestBasicFragmenter_ZeroLen(t *testing.T) { mss := 50 frag := NewBasicFragmenter(mss) - buf := &ByteStreamBuffer{data: nil} + buf := bytes.NewBuffer(nil) packet, packetLength, ok := frag.MakeNew(0, mss, buf) if ok { t.Errorf("expected ok=false for zero-len buffer") @@ -100,7 +105,7 @@ func TestBasicFragmenter_AllowedLessThanMSS(t *testing.T) { for i := range data { data[i] = byte(i) } - buf := &ByteStreamBuffer{data: data} + buf := bytes.NewBuffer(data) packet, packetLength, ok := frag.MakeNew(0, allowed, buf) if !ok { t.Fatalf("expected ok=true, got false") @@ -121,7 +126,7 @@ func TestBasicFragmenter_AllowedLessThanBuffer(t *testing.T) { for i := range data { data[i] = byte(i) } - buf := &ByteStreamBuffer{data: data} + buf := bytes.NewBuffer(data) packet, packetLength, ok := frag.MakeNew(0, allowed, buf) if !ok { t.Fatalf("expected ok=true, got false") @@ -142,7 +147,7 @@ func TestBasicFragmenter_BufferSmallerThanAllowedAndMSS(t *testing.T) { for i := range data { data[i] = byte(i) } - buf := &ByteStreamBuffer{data: data} + buf := bytes.NewBuffer(data) packet, packetLength, ok := frag.MakeNew(0, allowed, buf) if !ok { t.Fatalf("expected ok=true, got false") @@ -159,7 +164,7 @@ func TestBasicFragmenter_NegativeOrZeroAllowed(t *testing.T) { mss := 100 frag := NewBasicFragmenter(mss) data := make([]byte, 50) - buf := &ByteStreamBuffer{data: data} + buf := bytes.NewBuffer(data) packet, packetLength, ok := frag.MakeNew(0, 0, buf) if ok || packet != nil || packetLength != 0 { t.Errorf("expected no packet for allowed=0") @@ -174,7 +179,7 @@ func TestBasicFragmenter_ZeroMSS(t *testing.T) { mss := 0 frag := NewBasicFragmenter(mss) data := make([]byte, 50) - buf := &ByteStreamBuffer{data: data} + buf := bytes.NewBuffer(data) packet, packetLength, ok := frag.MakeNew(0, 100, buf) if ok || packet != nil || packetLength != 0 { t.Errorf("expected no packet for MSS=0") @@ -185,7 +190,7 @@ func TestBasicFragmenter_FlagsSet(t *testing.T) { mss := 100 frag := NewBasicFragmenter(mss) data := make([]byte, 50) - buf := &ByteStreamBuffer{data: data} + buf := bytes.NewBuffer(data) packet, _, ok := frag.MakeNew(0, 100, buf) if !ok || packet == nil { t.Fatalf("expected valid packet") diff --git a/mod/udp/rudp/integration_test.go b/mod/udp/rudp/integration_test.go index bc9555cfc..094f0b695 100644 --- a/mod/udp/rudp/integration_test.go +++ b/mod/udp/rudp/integration_test.go @@ -3,6 +3,7 @@ package rudp import ( "context" "net" + "sort" "testing" "time" @@ -162,3 +163,137 @@ func TestListenerDialHelloWorld(t *testing.T) { } } } + +// TestDiagFirst32Packets performs a focused diagnostic of the first ~32 data packets +// to inspect sequence alignment between sender and receiver. It writes exactly +// 32 * MSS bytes (fragmented into MSS-sized packets) and then logs the receiver's +// expected sequence, number of in-order bytes buffered, out-of-order queue size, +// and sample sequence gaps. This is a non-fatal diagnostic (will t.Skip on environments +// where it cannot bind or handshake cleanly). +func TestDiagFirst32Packets(t *testing.T) { + // Keep this test quick. + baseCtx := astral.NewContext(context.Background()) + + // Custom config to keep things deterministic and fast. + cfg := Config{ + MaxSegmentSize: DefaultMSS, // 1187 + MaxWindowPackets: 128, // allow >32 easily + AckDelay: time.Microsecond, // effectively immediate + RecvBufBytes: 1 << 20, + SendBufBytes: 1 << 20, + } + cfg.Normalize() + + l, err := Listen(baseCtx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, cfg, 2*time.Second) + if err != nil { + t.Skipf("listener setup failed (skip diag): %v", err) + } + defer l.Close() + + serverAddr := l.Addr().(*net.UDPAddr) + ipv4Dest := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: serverAddr.Port} + + acceptedCh := make(chan *Conn, 1) + go func() { + acceptCtx, cancel := baseCtx.WithTimeout(2 * time.Second) + defer cancel() + c, aerr := l.Accept(acceptCtx) + if aerr == nil { + acceptedCh <- c + } + }() + + udpConn, err := net.DialUDP("udp4", nil, ipv4Dest) + if err != nil { + udpConn, err = net.DialUDP("udp", nil, ipv4Dest) + } + if err != nil { + l.Close() + t.Skipf("dial failed (skip diag): %v", err) + } + defer udpConn.Close() + + localEP, _ := udpmod.ParseEndpoint(udpConn.LocalAddr().String()) + remoteEP, _ := udpmod.ParseEndpoint(udpConn.RemoteAddr().String()) + + outConn, err := NewConn(udpConn, localEP, remoteEP, cfg, true, nil, baseCtx) + if err != nil { + t.Skipf("NewConn outbound failed (skip diag): %v", err) + } + defer outConn.Close() + + hCtx, hCancel := baseCtx.WithTimeout(2 * time.Second) + if err := outConn.StartClientHandshake(hCtx); err != nil { + hCancel() + t.Skipf("handshake failed (skip diag): %v", err) + } + hCancel() + + var serverConn *Conn + select { + case serverConn = <-acceptedCh: + case <-time.After(2 * time.Second): + t.Skip("server accept timeout (skip diag)") + } + if serverConn == nil { + t.Skip("nil serverConn (skip diag)") + } + defer serverConn.Close() + + // Prepare exactly 32 * MSS bytes. + packets := 32 + bytesToSend := packets * cfg.MaxSegmentSize + payload := make([]byte, bytesToSend) + for i := range payload { + payload[i] = byte(i) + } + + written := 0 + start := time.Now() + for written < bytesToSend { + n, werr := outConn.Write(payload[written:]) + if werr != nil { + t.Fatalf("write error after %d bytes: %v", written, werr) + } + written += n + } + elapsedWrite := time.Since(start) + // Allow receiver a short window to process. + time.Sleep(50 * time.Millisecond) + + // Snapshot receiver internal state. + serverConn.recvMu.Lock() + expected := serverConn.expected + recvLen := serverConn.recvRB.Length() + oosz := len(serverConn.recvOO) + // Collect first few out-of-order keys + keys := make([]uint32, 0, oosz) + for k := range serverConn.recvOO { + keys = append(keys, k) + } + serverConn.recvMu.Unlock() + if len(keys) > 0 { + sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] }) + } + if len(keys) > 16 { + keys = keys[:16] + } + + // Log sender side nextSeq / acked + serverConn.sendMu.Lock() + acked := serverConn.ackedSeqNum + next := serverConn.nextSeqNum + serverConn.sendMu.Unlock() + + clientNext := outConn.nextSeqNum + clientAcked := outConn.ackedSeqNum + + t.Logf("diag: wrote=%d bytes (32*MSS=%d) writeElapsed=%v", written, bytesToSend, elapsedWrite) + t.Logf("diag: server expected=%d recvLen=%d recvOO=%d (firstOO=%v)", expected, recvLen, oosz, keys) + t.Logf("diag: server acked=%d next=%d | client acked=%d next=%d", acked, next, clientAcked, clientNext) + + // Basic assertion: at least one in-order advancement OR explicit log for gap. + if recvLen == 0 { + t.Logf("diag: WARNING no in-order data buffered; likely initial gap (expected not reached)") + } +} diff --git a/mod/udp/rudp/retransmissions.go b/mod/udp/rudp/retransmissions.go index 4b3c76952..043a55797 100644 --- a/mod/udp/rudp/retransmissions.go +++ b/mod/udp/rudp/retransmissions.go @@ -7,6 +7,14 @@ import ( "github.com/cryptopunkscc/astrald/mod/udp" ) +type Unacked struct { + pkt *Packet // Packet metadata (seq, len) + sentTime time.Time // Last sent time + rtxCount int // Retransmit count + length int // Payload length + isHandshake bool // True if this entry is for a handshake control packet +} + // startRtxTimer arms the retransmission timer if not already running func (c *Conn) startRtxTimer() { c.sendMu.Lock() From 0f0ad80492273a18d5a8f2d64a469c3d7b1d5ea5 Mon Sep 17 00:00:00 2001 From: Rekseto Date: Thu, 2 Oct 2025 17:45:51 +0200 Subject: [PATCH 12/13] cmd:/astrald: register udp module --- cmd/astrald/mods.go | 1 + mod/udp/conn.go | 4 ---- 2 files changed, 1 insertion(+), 4 deletions(-) delete mode 100644 mod/udp/conn.go diff --git a/cmd/astrald/mods.go b/cmd/astrald/mods.go index a799e2c24..ec2ece4e4 100644 --- a/cmd/astrald/mods.go +++ b/cmd/astrald/mods.go @@ -22,5 +22,6 @@ import ( _ "github.com/cryptopunkscc/astrald/mod/shell/src" _ "github.com/cryptopunkscc/astrald/mod/tcp/src" _ "github.com/cryptopunkscc/astrald/mod/tor/src" + _ "github.com/cryptopunkscc/astrald/mod/udp/src" _ "github.com/cryptopunkscc/astrald/mod/user/src" ) diff --git a/mod/udp/conn.go b/mod/udp/conn.go deleted file mode 100644 index df1150adc..000000000 --- a/mod/udp/conn.go +++ /dev/null @@ -1,4 +0,0 @@ -package udp - -type ReliableUdpConn interface { -} From d56d3c4871698d437d4276b3263a3713a4386e04 Mon Sep 17 00:00:00 2001 From: Rekseto Date: Fri, 3 Oct 2025 14:42:26 +0200 Subject: [PATCH 13/13] mod/udp: update ObjectType to reflect correct module for IP address --- mod/udp/ip.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mod/udp/ip.go b/mod/udp/ip.go index aa680a0ae..bb45f9671 100644 --- a/mod/udp/ip.go +++ b/mod/udp/ip.go @@ -19,7 +19,7 @@ func ParseIP(s string) (IP, error) { // astral func (IP) ObjectType() string { - return "mod.tcp.ip_address" + return "mod.udp.ip_address" } func (ip IP) WriteTo(w io.Writer) (n int64, err error) {