diff --git a/cmd/astrald/mods.go b/cmd/astrald/mods.go index a799e2c2..ec2ece4e 100644 --- a/cmd/astrald/mods.go +++ b/cmd/astrald/mods.go @@ -22,5 +22,6 @@ import ( _ "github.com/cryptopunkscc/astrald/mod/shell/src" _ "github.com/cryptopunkscc/astrald/mod/tcp/src" _ "github.com/cryptopunkscc/astrald/mod/tor/src" + _ "github.com/cryptopunkscc/astrald/mod/udp/src" _ "github.com/cryptopunkscc/astrald/mod/user/src" ) diff --git a/go.mod b/go.mod index ca8c98bc..1e202657 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/smallnest/ringbuffer v0.0.0-20250317021400-0da97b586904 // indirect golang.org/x/sys v0.32.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/libc v1.24.1 // indirect diff --git a/go.sum b/go.sum index 870a14f4..313787fd 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/smallnest/ringbuffer v0.0.0-20250317021400-0da97b586904 h1:OoG1xZV7CXnP2/Udl1ybEgTEds9XXA3NHWg+OR3c/a8= +github.com/smallnest/ringbuffer v0.0.0-20250317021400-0da97b586904/go.mod h1:tAG61zBM1DYRaGIPloumExGvScf08oHuo0kFoOqdbT0= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/wailsapp/mimetype v1.4.1 h1:pQN9ycO7uo4vsUUuPeHEYoUkLVkaRntMnHJxVwYhwHs= diff --git a/mod/nodes/src/op_streams.go b/mod/nodes/src/op_streams.go index 88172e68..2fec34e2 100644 --- a/mod/nodes/src/op_streams.go +++ b/mod/nodes/src/op_streams.go @@ -1,10 +1,11 @@ package nodes import ( + "slices" + "github.com/cryptopunkscc/astrald/astral" "github.com/cryptopunkscc/astrald/mod/nodes" "github.com/cryptopunkscc/astrald/mod/shell" - "slices" ) type opStreamsArgs struct { diff --git a/mod/udp/README.md b/mod/udp/README.md new file mode 100644 index 00000000..2bfda88a --- /dev/null +++ b/mod/udp/README.md @@ -0,0 +1,28 @@ +# UDP Module + +This module implements UDP-based communication for the Astrald platform, enabling fast, connectionless data transfer between nodes. + +## Current Proof of Concept (PoC) State +- Basic UDP packet send/receive functionality +- Simple endpoint resolution +- Fragmentation and reassembly of large packets +- Minimal handshake and connection logic +- Configuration via `config.go` + +## Key Components +- **Dial:** Initiates outbound UDP connections. Integrates with Astral's context (`astral.Context`), endpoint abstraction (`exonet.Endpoint`), and parses endpoints using UDP utilities. Returns a reliable connection for Astral's exonet system. Similar to TCP's Dial, but uses UDP-specific logic and interacts with Astral's node and identity systems. +- **ResolveEndpoints:** Resolves available UDP endpoints for a node. Uses Astral's identity system (`astral.Identity`) to verify node identity and returns endpoints as `exonet.Endpoint` via Astral's signal utilities (`sig.ArrayToChan`). Connects with Astral's node and endpoint management. +- **Loader:** Loads the UDP module into Astral. Connects with Astral's node (`astral.Node`), asset management (`core/assets`), and logging (`astral/log.Logger`). Loads configuration from assets, parses public endpoints, and registers the module with Astral's core module system for lifecycle management. +- **Unpack:** Handles packet reassembly for Astral's exonet system. Uses UDP endpoint parsing and error handling, returning endpoints for Astral's network abstraction. Connects with Astral's error and endpoint utilities. +- **Server:** Listens for incoming UDP connections. Uses Astral's context and logging, manages connections and server lifecycle, and integrates with Astral's configuration and endpoint parsing. Registers endpoints and connections with Astral's node and router systems. +- +## Possible Improvements for Production Readiness +- Advanced flow control mechanisms (dynamic window sizing, congestion control) +- Fast retransmissions and selective acknowledgments (SACK) +- Adaptive retransmission timers (RTT estimation) +- Performance optimizations (buffering, batching) +- Support for NAT traversal and hole punching +- Comprehensive test coverage +- Documentation and usage examples + +This module is currently a proof of concept and not recommended for production use without further development. diff --git a/mod/udp/endpoint.go b/mod/udp/endpoint.go new file mode 100644 index 00000000..d8629cb2 --- /dev/null +++ b/mod/udp/endpoint.go @@ -0,0 +1,154 @@ +package udp + +import ( + "bytes" + "errors" + "io" + "net" + "strconv" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/term" + "github.com/cryptopunkscc/astrald/mod/exonet" +) + +var _ exonet.Endpoint = &Endpoint{} +var _ astral.Object = &Endpoint{} + +// NOTE: this is same for UDP and TCP - consider moving to a common package + +// Endpoint is an astral.Object that holds information about a UDP endpoint, +// i.e. an IP address and a port. +// Supports JSON and text. +type Endpoint struct { + IP IP + Port astral.Uint16 +} + +func (e *Endpoint) ObjectType() string { + return "mod.udp.endpoint" +} + +func (e Endpoint) WriteTo(w io.Writer) (n int64, err error) { + return astral.Struct(e).WriteTo(w) +} + +func (e *Endpoint) ReadFrom(r io.Reader) (n int64, err error) { + return astral.Struct(e).ReadFrom(r) +} + +// exonet.Endpoint + +func (e *Endpoint) Address() string { + return net.JoinHostPort(e.IP.String(), strconv.Itoa(int(e.Port))) +} + +func (e *Endpoint) Network() string { + return "udp" +} + +// HostString returns the IP address as a string +func (e *Endpoint) HostString() string { + return e.IP.String() +} + +// PortNumber returns the port number as an int +func (e *Endpoint) PortNumber() int { + return int(e.Port) +} + +func (e *Endpoint) Pack() []byte { + var b = &bytes.Buffer{} + if _, err := e.WriteTo(b); err != nil { + return nil + } + return b.Bytes() +} + +// Text marshaling + +func (e Endpoint) MarshalText() (text []byte, err error) { + return []byte(e.Address()), nil +} + +func (e *Endpoint) UnmarshalText(text []byte) error { + h, p, err := net.SplitHostPort(string(text)) + if err != nil { + return err + } + + ip, err := ParseIP(h) + if err != nil { + return err + } + + port, err := strconv.Atoi(p) + if err != nil { + return err + } + + // check if port fits in 16 bits + if (port >> 16) > 0 { + return errors.New("port out of range") + } + + e.IP = ip + e.Port = astral.Uint16(port) + + return nil +} + +// ... + +func (e *Endpoint) String() string { + return e.Address() +} + +func (e *Endpoint) IsZero() bool { + return e == nil || e.IP == nil +} + +func ParseEndpoint(s string) (*Endpoint, error) { + hostStr, portStr, err := net.SplitHostPort(s) + if err != nil { + return nil, err + } + + ip, err := ParseIP(hostStr) + if err != nil { + return nil, err + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, err + } + + // check if port fits in 16 bits + if (port >> 16) > 0 { + return nil, errors.New("port out of range") + } + + return &Endpoint{ + IP: ip, + Port: astral.Uint16(port), + }, nil +} + +func (e *Endpoint) UDPAddr() *net.UDPAddr { + return &net.UDPAddr{ + IP: net.ParseIP(e.IP.String()), + Port: int(e.Port), + } +} + +func init() { + _ = astral.DefaultBlueprints.Add(&Endpoint{}) + + term.SetTranslateFunc(func(o *Endpoint) astral.Object { + return &term.ColorString{ + Color: term.HighlightColor, + Text: astral.String32(o.String()), + } + }) +} diff --git a/mod/udp/errors.go b/mod/udp/errors.go new file mode 100644 index 00000000..c64bc159 --- /dev/null +++ b/mod/udp/errors.go @@ -0,0 +1,19 @@ +package udp + +import "errors" + +var ( + ErrListenerClosed = errors.New("listener closed") + ErrRetransmissionLimitExceeded = errors.New( + "retransmissions limit exceeded") + + // ErrDataLost is emitted on close if there was still buffered or unacked data. + ErrDataLost = errors.New("unsent data lost on close") + ErrPacketTooShort = errors.New("packet too short") + ErrConnClosed = errors.New("connection closed") + ErrInvalidPayloadLength = errors.New("invalid payload length") + ErrZeroMSS = errors.New("invalid MSS") + ErrMalformedPacket = errors.New("malformed packet") + ErrHandshakeTimeout = errors.New("handshake timeout") + ErrConnectionNotEstablished = errors.New("connection not established") +) diff --git a/mod/udp/ip.go b/mod/udp/ip.go new file mode 100644 index 00000000..bb45f967 --- /dev/null +++ b/mod/udp/ip.go @@ -0,0 +1,97 @@ +package udp + +import ( + "encoding/json" + "errors" + "io" + "net" + + "github.com/cryptopunkscc/astrald/astral" +) + +// NOTE: this is same for tcp / udp modules, consider moving to a common package +type IP net.IP + +func ParseIP(s string) (IP, error) { + return IP(net.ParseIP(s)), nil +} + +// astral + +func (IP) ObjectType() string { + return "mod.udp.ip_address" +} + +func (ip IP) WriteTo(w io.Writer) (n int64, err error) { + if ip.IsIPv4() { + return astral.Bytes8(net.IP(ip).To4()).WriteTo(w) + } + + return astral.Bytes8(ip).WriteTo(w) +} + +func (ip *IP) ReadFrom(r io.Reader) (n int64, err error) { + return (*astral.Bytes8)(ip).ReadFrom(r) +} + +// json + +func (ip IP) MarshalJSON() ([]byte, error) { + return json.Marshal(ip.String()) +} + +func (ip *IP) UnmarshalJSON(b []byte) error { + var str string + err := json.Unmarshal(b, &str) + if err != nil { + return nil + } + + parsed := IP(net.ParseIP(str)) + + if parsed == nil { + return errors.New("invalid IP") + } + + *ip = parsed + return nil +} + +// text + +func (ip IP) MarshalText() (text []byte, err error) { + return []byte(ip.String()), nil +} + +func (ip *IP) UnmarshalText(text []byte) error { + parsed := IP(net.ParseIP(string(text))) + if parsed == nil { + return errors.New("invalid IP") + } + *ip = parsed + return nil +} + +// ... + +func (ip IP) IsIPv4() bool { + return net.IP(ip).To4() != nil +} + +func (ip IP) IsIPv6() bool { + return net.IP(ip).To16() != nil +} + +func (ip IP) IsLoopback() bool { return net.IP(ip).IsLoopback() } + +func (ip IP) IsGlobalUnicast() bool { return net.IP(ip).IsGlobalUnicast() } + +func (ip IP) IsPrivate() bool { return net.IP(ip).IsPrivate() } + +func (ip IP) String() string { + return net.IP(ip).String() +} + +func init() { + _ = astral.DefaultBlueprints.Add(&IP{}) +} diff --git a/mod/udp/module.go b/mod/udp/module.go new file mode 100644 index 00000000..a646e0d6 --- /dev/null +++ b/mod/udp/module.go @@ -0,0 +1,14 @@ +package udp + +import ( + "github.com/cryptopunkscc/astrald/mod/exonet" +) + +const ModuleName = "udp" + +type Module interface { + exonet.Dialer + exonet.Unpacker + exonet.Parser + ListenPort() int +} diff --git a/mod/udp/rudp/README.md b/mod/udp/rudp/README.md new file mode 100644 index 00000000..e5d552e8 --- /dev/null +++ b/mod/udp/rudp/README.md @@ -0,0 +1,202 @@ +# RUDP (Proof of Concept) + +Reliable UDP with minimal mechanisms: fixed-interval retransmissions, cumulative ACKs, simple packet window, basic 3‑way handshake. + +## Current Status Summary + +Works today: +- 3‑way handshake (SYN → SYN|ACK → ACK) with random non‑zero ISNs. +- Cumulative reliability (no SACK) using single connection-level retransmission timer. +- Fixed RTO (no RTT sampling, no backoff) + global resend of all unacked packets each tick. +- Flow control: sender-side packet count window (`MaxWindowPackets` considered conversion for MaxWindowBytes for better performance), MSS-bounded fragments. +- Fragmentation: simple slice of up to MSS bytes per packet (no coalescing / segmentation logic beyond cap). +- Delayed ACK: single configurable `AckDelay` timer (pure ACK or piggyback on outgoing data). +- Out-of-order receive buffering with gap fill & drain (`recvOO` map). +- Ring buffers: send (bytes) + recv (bytes). Blocking `Write` until space. +- Error surfacing: `ErrChan` (retransmission limit, data loss on close, marshal errors). +- Listener accepts only established connections; server handshake runs per inbound SYN. + +Not implemented / not present: +- FIN/close protocol (flags exist; logic is TODO / unused). +- Half-close, TIME-WAIT, RST, keepalive/ping. +- Congestion control, pacing, rate limiting. +- Adaptive RTO (no RTT measurement), exponential backoff, per-packet timers. +- Selective ACK, fast retransmit, loss detection heuristics beyond timeout. +- Use of advertised receive window (`Win` field) by sender (currently ignored). +- MSS/Path MTU discovery (fixed MSS from config only). +- Byte-based dynamic window / BDP autotuning. +- Checksum beyond UDP (no payload integrity apart from lower layer). +- Security (no encryption, authentication, replay protection). +- Stats/metrics/telemetry surfaces. +- Config field `SendBufBytes` (unused; send buffer size derived from window*MSS*2). +- Backpressure signaling aside from blocking `Write`. + +Platform assumptions: standard Go net.UDP; loopback benchmark only; no OS autotuning hooks. + +## Architecture Overview + +Data path (outbound): +Application → `Conn.Write` → send ring buffer → `senderLoop` → fragment (≤ MSS) → packet build (Seq/Ack/Flags) → UDP send → network. + +Receive path: UDP recv → `recvLoop` → parse → (handshake vs established) → data: ACK handling + reordering + ring write → `Conn.Read` → Application. + +ACK path: receiver marks `ackPending`; delayed (timer) or immediate; pure ACK or piggyback in data packet (`FlagACK`). + +Mermaid sequence (simplified steady-state): +```mermaid +sequenceDiagram + participant AppW as App (Writer) + participant Conn as Conn + participant Snd as senderLoop + participant UDP as UDP Socket + participant Rcv as recvLoop + participant AppR as App (Reader) + + AppW->>Conn: Write(bytes) + Conn->>Conn: Enqueue bytes (sendRB) + Conn->>Snd: Signal cond + loop While data & window + Snd->>Snd: Fragment ≤ MSS + Snd->>Snd: Build Packet(Seq,Ack,ACK) + Snd->>UDP: sendDatagram() + end + UDP-->>Rcv: Datagram + Rcv->>Rcv: Unmarshal Packet + Rcv->>Conn: HandleAck (if piggyback) + Rcv->>Conn: handleDataPacket + Conn->>Conn: Reorder / drain recvOO + Conn->>Conn: queueAckLocked() + Conn->>UDP: (maybe) Pure ACK (delayed) + AppR->>Conn: Read() + Conn-->>AppR: Bytes +``` + +## Reliability Mechanisms (Existing) +- ACKs: Cumulative only. Data packets always carry `Ack`. Pure ACK packets (`FlagACK`, `Len=0`) sent on delay timer or immediately (duplicates, out-of-order, buffer full). +- Retransmission: Single fixed-interval connection-level timer; every tick retransmits ALL unacked packets (handshake + data). No RTT measurement, no per-packet timers, no exponential backoff. Limit enforced by `RetransmissionLimit` per packet; exceeding closes connection and emits error. +- Window: Sender-side count of in-flight packets (`len(unacked) < MaxWindowPackets`). Receiver advertised `Win` (bytes) set only on pure ACK; unused by sender. +- Reordering: Out-of-order packets stored in `recvOO` keyed by starting `Seq`; drained when gap closes. No explicit cap beyond recv buffer space. +- Fragmentation / MSS: Each packet contains up to `MaxSegmentSize` bytes; no multi-packet segmentation logic aside from slicing to MSS. +- Checksum/validation: Only length/header sanity checks; no additional checksum. +- Pacing / rate limiting: Not implemented. +- Close semantics: `Close()` does NOT send FIN; it halts timers, signals waiters, may emit `ErrDataLost` if unsent or unacked data remained; no half-close. + +## Configuration & Defaults +(From `config.go` after `Normalize()`) +- `MaxSegmentSize` (MSS): `DefaultMSS` (1200 - 13 = 1187 bytes payload). +- `MaxWindowPackets`: 1024. +- `RetransmissionInterval` (RTO fixed): 200ms. +- `MaxRetransmissionInterval`: 4s (NOT USED currently). +- `RetransmissionLimit`: 8 (per packet, including handshake packets). +- `AckDelay`: 5ms (0 => immediate ACK). +- `RecvBufBytes`: 16 MiB. +- `SendBufBytes`: 16 MiB (NOT USED; actual send buffer = window*MSS*2). + +## Wire Format (As Implemented) +Header (13 bytes): +- `Seq` (4): first byte sequence number of this segment. +- `Ack` (4): cumulative ACK (all bytes < Ack received). +- `Flags` (1): bitmask (SYN=1, ACK=2, FIN=4). +- `Win` (2): advertised free receive window in BYTES (only set on pure ACK; 0 on data packets currently). +- `Len` (2): payload length. +- `Payload` (Len bytes). + +Usage: +- Data: `Len>0`, always sets `FlagACK`, `Win`=0. +- Pure ACK: `Len=0`, `FlagACK`, sets `Win` to clamped free bytes (≤ 0xFFFF). +- Handshake: SYN (no payload), SYN|ACK, final ACK. +- FIN: Flag defined but not produced/consumed. + +## State & Timers (As Implemented) +States defined: `Closed, Listen, SynSent, SynReceived, Established, FinSent, FinReceived, TimeWait`. +Active transitions exercised: Closed → SynSent → Established (client); Closed → SynReceived → Established (server). FIN-related states unused. +Timers: +- Retransmission timer: single `time.AfterFunc` repeating at fixed `RetransmissionInterval` while any unacked packets exist. +- Delayed ACK timer: per-connection `ackTimer` (fires once per pending ACK batch) with delay `AckDelay`. +Cancellation: +- Retransmission timer stopped when `unacked` empty or connection closed. +- ACK timer canceled (stop) on close or when ACK sent. +No other timers (no keepalive, no linger/TIME-WAIT). + +## Buffers & Memory +- Send ring: size = `MaxWindowPackets * MaxSegmentSize * 2` bytes; holds queued outbound data not yet packetized. Bytes removed upon packetization (not upon ACK); retransmissions use stored packet copies in `unacked` map. +- Receive ring: size = `RecvBufBytes`; holds in-order assembled payload bytes for `Read`. +- Out-of-order map: one entry per future gap start (key = packet.Seq). Evicted on drainage. +- Packet copies: each in-flight packet (data + handshake) retained until cumulatively ACKed; payload slice owned exclusively by `unacked` entry. +- Zero-copy: Not present (copies on UDP read into new slice, into recv ring, then into caller buffer on `Read`). +- GC considerations: One allocation per outbound packet payload; per in-order + out-of-order inbound packet payload; ring buffers reuse internal slabs. + +## Benchmarks & Expected Localhost Behavior +Benchmark: `BenchmarkRUDPTransfer10MiB` (one 10 MiB unidirectional transfer client→server; handshake excluded from timed portion). +Running: +```bash +go test -run ^$ -bench BenchmarkRUDPTransfer10MiB -benchtime=1x ./mod/udp/rudp +RUDP_BENCH_PROGRESS=1 go test -run ^$ -bench BenchmarkRUDPTransfer10MiB -benchtime=1x ./mod/udp/rudp +``` +Reports: +- MiB/s throughput (write phase only). +- Allocations, bytes per op. +Typical throughput: Not documented in code (dependent on fixed window * MSS / RTO / ACK delay and host performance). + +## Missing Pieces +- No FIN exchange / graceful close / half-close. +- No congestion control or pacing. +- No RTT measurement / adaptive RTO / backoff (fixed interval only). +- `MaxRetransmissionInterval` unused. +- No selective ACK (SACK) / fast retransmit / loss detection beyond timeout. +- No use of advertised `Win` (flow control purely packet-count based). +- No path MTU discovery / dynamic MSS. +- No checksum beyond UDP; silent data corruption undetected at transport layer. +- No encryption / authentication. +- No keepalive / idle timeout / ping. +- No connection IDs beyond initial Seq reuse; no NAT rebinding handling beyond IP match. +- `SendBufBytes` config unused. +- Backpressure only via blocking writes; no error signaling before block. +- +## Future Improvements +### Near-term Safety & Performance +- Adaptive byte-based window & BDP autotune. +- RTT sampling + RFC 6298 RTO calculation + exponential backoff. +- Packet pacing / rate shaping. +- Time-threshold (fast) loss detection + limited fast retransmit. +- Adaptive ACK delay (heuristics based on flight / reordering). + +### Protocol Completeness +- Proper FIN handshake, half-close, and TIME-WAIT handling. +- Explicit connection IDs distinct from ISNs. +- RST/abort semantics. +- Keepalive / ping frames. + +### Robustness & Reliability +- Path MTU discovery (PLPMTUD) + MSS adjustment. +- Selective ACK (SACK) & out-of-order range tracking. +- Reordering window limits / discard policies. +- Detailed stats / telemetry (RTT, loss, retransmits, goodput). +- Error surfaces for flow control violations / protocol anomalies. + +### Tooling & Testing +- Loss / delay / duplication chaos test harness. +- Cross-platform CI with network emulation. +- Long-running soak / stress tests (memory & stability). +- Benchmark variants (bidirectional, varying window/MSS). + +## Benchmarks (uTP Comparison) + +| Implementation | Iterations | ns/op (mean) | Throughput (MB/s) | B/op | allocs/op | +|---|---:|---:|---:|---:|---:| +| **rudp** (`github.com/cryptopunkscc/astrald/mod/udp/rudp`) | 48 | 71,213,610 | 147.24 | 23,393,100 | 94,552 | +| **uTP** (`github.com/anacrolix/utp`) | 48 | 24,602,392 | 426.21 | 82,243,573 | 93,825 | + +### Notes + +Numbers come from the respective benchmark outputs on loopback (localhost), 10 MiB payload; handshake/setup excluded from the timed region. + +uTP shows ~3× higher throughput on this workload. Allocation counts are comparable; rudp allocates fewer bytes/op but runs slower overall. + +This rudp is a proof of concept and has not been profiler-guided or optimized yet. + +### Potential Improvements to Close the Gap + +Adding fast retransmits / time-threshold loss detection and switching from a packet-count window (MaxWindowPackets) to a byte-based window (MaxWindowBytes) with BDP-aware autotuning will most likely bring performance closer to uTP on high-throughput paths. + +Additional expected wins: adaptive ACK delay (immediate on reordering/gap), pacing, and better RTO sizing \ No newline at end of file diff --git a/mod/udp/rudp/bench_test.go b/mod/udp/rudp/bench_test.go new file mode 100644 index 00000000..ae4e7431 --- /dev/null +++ b/mod/udp/rudp/bench_test.go @@ -0,0 +1,199 @@ +package rudp_test + +import ( + "context" + "net" + "os" + "sync/atomic" + "testing" + "time" + + "github.com/cryptopunkscc/astrald/astral" + udpmod "github.com/cryptopunkscc/astrald/mod/udp" + rudp "github.com/cryptopunkscc/astrald/mod/udp/rudp" +) + +var benchDebug = os.Getenv("RUDP_BENCH_DEBUG") != "" +var benchProgress = os.Getenv("RUDP_BENCH_PROGRESS") != "" + +// BenchmarkRUDPTransfer10MiB performs a single unidirectional transfer of ~10 MiB +// from client -> server over a single RUDP connection on loopback. Handshake and +// setup are excluded from the timed section. The benchmark runs only one data +// transfer regardless of b.N (subsequent iterations return early) to bound total +// execution time while still reporting throughput via b.SetBytes. +func BenchmarkRUDPTransfer10MiB(b *testing.B) { + if testing.Short() { + b.Skip("skip in short mode") + } + + const totalBytes = 10 * 1024 * 1024 // 10 MiB per iteration + baseChunkSize := 64 * 1024 + b.SetBytes(totalBytes) + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Stop the timer during setup/teardown so only the write phase is measured. + b.StopTimer() + + b.Logf("[iter %d] setup starting", i) + baseCtx := astral.NewContext(context.Background()) + + cfg := rudp.Config{} + cfg.Normalize() // simplest model: use normalized defaults + + l, err := rudp.Listen(baseCtx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, cfg, 2*time.Second) + if err != nil { + b.Fatalf("listen: %v", err) + } + b.Logf("[iter %d] listener=%v", i, l.Addr()) + + serverAddr := l.Addr().(*net.UDPAddr) + ipv4Dest := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: serverAddr.Port} + + acceptedCh := make(chan *rudp.Conn, 1) + acceptErrCh := make(chan error, 1) + go func() { + acceptCtx, cancel := baseCtx.WithTimeout(3 * time.Second) + defer cancel() + c, aerr := l.Accept(acceptCtx) + if aerr != nil { + acceptErrCh <- aerr + return + } + acceptedCh <- c + }() + + udpConn, err := net.DialUDP("udp4", nil, ipv4Dest) + if err != nil { + udpConn, err = net.DialUDP("udp", nil, ipv4Dest) + } + if err != nil { + l.Close() + b.Fatalf("dial: %v", err) + } + b.Logf("[iter %d] client local=%v remote=%v", i, udpConn.LocalAddr(), udpConn.RemoteAddr()) + + localEP, _ := udpmod.ParseEndpoint(udpConn.LocalAddr().String()) + remoteEP, _ := udpmod.ParseEndpoint(udpConn.RemoteAddr().String()) + outConn, err := rudp.NewConn(udpConn, localEP, remoteEP, cfg, true, nil, baseCtx) + if err != nil { + udpConn.Close() + l.Close() + b.Fatalf("newconn: %v", err) + } + + hCtx, hCancel := baseCtx.WithTimeout(2 * time.Second) + if err := outConn.StartClientHandshake(hCtx); err != nil { + hCancel() + outConn.Close() + udpConn.Close() + l.Close() + b.Fatalf("handshake: %v", err) + } + hCancel() + b.Logf("[iter %d] handshake complete local=%v remote=%v", i, udpConn.LocalAddr(), udpConn.RemoteAddr()) + + var serverConn *rudp.Conn + select { + case serverConn = <-acceptedCh: + case aerr := <-acceptErrCh: + outConn.Close() + udpConn.Close() + l.Close() + b.Fatalf("accept: %v", aerr) + case <-time.After(3 * time.Second): + outConn.Close() + udpConn.Close() + l.Close() + b.Fatalf("accept timeout") + } + if serverConn == nil { + outConn.Close() + udpConn.Close() + l.Close() + b.Fatalf("nil server conn") + } + b.Logf("[iter %d] server accepted remote=%v", i, serverConn.RemoteEndpoint()) + + // Determine a safe chunk size that will not block on first write: + // send ring buffer size = MaxWindowBytes * 2. Choose <= MaxWindowBytes /2 for margin. + safeChunk := cfg.MaxWindowPackets * cfg.MaxSegmentSize / 2 + if safeChunk < cfg.MaxSegmentSize { + safeChunk = cfg.MaxSegmentSize + } + if baseChunkSize > safeChunk { + if benchDebug { + b.Logf("[iter %d] adjusting chunk size from %d to %d ("+ + "MaxWindowBytes=%d)", i, baseChunkSize, safeChunk, cfg.MaxWindowPackets) + } + } + chunkSize := safeChunk + + // Prepare transfer + chunk := make([]byte, chunkSize) + for j := range chunk { + chunk[j] = byte(j) + } + var received int64 + readDone := make(chan struct{}) + var readErr atomic.Value + go func() { + buf := make([]byte, 128*1024) + for atomic.LoadInt64(&received) < int64(totalBytes) { + n, err := serverConn.Read(buf) + if err != nil { + readErr.Store(err) + break + } + if n > 0 { + atomic.AddInt64(&received, int64(n)) + } + } + close(readDone) + }() + + b.Logf("[iter %d] starting timed transfer", i) + writeStart := time.Now() + // Start the timer only for the write phase (do not reset cumulative timer) + b.StartTimer() + sent := 0 + nextMilestone := 1 * 1024 * 1024 + for sent < totalBytes { + toWrite := chunkSize + if rem := totalBytes - sent; rem < toWrite { + toWrite = rem + } + if _, err := outConn.Write(chunk[:toWrite]); err != nil { + b.Fatalf("write err after %d bytes: %v", sent, err) + } + sent += toWrite + if benchProgress && sent >= nextMilestone { + b.Logf("[iter %d] progress sent=%d / %d (%.2f%%)", i, sent, totalBytes, 100*float64(sent)/float64(totalBytes)) + nextMilestone += 1 * 1024 * 1024 + } + } + b.StopTimer() + elapsedWrite := time.Since(writeStart) + + select { + case <-readDone: + case <-time.After(5 * time.Second): + b.Fatalf("read timeout received=%d", atomic.LoadInt64(&received)) + } + if v := readErr.Load(); v != nil { + b.Fatalf("server read error: %v", v.(error)) + } + if got := atomic.LoadInt64(&received); got != int64(totalBytes) { + b.Fatalf("recv mismatch got=%d want=%d", got, totalBytes) + } + + mbps := (float64(totalBytes) / (1024 * 1024)) / elapsedWrite.Seconds() + b.Logf("[iter %d] transfer complete bytes=%d duration=%v throughput=%.2f MiB/s", i, totalBytes, elapsedWrite, mbps) + + serverConn.Close() + outConn.Close() + udpConn.Close() + l.Close() + b.Logf("[iter %d] cleanup done", i) + } +} diff --git a/mod/udp/rudp/config.go b/mod/udp/rudp/config.go new file mode 100644 index 00000000..42e6931f --- /dev/null +++ b/mod/udp/rudp/config.go @@ -0,0 +1,59 @@ +package rudp + +import "time" + +// Transport default constants (exported for visibility in tests & integration) +const ( + DefaultMSS = 1200 - 13 // 1187 (1200 minus header) + DefaultWndPkts = 1024 + DefaultRTO = 200 * time.Millisecond + DefaultRTOMax = 4 * time.Second + DefaultRetries = 8 + DefaultAckDelay = 5 * time.Millisecond + DefaultRecvBufBytes = 16 << 20 + DefaultSendBufBytes = 16 << 20 +) + +// Config holds reliability / buffering parameters for the rudp transport. +type Config struct { + MaxSegmentSize int `yaml:"max_segment_size"` + MaxWindowPackets int `yaml:"max_window_packets"` + RetransmissionInterval time.Duration `yaml:"retransmission_interval"` + MaxRetransmissionInterval time.Duration `yaml:"max_retransmission_interval"` + RetransmissionLimit int `yaml:"retransmission_limit"` + AckDelay time.Duration `yaml:"ack_delay"` + RecvBufBytes int `yaml:"recv_buf_bytes"` + SendBufBytes int `yaml:"send_buf_bytes"` +} + +// Normalize applies defaults to zero-value fields (no clamping beyond basic sanity here). +func (c *Config) Normalize() { + if c.MaxSegmentSize == 0 { + c.MaxSegmentSize = DefaultMSS + } + if c.MaxWindowPackets == 0 { + c.MaxWindowPackets = DefaultWndPkts + } + if c.RetransmissionInterval == 0 { + c.RetransmissionInterval = DefaultRTO + } + if c.MaxRetransmissionInterval == 0 { + c.MaxRetransmissionInterval = DefaultRTOMax + } + if c.RetransmissionLimit == 0 { + c.RetransmissionLimit = DefaultRetries + } + if c.AckDelay == 0 { + c.AckDelay = DefaultAckDelay + } + if c.RecvBufBytes == 0 { + c.RecvBufBytes = DefaultRecvBufBytes + } + if c.SendBufBytes == 0 { + c.SendBufBytes = DefaultSendBufBytes + } +} + +// - AckDelay: mirrors QUIC MAX_ACK_DELAY (RFC 9000 §13.2.1). +// - Buffer sizes: 1 MiB default, capped for safety, must be >= aggregate window. +// - Aggregate window bytes can be derived as MaxWindowPackets * MaxSegmentSize. diff --git a/mod/udp/rudp/config_test.go b/mod/udp/rudp/config_test.go new file mode 100644 index 00000000..5e0b415d --- /dev/null +++ b/mod/udp/rudp/config_test.go @@ -0,0 +1,194 @@ +package rudp + +import ( + "testing" + "time" +) + +// TestNormalizeAppliesDefaults verifies that a zero-value Config gets all default values. +func TestNormalizeAppliesDefaults(t *testing.T) { + var c Config + c.Normalize() + + if c.MaxSegmentSize != DefaultMSS { + f(t, "MaxSegmentSize", c.MaxSegmentSize, DefaultMSS) + } + + if c.MaxWindowPackets != DefaultWndPkts { + f(t, "MaxWindowPackets", c.MaxWindowPackets, DefaultWndPkts) + } + if c.RetransmissionInterval != DefaultRTO { + f(t, "RetransmissionInterval", c.RetransmissionInterval, DefaultRTO) + } + if c.MaxRetransmissionInterval != DefaultRTOMax { + f(t, "MaxRetransmissionInterval", c.MaxRetransmissionInterval, DefaultRTOMax) + } + if c.RetransmissionLimit != DefaultRetries { + f(t, "RetransmissionLimit", c.RetransmissionLimit, DefaultRetries) + } + if c.AckDelay != DefaultAckDelay { + f(t, "AckDelay", c.AckDelay, DefaultAckDelay) + } + if c.RecvBufBytes != DefaultRecvBufBytes { + f(t, "RecvBufBytes", c.RecvBufBytes, DefaultRecvBufBytes) + } + if c.SendBufBytes != DefaultSendBufBytes { + f(t, "SendBufBytes", c.SendBufBytes, DefaultSendBufBytes) + } +} + +// TestNormalizePreservesNonZero verifies that non-zero fields are not overwritten. +func TestNormalizePreservesNonZero(t *testing.T) { + orig := Config{ + MaxSegmentSize: 999, + MaxWindowPackets: 77, + RetransmissionInterval: 321 * time.Millisecond, + MaxRetransmissionInterval: 987 * time.Millisecond, + RetransmissionLimit: 42, + AckDelay: 11 * time.Millisecond, + RecvBufBytes: 222222, + SendBufBytes: 333333, + } + c := orig + c.Normalize() + + if c != orig { + // Compare field-by-field for clearer diagnostics. + if c.MaxSegmentSize != orig.MaxSegmentSize { + g(t, "MaxSegmentSize", c.MaxSegmentSize, orig.MaxSegmentSize) + } + + if c.MaxWindowPackets != orig.MaxWindowPackets { + g(t, "MaxWindowPackets", c.MaxWindowPackets, orig.MaxWindowPackets) + } + if c.RetransmissionInterval != orig.RetransmissionInterval { + g(t, "RetransmissionInterval", c.RetransmissionInterval, orig.RetransmissionInterval) + } + if c.MaxRetransmissionInterval != orig.MaxRetransmissionInterval { + g(t, "MaxRetransmissionInterval", c.MaxRetransmissionInterval, orig.MaxRetransmissionInterval) + } + if c.RetransmissionLimit != orig.RetransmissionLimit { + g(t, "RetransmissionLimit", c.RetransmissionLimit, orig.RetransmissionLimit) + } + if c.AckDelay != orig.AckDelay { + g(t, "AckDelay", c.AckDelay, orig.AckDelay) + } + if c.RecvBufBytes != orig.RecvBufBytes { + g(t, "RecvBufBytes", c.RecvBufBytes, orig.RecvBufBytes) + } + if c.SendBufBytes != orig.SendBufBytes { + g(t, "SendBufBytes", c.SendBufBytes, orig.SendBufBytes) + } + // Fail after reporting discrepancies. + if t.Failed() { + return + } + } +} + +// TestNormalizePartial ensures only zero fields get populated. +func TestNormalizePartial(t *testing.T) { + c := Config{ + MaxSegmentSize: 500, // keep + // others zero + AckDelay: 5 * time.Millisecond, // keep + } + c.Normalize() + + if c.MaxSegmentSize != 500 { + g(t, "MaxSegmentSize", c.MaxSegmentSize, 500) + } + if c.AckDelay != 5*time.Millisecond { + g(t, "AckDelay", c.AckDelay, 5*time.Millisecond) + } + + if c.MaxWindowPackets != DefaultWndPkts { + f(t, "MaxWindowPackets", c.MaxWindowPackets, DefaultWndPkts) + } + if c.RetransmissionInterval != DefaultRTO { + f(t, "RetransmissionInterval", c.RetransmissionInterval, DefaultRTO) + } + if c.MaxRetransmissionInterval != DefaultRTOMax { + f(t, "MaxRetransmissionInterval", c.MaxRetransmissionInterval, DefaultRTOMax) + } + if c.RetransmissionLimit != DefaultRetries { + f(t, "RetransmissionLimit", c.RetransmissionLimit, DefaultRetries) + } + if c.RecvBufBytes != DefaultRecvBufBytes { + f(t, "RecvBufBytes", c.RecvBufBytes, DefaultRecvBufBytes) + } + if c.SendBufBytes != DefaultSendBufBytes { + f(t, "SendBufBytes", c.SendBufBytes, DefaultSendBufBytes) + } +} + +// TestNormalizeIdempotent ensures calling Normalize twice doesn't change values after first call. +func TestNormalizeIdempotent(t *testing.T) { + var c Config + c.Normalize() + first := c + c.Normalize() + if c != first { + g(t, "ConfigAfterSecondNormalize", c, first) + } +} + +// TestNormalizeNegativeValues ensures negative values are preserved (no implicit clamping yet). +func TestNormalizeNegativeValues(t *testing.T) { + c := Config{ + MaxSegmentSize: -1, + MaxWindowPackets: -3, + RetransmissionLimit: -4, + RecvBufBytes: -5, + SendBufBytes: -6, + } + // Durations negative as well + c.RetransmissionInterval = -10 * time.Millisecond + c.MaxRetransmissionInterval = -20 * time.Millisecond + c.AckDelay = -30 * time.Millisecond + + c.Normalize() + + if c.MaxSegmentSize != -1 { + g(t, "MaxSegmentSize", c.MaxSegmentSize, -1) + } + + if c.MaxWindowPackets != -3 { + g(t, "MaxWindowPackets", c.MaxWindowPackets, -3) + } + if c.RetransmissionLimit != -4 { + g(t, "RetransmissionLimit", c.RetransmissionLimit, -4) + } + if c.RecvBufBytes != -5 { + g(t, "RecvBufBytes", c.RecvBufBytes, -5) + } + if c.SendBufBytes != -6 { + g(t, "SendBufBytes", c.SendBufBytes, -6) + } + if c.RetransmissionInterval != -10*time.Millisecond { + g(t, "RetransmissionInterval", c.RetransmissionInterval, -10*time.Millisecond) + } + if c.MaxRetransmissionInterval != -20*time.Millisecond { + g(t, "MaxRetransmissionInterval", c.MaxRetransmissionInterval, -20*time.Millisecond) + } + if c.AckDelay != -30*time.Millisecond { + g(t, "AckDelay", c.AckDelay, -30*time.Millisecond) + } +} + +// Helper failure formatters for brevity. +func f[T comparable](t *testing.T, field string, got, want T) { + if got != want { + // Using t.Fatalf to stop early in default-path tests where cascading errors add little value. + // Use %v for generic print. + //nolint:forbidigo // simple test diagnostic + t.Fatalf("%s mismatch: got=%v want=%v", field, got, want) + } +} + +func g[T comparable](t *testing.T, field string, got, want T) { + if got != want { + //nolint:forbidigo + t.Errorf("%s mismatch: got=%v want=%v", field, got, want) + } +} diff --git a/mod/udp/rudp/conn.go b/mod/udp/rudp/conn.go new file mode 100644 index 00000000..d2918656 --- /dev/null +++ b/mod/udp/rudp/conn.go @@ -0,0 +1,396 @@ +// conn.go +package rudp + +import ( + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/udp" + "github.com/smallnest/ringbuffer" +) + +// Conn represents a reliable UDP connection. +// Implements reliability, flow control, retransmissions, and error notification. +// Key mechanisms: +// - Reliable delivery using retransmission timer and fast retransmit +// - Flow control using packet window (MaxWindowPackets) +// - Concurrency safety via sendMu and sendCond +// - Error notification via ErrChan (application can monitor for connection-level errors) +// - Centralized resource cleanup via Close() +// - PoC limitations: no congestion control, no SACK, no adaptive pacing +type Conn struct { + // UDP socket and addressing + udpConn *net.UDPConn // Underlying UDP socket + localEndpoint *udp.Endpoint + remoteEndpoint *udp.Endpoint + outbound bool // true if we initiated the connection + onEstablishedCb func(*Conn) + onClosedCb func(*Conn) + + // Configuration (reliability, flow control, etc.) + cfg Config // All protocol parameters + + // Connection state (atomic for lock-free reads) + state uint32 // Current connection state (stores ConnState) + inCh chan *Packet // Incoming packet channel + closedFlag uint32 // 0=open, 1=closed (atomic) + + // Sequence numbers and send state + initialSeqNumLocal uint32 + initialSeqNumRemote uint32 + nextSeqNum uint32 + connID uint32 // Connection ID + ackedSeqNum uint32 // Highest cumulative ACK observed for our sent data (also piggybacked when sending) + expected uint32 // Next expected sequence number (receive side) + // Send buffer and reliability + sendRB *ringbuffer.RingBuffer // Persistent send ring buffer (FIFO; bytes consumed at packetization) + frag *BasicFragmenter // Fragmenter for packetization + unacked map[uint32]*Unacked // Map of unacked packets (seq -> Unacked); stores full packet copies for retransmission + + // Concurrency and coordination + sendMu sync.Mutex // Protects all shared state + sendCond *sync.Cond // Condition variable for sender coordination + + // Retransmission timer + rtxTimer *time.Timer // Fixed retransmission timer (PoC only) + + // Error notification + ErrChan chan error // Channel for connection-level errors (e.g., retransmission failure) + + // Inbound buffering & ACK state + recvRB *ringbuffer.RingBuffer + recvMu sync.Mutex + recvCond *sync.Cond + ackTimer *time.Timer + ackPending bool + // Out-of-order buffer (keyed by sequence number of first byte) + recvOO map[uint32]*Packet // stored packets with Seq > expected awaiting in-order delivery + +} + +// unified handshake channel capacity (applies to inbound & outbound) +const handshakeQueueCap = 64 + +// NewConn constructs a connection around a UDP socket. +// Parameters: +// +// outbound - true for client initiated connection (owns socket & recv loop) +// firstPacket / ctx - required only for inbound (outbound=false) to drive server handshake +func NewConn(cn *net.UDPConn, l, r *udp.Endpoint, cfg Config, outbound bool, firstPacket *Packet, ctx *astral.Context) (*Conn, error) { + cfg.Normalize() + if cfg.MaxSegmentSize <= 0 { + return nil, udp.ErrZeroMSS + } + + if !outbound { + if firstPacket == nil || ctx == nil { + return nil, errors.New("inbound connection requires firstPacket and ctx") + } + } + + // sendRBSize in BYTES: MaxWindowPackets * MaxSegmentSize * 2 for retransmit slack + sendRBSize := cfg.MaxWindowPackets * cfg.MaxSegmentSize * 2 + rb := ringbuffer.New(sendRBSize) + frag := NewBasicFragmenter(cfg.MaxSegmentSize) + + rc := &Conn{ + udpConn: cn, + localEndpoint: l, + remoteEndpoint: r, + cfg: cfg, + sendRB: rb, + frag: frag, + unacked: make(map[uint32]*Unacked), + ErrChan: make(chan error, 1), // Buffered to avoid blocking + inCh: make(chan *Packet, handshakeQueueCap), // handshake delivery channel + outbound: outbound, + } + rc.sendCond = sync.NewCond(&rc.sendMu) + rc.recvRB = ringbuffer.New(cfg.RecvBufBytes) + rc.recvCond = sync.NewCond(&rc.recvMu) + rc.recvOO = make(map[uint32]*Packet) + + if outbound { + // Do not start recvLoop here; StartClientHandshake will start it after handshake completes + } else { + go rc.recvLoop() + } + + if !outbound { + // Start server handshake asynchronously for inbound connections + go func() { + if err := rc.StartServerHandshake(ctx, firstPacket); err != nil { + rc.Close() + } + }() + } + + return rc, nil +} + +// OnEstablished registers a callback invoked exactly once when the connection transitions to Established. +func (c *Conn) OnEstablished(cb func(*Conn)) { + c.sendMu.Lock() + c.onEstablishedCb = cb + c.sendMu.Unlock() + // Fast path: if already established, invoke asynchronously + if c.inState(StateEstablished) && cb != nil { + go cb(c) + } +} + +// OnClosed registers a callback invoked exactly once after Close() releases resources. +// If the connection is already closed when registering, the callback is invoked asynchronously. +func (c *Conn) OnClosed(cb func(*Conn)) { + c.sendMu.Lock() + c.onClosedCb = cb + closed := c.isClosed() + c.sendMu.Unlock() + if closed && cb != nil { + go cb(c) + } +} + +func (c *Conn) setState(state ConnState) { + atomic.StoreUint32(&c.state, uint32(state)) +} + +func (c *Conn) inState(state ConnState) bool { + return atomic.LoadUint32(&c.state) == uint32(state) +} + +func (c *Conn) isClosed() bool { return atomic.LoadUint32(&c.closedFlag) != 0 } + +func (c *Conn) Read(p []byte) (n int, err error) { + if !c.inState(StateEstablished) && !c.isClosed() { + return 0, udp.ErrConnectionNotEstablished + } + c.recvMu.Lock() + for c.recvRB != nil && c.recvRB.Length() == 0 && !c.isClosed() { + c.recvCond.Wait() + } + if c.recvRB == nil || (c.recvRB.Length() == 0 && c.isClosed()) { + c.recvMu.Unlock() + return 0, io.EOF + } + want := len(p) + if rl := int(c.recvRB.Length()); want > rl { + want = rl + } + if want == 0 { + c.recvMu.Unlock() + return 0, nil + } + // Read directly into caller's buffer (no temp allocation) + m, _ := c.recvRB.Read(p[:want]) + c.recvMu.Unlock() + return m, nil +} + +// Write enqueues data into the send ring buffer. Implementation in send.go. +func (c *Conn) Write(p []byte) (n int, err error) { + if c.isClosed() { + return 0, udp.ErrConnClosed + } + + if !c.inState(StateEstablished) { + return 0, udp.ErrConnectionNotEstablished + } + + return c.writeSend(p) +} + +func (c *Conn) Close() error { + c.sendMu.Lock() + if c.isClosed() { // already closed + c.sendMu.Unlock() + return nil + } + pendingData := c.sendRB != nil && c.sendRB.Length() > 0 + pendingUnacked := len(c.unacked) > 0 + atomic.StoreUint32(&c.closedFlag, 1) + if c.rtxTimer != nil { + c.rtxTimer.Stop() + c.rtxTimer = nil + } + c.sendCond.Broadcast() + ch := c.inCh + c.inCh = nil + closedCb := c.onClosedCb + c.sendMu.Unlock() + + if pendingData || pendingUnacked { + select { + case c.ErrChan <- udp.ErrDataLost: + default: + } + } + if ch != nil { + close(ch) + } + if c.outbound { // only needed to unblock recvLoop for outbound + _ = c.udpConn.SetReadDeadline(time.Now()) + } + c.recvMu.Lock() + if c.ackTimer != nil { + c.ackTimer.Stop() + c.ackTimer = nil + } + if c.recvCond != nil { + c.recvCond.Broadcast() + } + c.recvMu.Unlock() + var err error + if c.outbound { + err = c.udpConn.Close() + } + // Invoke close callback after resources released + if closedCb != nil { + closedCb(c) + } + + return err +} + +func (c *Conn) onEstablished() { + // Idempotent: only transition once + if c.inState(StateEstablished) || c.isClosed() { + return + } + // Initialize receive-side expected sequence to remote initial seq + 1 (account for SYN consuming one seq) + if c.initialSeqNumRemote != 0 && c.expected == 0 { + c.expected = c.initialSeqNumRemote + 1 + } + c.setState(StateEstablished) + // Future established-only initializations (keepalives, metrics, etc.) go here. + go c.senderLoop() + // Invoke callback (outside locks) if set + cb := func() func(*Conn) { c.sendMu.Lock(); defer c.sendMu.Unlock(); return c.onEstablishedCb }() + if cb != nil { + cb(c) + } +} + +// HandleAckPacket processes ACK packets +func (c *Conn) HandleAckPacket(packet *Packet) { + ack := packet.Ack + c.sendMu.Lock() + defer c.sendMu.Unlock() + if ack > c.ackedSeqNum { + c.ackedSeqNum = ack + } + // cumulative ACK floor implicitly reflected by ackedSeqNum + // Remove fully acked packets (keyed by seq) + for s, u := range c.unacked { + if u.isHandshake { + if ack > s { // expected ack == s+1 + delete(c.unacked, s) + } + continue + } + if s+uint32(u.length) <= ack { + delete(c.unacked, s) + } + } + if len(c.unacked) == 0 && c.rtxTimer != nil { + c.rtxTimer.Stop() + c.rtxTimer = nil + } + c.sendCond.Broadcast() +} + +// HandleControlPacket processes SYN, FIN, and other control packets +func (c *Conn) HandleControlPacket(packet *Packet) { + // Example: handle SYN, FIN, or other control logic + if packet.Flags&FlagSYN != 0 { + // TODO: ...handle SYN logic... + } + if packet.Flags&FlagFIN != 0 { + // TODO: ...handle FIN logic... + } + // ...handle other control flags as needed... +} + +// Outbound interface compliance for exonet.Conn +func (c *Conn) Outbound() bool { return c.outbound } + +// LocalEndpoint returns the local UDP endpoint of the connection. +func (c *Conn) LocalEndpoint() exonet.Endpoint { + if c == nil { + return nil + } + return c.localEndpoint +} + +// RemoteEndpoint returns the remote UDP endpoint of the connection. +func (c *Conn) RemoteEndpoint() exonet.Endpoint { + if c == nil { + return nil + } + return c.remoteEndpoint +} + +// ProcessPacket feeds a received packet into the connection (server-side demux path). +// If the connection is not yet established, the packet is queued for handshake processing. +func (c *Conn) ProcessPacket(pkt *Packet) { + if !c.inState(StateEstablished) { + c.sendMu.Lock() + ch := c.inCh + c.sendMu.Unlock() + if ch != nil { + select { + case ch <- pkt: + default: + } + } + return + } + + // Data packet (may include piggyback ACK) + if pkt.Len > 0 { + if pkt.Flags&FlagACK != 0 { + c.HandleAckPacket(pkt) + } + + c.handleDataPacket(pkt) + return + } + + // Pure ACK + if pkt.Flags&FlagACK != 0 { + c.HandleAckPacket(pkt) + return + } + // Control (SYN/FIN/etc.) + if pkt.Flags&(FlagSYN|FlagFIN) != 0 { + c.HandleControlPacket(pkt) + return + } + + // Ignore anything else +} + +// sendDatagram sends a raw packet buffer choosing the correct syscall based on +// connection role (outbound connections use Write on a connected socket; +// inbound connections use WriteToUDP specifying the remote endpoint). +func (c *Conn) sendDatagram(b []byte) (n int, err error) { + if c.udpConn == nil { + return n, udp.ErrConnClosed + } + if c.outbound { + _, err := c.udpConn.Write(b) + return n, err + } + if c.remoteEndpoint == nil { + return n, fmt.Errorf("remote endpoint nil") + } + written, err := c.udpConn.WriteToUDP(b, c.remoteEndpoint.UDPAddr()) + return written, err +} diff --git a/mod/udp/rudp/conn_handshake.go b/mod/udp/rudp/conn_handshake.go new file mode 100644 index 00000000..def2179b --- /dev/null +++ b/mod/udp/rudp/conn_handshake.go @@ -0,0 +1,209 @@ +package rudp + +import ( + "context" + "crypto/rand" + "fmt" + "net" + "time" + + "github.com/cryptopunkscc/astrald/mod/udp" +) + +type ConnState int + +const ( + StateClosed ConnState = iota // no connection / after Close() + StateListen // (server only) waiting for SYN + StateSynSent // client sent SYN, waiting for SYN|ACK + StateSynReceived // server got SYN, sent SYN|ACK, waiting for final ACK + StateEstablished // handshake complete, normal data flow + StateFinSent // Close() called, FIN sent, waiting for ACK + StateFinReceived // FIN received, waiting for local Close() + StateTimeWait // (optional) short wait after FIN to absorb retransmits +) + +func (c *Conn) StartClientHandshake(ctx context.Context) error { + return c.startClientHandshakeDirect(ctx) +} + +// startClientHandshakeDirect performs the 3-way handshake using direct socket reads. +func (c *Conn) startClientHandshakeDirect(ctx context.Context) error { + seq, err := randUint32NZ() + if err != nil { + return fmt.Errorf("failed to generate initial sequence number: %w", err) + } + c.initialSeqNumLocal = seq + c.connID = seq + c.setState(StateSynSent) + + if err := c.sendHandshakeControl(FlagSYN, seq, 0); err != nil { + return err + } + + buf := make([]byte, 1500) + deadlineInterval := 300 * time.Millisecond + + for { + if ctx.Err() != nil { + return udp.ErrHandshakeTimeout + } + _ = c.udpConn.SetReadDeadline(time.Now().Add(deadlineInterval)) + n, addr, err := c.udpConn.ReadFromUDP(buf) + + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Timeout() { + continue + } + continue + } + // Filter by remote IP (ignore port mismatch in case of NAT rebinding; only IP match) + if c.remoteEndpoint != nil && !addr.IP.Equal(net.IP(c.remoteEndpoint.IP)) { + continue + } + if n < 13 { + continue + } + pkt := &Packet{} + if err := pkt.Unmarshal(buf[:n]); err != nil { + continue + } + + // Expect SYN|ACK + if pkt.Flags&(FlagSYN|FlagACK) == (FlagSYN|FlagACK) && pkt.Ack == seq+1 && pkt.Seq != 0 { + c.initialSeqNumRemote = pkt.Seq + // finalize local cumulative ack/next sequence + c.sendMu.Lock() + c.ackedSeqNum = seq + 1 + c.nextSeqNum = seq + 1 + c.sendMu.Unlock() + // Send final ACK + if err := c.SendControlPacket(FlagACK, seq+1, c.initialSeqNumRemote+1); err != nil { + return err + } + c.onEstablished() + // Start recvLoop after establishment for outbound + go c.recvLoop() + return nil + } + // ignore other control/data until handshake completes + } +} + +func (c *Conn) StartServerHandshake(ctx context.Context, synPkt *Packet) error { + c.initialSeqNumRemote = synPkt.Seq + c.connID = synPkt.Seq + seq, err := randUint32NZ() + if err != nil { + return fmt.Errorf("failed to generate initial sequence number: %w", err) + } + c.initialSeqNumLocal = seq + c.setState(StateSynReceived) + + // send SYN|ACK and register for retransmission + err = c.sendHandshakeControl(FlagSYN|FlagACK, c.initialSeqNumLocal, c.initialSeqNumRemote+1) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + return udp.ErrHandshakeTimeout + case pkt := <-c.inCh: + if pkt.Flags&FlagACK != 0 && pkt.Ack == c.initialSeqNumLocal+1 { + // final ACK received + c.sendMu.Lock() + delete(c.unacked, c.initialSeqNumLocal) + if len(c.unacked) == 0 && c.rtxTimer != nil { + c.rtxTimer.Stop() + c.rtxTimer = nil + } + c.ackedSeqNum = c.initialSeqNumLocal + 1 + c.nextSeqNum = c.initialSeqNumLocal + 1 + c.sendMu.Unlock() + c.onEstablished() + // fused receive loop will now dispatch directly + return nil + } + } + } +} + +// sendHandshakeControl builds, sends and registers a handshake control packet for unified retransmissions +func (c *Conn) sendHandshakeControl(flags uint8, seq, ack uint32) error { + pkt := &Packet{Seq: seq, Ack: ack, Flags: flags, Len: 0} + b, err := pkt.Marshal() + if err != nil { + return fmt.Errorf("marshal handshake pkt: %w", err) + } + if c.udpConn == nil { + return udp.ErrConnClosed + } + if c.remoteEndpoint == nil { + return fmt.Errorf("remote endpoint nil") + } + + _, err = c.sendDatagram(b) + if err != nil { + return err + } + + // Register packet and decide if timer start is needed - do this OUTSIDE the lock + needTimer := false + c.sendMu.Lock() + if _, exists := c.unacked[seq]; !exists { // only register first time + c.unacked[seq] = &Unacked{ + pkt: pkt, + sentTime: time.Now(), + rtxCount: 0, + length: 0, + isHandshake: true, + } + if c.rtxTimer == nil { + needTimer = true + } + } + c.sendMu.Unlock() + + // Start timer AFTER releasing lock to avoid deadlock + if needTimer { + c.startRtxTimer() + } + + return nil +} + +// SendControlPacket retained for non-handshake control (e.g., FIN, pure ACK), not tracked +func (c *Conn) SendControlPacket(flags uint8, seq, ack uint32) error { + pkt := &Packet{Seq: seq, Ack: ack, Flags: flags, Len: 0} + data, err := pkt.Marshal() + if err != nil { + return fmt.Errorf(`SendControlPacket failed to marshal control packet: %w`, err) + } + if c.udpConn == nil { + return udp.ErrConnClosed + } + if c.remoteEndpoint == nil { + return fmt.Errorf("remote endpoint nil") + } + _, err = c.sendDatagram(data) + if err != nil { + return fmt.Errorf(`SendControlPacket failed to send control packet: %w`, err) + } + return nil +} + +func randUint32NZ() (uint32, error) { + var b [4]byte + for { + _, err := rand.Read(b[:]) + if err != nil { + return 0, fmt.Errorf("failed to generate random uint32: %w", err) + } + v := uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) + if v != 0 { + return v, nil + } + } +} diff --git a/mod/udp/rudp/fragmenter.go b/mod/udp/rudp/fragmenter.go new file mode 100644 index 00000000..2b822e36 --- /dev/null +++ b/mod/udp/rudp/fragmenter.go @@ -0,0 +1,55 @@ +package rudp + +import "bytes" + +// Fragmenter turns buffered bytes into wire packets and reproduces the exact +// same boundaries for retransmission. +type Fragmenter interface { + // MakeNew decides payload size and builds a new Packet at nextSeq. + // 'allowed' is the sender's remaining window in bytes. + // Returns (packet, payloadLen, ok). ok=false if it chooses not to send (e.g., Nagle). + MakeNew(nextSeq uint32, allowed int, buf *bytes.Buffer) (*Packet, int, bool) +} + +// BasicFragmenter is a simple implementation of Fragmenter that splits a +// buffer into Packets of at most MSS size. +type BasicFragmenter struct { + MSS int +} + +// NewBasicFragmenter creates a new BasicFragmenter with the given maximum +// segment size (MSS). +func NewBasicFragmenter(mss int) *BasicFragmenter { + return &BasicFragmenter{MSS: mss} +} + +// MakeNew implements the Fragmenter interface for BasicFragmenter. +func (f *BasicFragmenter) MakeNew(nextSeq uint32, allowed int, + buf *bytes.Buffer) (*Packet, int, bool) { + if f.MSS <= 0 { + return nil, 0, false + } + if buf.Len() == 0 { + return nil, 0, false + } + if allowed <= 0 { + return nil, 0, false + } + + maxLen := f.MSS + if allowed < maxLen { + maxLen = allowed + } + if buf.Len() < maxLen { + maxLen = buf.Len() + } + + payload := buf.Bytes()[:maxLen] + packet := &Packet{ + Seq: nextSeq, + Len: uint16(len(payload)), + Payload: payload, + Flags: FlagACK, // Data packets should have ACK flag set + } + return packet, len(payload), true +} diff --git a/mod/udp/rudp/fragmenter_test.go b/mod/udp/rudp/fragmenter_test.go new file mode 100644 index 00000000..5ec3c55f --- /dev/null +++ b/mod/udp/rudp/fragmenter_test.go @@ -0,0 +1,201 @@ +package rudp + +import ( + "bytes" + "testing" +) + +func TestBasicFragmenter_SingleFragment(t *testing.T) { + mss := 100 + frag := NewBasicFragmenter(mss) + data := make([]byte, 80) + for i := range data { + data[i] = byte(i) + } + packet, packetLen, ok := frag.MakeNew(0, mss, bytes.NewBuffer(data)) + if !ok { + t.Fatalf("expected ok=true, got false") + } + if packetLen != 80 { + t.Errorf("expected packetLen=80, got %d", packetLen) + } + if packet == nil { + t.Fatalf("expected non-nil packet") + } + if packet.Seq != 0 { + t.Errorf("expected Seq=0, got %d", packet.Seq) + } + if packet.Len != 80 { + t.Errorf("expected Len=80, got %d", packet.Len) + } + for i := range packet.Payload { + if packet.Payload[i] != byte(i) { + t.Errorf("payload mismatch at %d: got %d, want %d", i, packet.Payload[i], byte(i)) + } + } +} + +func TestBasicFragmenter_MultipleFragments(t *testing.T) { + mss := 50 + frag := NewBasicFragmenter(mss) + data := make([]byte, 120) + for i := range data { + data[i] = byte(i) + } + buf := bytes.NewBuffer(data) + + nextSeq := uint32(0) + total := 0 + fragments := 0 + for buf.Len() > 0 { + packet, packetLength, ok := frag.MakeNew(nextSeq, mss, buf) + if !ok { + t.Fatalf("expected ok=true, got false") + } + if packet == nil { + t.Fatalf("expected non-nil packet") + } + if int(packet.Len) > mss { + t.Errorf("fragment too large: got %d, want <= %d", packet.Len, mss) + } + for i := 0; i < int(packet.Len); i++ { + if packet.Payload[i] != byte(int(nextSeq)+i) { + t.Errorf("payload mismatch at %d: got %d, want %d", int(nextSeq)+i, packet.Payload[i], byte(int(nextSeq)+i)) + } + } + n := packetLength + if n > buf.Len() { + n = buf.Len() + } // clamp for safety + _ = buf.Next(n) + + nextSeq += uint32(packetLength) + total += packetLength + fragments++ + } + if total != 120 { + t.Errorf("expected total=120, got %d", total) + } + if fragments != 3 { + t.Errorf("expected fragments=3, got %d", fragments) + } +} + +func TestBasicFragmenter_ZeroLen(t *testing.T) { + mss := 50 + frag := NewBasicFragmenter(mss) + buf := bytes.NewBuffer(nil) + packet, packetLength, ok := frag.MakeNew(0, mss, buf) + if ok { + t.Errorf("expected ok=false for zero-len buffer") + } + if packet != nil { + t.Errorf("expected nil packet for zero-len buffer") + } + if packetLength != 0 { + t.Errorf("expected packetLength=0 for zero-len buffer, got %d", packetLength) + } +} + +func TestBasicFragmenter_AllowedLessThanMSS(t *testing.T) { + mss := 100 + allowed := 40 + frag := NewBasicFragmenter(mss) + data := make([]byte, 80) + for i := range data { + data[i] = byte(i) + } + buf := bytes.NewBuffer(data) + packet, packetLength, ok := frag.MakeNew(0, allowed, buf) + if !ok { + t.Fatalf("expected ok=true, got false") + } + if packetLength != allowed { + t.Errorf("expected packetLength=%d, got %d", allowed, packetLength) + } + if packet.Len != uint16(allowed) { + t.Errorf("expected Len=%d, got %d", allowed, packet.Len) + } +} + +func TestBasicFragmenter_AllowedLessThanBuffer(t *testing.T) { + mss := 100 + allowed := 60 + frag := NewBasicFragmenter(mss) + data := make([]byte, 80) + for i := range data { + data[i] = byte(i) + } + buf := bytes.NewBuffer(data) + packet, packetLength, ok := frag.MakeNew(0, allowed, buf) + if !ok { + t.Fatalf("expected ok=true, got false") + } + if packetLength != allowed { + t.Errorf("expected packetLength=%d, got %d", allowed, packetLength) + } + if packet.Len != uint16(allowed) { + t.Errorf("expected Len=%d, got %d", allowed, packet.Len) + } +} + +func TestBasicFragmenter_BufferSmallerThanAllowedAndMSS(t *testing.T) { + mss := 100 + allowed := 80 + frag := NewBasicFragmenter(mss) + data := make([]byte, 50) + for i := range data { + data[i] = byte(i) + } + buf := bytes.NewBuffer(data) + packet, packetLength, ok := frag.MakeNew(0, allowed, buf) + if !ok { + t.Fatalf("expected ok=true, got false") + } + if packetLength != 50 { + t.Errorf("expected packetLength=50, got %d", packetLength) + } + if packet.Len != 50 { + t.Errorf("expected Len=50, got %d", packet.Len) + } +} + +func TestBasicFragmenter_NegativeOrZeroAllowed(t *testing.T) { + mss := 100 + frag := NewBasicFragmenter(mss) + data := make([]byte, 50) + buf := bytes.NewBuffer(data) + packet, packetLength, ok := frag.MakeNew(0, 0, buf) + if ok || packet != nil || packetLength != 0 { + t.Errorf("expected no packet for allowed=0") + } + packet, packetLength, ok = frag.MakeNew(0, -10, buf) + if ok || packet != nil || packetLength != 0 { + t.Errorf("expected no packet for allowed<0") + } +} + +func TestBasicFragmenter_ZeroMSS(t *testing.T) { + mss := 0 + frag := NewBasicFragmenter(mss) + data := make([]byte, 50) + buf := bytes.NewBuffer(data) + packet, packetLength, ok := frag.MakeNew(0, 100, buf) + if ok || packet != nil || packetLength != 0 { + t.Errorf("expected no packet for MSS=0") + } +} + +func TestBasicFragmenter_FlagsSet(t *testing.T) { + mss := 100 + frag := NewBasicFragmenter(mss) + data := make([]byte, 50) + buf := bytes.NewBuffer(data) + packet, _, ok := frag.MakeNew(0, 100, buf) + if !ok || packet == nil { + t.Fatalf("expected valid packet") + } + if packet.Flags&FlagACK == 0 { + t.Errorf("expected ACK flag set in data packet") + } +} diff --git a/mod/udp/rudp/integration_test.go b/mod/udp/rudp/integration_test.go new file mode 100644 index 00000000..094f0b69 --- /dev/null +++ b/mod/udp/rudp/integration_test.go @@ -0,0 +1,299 @@ +package rudp + +import ( + "context" + "net" + "sort" + "testing" + "time" + + "github.com/cryptopunkscc/astrald/astral" + udpmod "github.com/cryptopunkscc/astrald/mod/udp" +) + +// TestListenerDialHelloWorld exercises a minimal end-to-end handshake and +// one-shot data transfer ("Hello World") between an outbound client Conn +// and an inbound server Conn accepted through Listener.Accept(). +func TestListenerDialHelloWorld(t *testing.T) { + baseCtx := astral.NewContext(context.Background()) + + // Start listener on an IPv4 loopback ephemeral port (force IPv4 to avoid ::/127.0.0.1 mismatch) + l, err := Listen(baseCtx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, Config{}, 2*time.Second) + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + defer l.Close() + + serverAddr := l.Addr().(*net.UDPAddr) + // Force IPv4 127.0.0.1 target (avoid ::1 / unspecified ambiguity) + ipv4Dest := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: serverAddr.Port} + + // Channel to receive accepted server connection + acceptedCh := make(chan *Conn, 1) + + // Accept in background + go func() { + acceptCtx, cancel := baseCtx.WithTimeout(3 * time.Second) + defer cancel() + c, err := l.Accept(acceptCtx) + if err != nil { + return + } + acceptedCh <- c + }() + + // Dial UDP (raw) for outbound side + udpConn, err := net.DialUDP("udp4", nil, ipv4Dest) + if err != nil { + // fallback try udp if udp4 failed + udpConn, err = net.DialUDP("udp", nil, ipv4Dest) + } + if err != nil { + l.Close() + t.Fatalf("DialUDP failed: %v", err) + } + defer udpConn.Close() + + // Build endpoints + localEP, _ := udpmod.ParseEndpoint(udpConn.LocalAddr().String()) + remoteEP, _ := udpmod.ParseEndpoint(udpConn.RemoteAddr().String()) + + // Create outbound reliable Conn + outConn, err := NewConn(udpConn, localEP, remoteEP, Config{}, true, nil, baseCtx) + if err != nil { + t.Fatalf("NewConn outbound failed: %v", err) + } + t.Logf("client local=%v remote=%v", udpConn.LocalAddr(), udpConn.RemoteAddr()) + + // Run client handshake + hCtx, hCancel := baseCtx.WithTimeout(2 * time.Second) + defer hCancel() + if err := outConn.StartClientHandshake(hCtx); err != nil { + outConn.Close() + l.Close() + t.Fatalf("client handshake failed: %v", err) + } + t.Logf("client handshake complete") + + // Wait for server side acceptance + var serverConn *Conn + select { + case serverConn = <-acceptedCh: + if serverConn != nil { + t.Logf("server accepted remote=%v", serverConn.RemoteEndpoint()) + } + case <-time.After(3 * time.Second): + // handshake should have completed well before this + outConn.Close() + l.Close() + t.Fatalf("timeout waiting for Accept()") + } + if serverConn == nil { + outConn.Close() + l.Close() + t.Fatalf("nil serverConn returned") + } + defer serverConn.Close() + + // Send payload client->server + msg := []byte("Hello World") + if _, err := outConn.Write(msg); err != nil { + // ensure cleanup before failing + outConn.Close() + serverConn.Close() + l.Close() + t.Fatalf("client write failed: %v", err) + } + t.Logf("client wrote payload len=%d", len(msg)) + + // Read at server side (no direct read deadline API; use goroutine + timeout) + readCh := make(chan struct{}) + var got []byte + var readErr error + go func() { + b := make([]byte, 64) + if n, err := serverConn.Read(b); err != nil { + readErr = err + } else { + got = append(got, b[:n]...) + } + close(readCh) + }() + select { + case <-readCh: + case <-time.After(10 * time.Second): + outConn.Close() + serverConn.Close() + l.Close() + t.Fatalf("timeout waiting for server read") + } + if readErr != nil { + outConn.Close() + serverConn.Close() + l.Close() + t.Fatalf("server read error: %v", readErr) + } + if string(got) != string(msg) { + outConn.Close() + serverConn.Close() + l.Close() + t.Fatalf("unexpected server payload: got %q want %q", string(got), string(msg)) + } + t.Logf("server read payload: %s", string(got)) + + // (Optional) Echo back from server to client to validate reverse path + if _, err := serverConn.Write([]byte("ACK")); err == nil { + // attempt client read with timeout + clientReadCh := make(chan []byte, 1) + go func() { + b := make([]byte, 8) + if n, e := outConn.Read(b); e == nil { + clientReadCh <- b[:n] + } + close(clientReadCh) + }() + select { + case resp := <-clientReadCh: + if len(resp) > 0 && string(resp) != "ACK" { + // Non-fatal; log mismatch + // t.Logf("unexpected client echo: %q", string(resp)) + } + case <-time.After(2 * time.Second): + // ignore echo timeout (non-fatal for core test) + } + } +} + +// TestDiagFirst32Packets performs a focused diagnostic of the first ~32 data packets +// to inspect sequence alignment between sender and receiver. It writes exactly +// 32 * MSS bytes (fragmented into MSS-sized packets) and then logs the receiver's +// expected sequence, number of in-order bytes buffered, out-of-order queue size, +// and sample sequence gaps. This is a non-fatal diagnostic (will t.Skip on environments +// where it cannot bind or handshake cleanly). +func TestDiagFirst32Packets(t *testing.T) { + // Keep this test quick. + baseCtx := astral.NewContext(context.Background()) + + // Custom config to keep things deterministic and fast. + cfg := Config{ + MaxSegmentSize: DefaultMSS, // 1187 + MaxWindowPackets: 128, // allow >32 easily + AckDelay: time.Microsecond, // effectively immediate + RecvBufBytes: 1 << 20, + SendBufBytes: 1 << 20, + } + cfg.Normalize() + + l, err := Listen(baseCtx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, cfg, 2*time.Second) + if err != nil { + t.Skipf("listener setup failed (skip diag): %v", err) + } + defer l.Close() + + serverAddr := l.Addr().(*net.UDPAddr) + ipv4Dest := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: serverAddr.Port} + + acceptedCh := make(chan *Conn, 1) + go func() { + acceptCtx, cancel := baseCtx.WithTimeout(2 * time.Second) + defer cancel() + c, aerr := l.Accept(acceptCtx) + if aerr == nil { + acceptedCh <- c + } + }() + + udpConn, err := net.DialUDP("udp4", nil, ipv4Dest) + if err != nil { + udpConn, err = net.DialUDP("udp", nil, ipv4Dest) + } + if err != nil { + l.Close() + t.Skipf("dial failed (skip diag): %v", err) + } + defer udpConn.Close() + + localEP, _ := udpmod.ParseEndpoint(udpConn.LocalAddr().String()) + remoteEP, _ := udpmod.ParseEndpoint(udpConn.RemoteAddr().String()) + + outConn, err := NewConn(udpConn, localEP, remoteEP, cfg, true, nil, baseCtx) + if err != nil { + t.Skipf("NewConn outbound failed (skip diag): %v", err) + } + defer outConn.Close() + + hCtx, hCancel := baseCtx.WithTimeout(2 * time.Second) + if err := outConn.StartClientHandshake(hCtx); err != nil { + hCancel() + t.Skipf("handshake failed (skip diag): %v", err) + } + hCancel() + + var serverConn *Conn + select { + case serverConn = <-acceptedCh: + case <-time.After(2 * time.Second): + t.Skip("server accept timeout (skip diag)") + } + if serverConn == nil { + t.Skip("nil serverConn (skip diag)") + } + defer serverConn.Close() + + // Prepare exactly 32 * MSS bytes. + packets := 32 + bytesToSend := packets * cfg.MaxSegmentSize + payload := make([]byte, bytesToSend) + for i := range payload { + payload[i] = byte(i) + } + + written := 0 + start := time.Now() + for written < bytesToSend { + n, werr := outConn.Write(payload[written:]) + if werr != nil { + t.Fatalf("write error after %d bytes: %v", written, werr) + } + written += n + } + elapsedWrite := time.Since(start) + // Allow receiver a short window to process. + time.Sleep(50 * time.Millisecond) + + // Snapshot receiver internal state. + serverConn.recvMu.Lock() + expected := serverConn.expected + recvLen := serverConn.recvRB.Length() + oosz := len(serverConn.recvOO) + // Collect first few out-of-order keys + keys := make([]uint32, 0, oosz) + for k := range serverConn.recvOO { + keys = append(keys, k) + } + serverConn.recvMu.Unlock() + if len(keys) > 0 { + sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] }) + } + if len(keys) > 16 { + keys = keys[:16] + } + + // Log sender side nextSeq / acked + serverConn.sendMu.Lock() + acked := serverConn.ackedSeqNum + next := serverConn.nextSeqNum + serverConn.sendMu.Unlock() + + clientNext := outConn.nextSeqNum + clientAcked := outConn.ackedSeqNum + + t.Logf("diag: wrote=%d bytes (32*MSS=%d) writeElapsed=%v", written, bytesToSend, elapsedWrite) + t.Logf("diag: server expected=%d recvLen=%d recvOO=%d (firstOO=%v)", expected, recvLen, oosz, keys) + t.Logf("diag: server acked=%d next=%d | client acked=%d next=%d", acked, next, clientAcked, clientNext) + + // Basic assertion: at least one in-order advancement OR explicit log for gap. + if recvLen == 0 { + t.Logf("diag: WARNING no in-order data buffered; likely initial gap (expected not reached)") + } +} diff --git a/mod/udp/rudp/listener.go b/mod/udp/rudp/listener.go new file mode 100644 index 00000000..8623e659 --- /dev/null +++ b/mod/udp/rudp/listener.go @@ -0,0 +1,153 @@ +package rudp + +import ( + "net" + "sync" + "sync/atomic" + "time" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/udp" +) + +const defaultAcceptBacklog = 32 +const defaultHandshakeTimeout = 5 * time.Second + +// Listener accepts inbound RUDP connections (server side) and returns only +// fully established connections via Accept(). +type Listener struct { + udpConn *net.UDPConn + cfg Config + baseCtx *astral.Context + mu sync.Mutex + conns map[string]*Conn // remoteKey -> Conn (includes handshaking ones) + acceptCh chan *Conn + closed atomic.Bool +} + +// Listen creates a new Listener bound to addr. handshakeTimeout==0 falls back to a small default. +func Listen(ctx *astral.Context, addr *net.UDPAddr, cfg Config, handshakeTimeout time.Duration) (*Listener, error) { + conn, err := net.ListenUDP("udp", addr) + if err != nil { + return nil, err + } + cfg.Normalize() + if handshakeTimeout <= 0 { + handshakeTimeout = defaultHandshakeTimeout + } + l := &Listener{ + udpConn: conn, + cfg: cfg, + baseCtx: ctx, + conns: make(map[string]*Conn), + acceptCh: make(chan *Conn, defaultAcceptBacklog), + } + go l.readLoop(handshakeTimeout) + return l, nil +} + +// Addr returns the underlying listening address. +func (l *Listener) Addr() net.Addr { return l.udpConn.LocalAddr() } + +// Accept blocks until an established connection is available, the context is canceled, or the listener is closed. +func (l *Listener) Accept(ctx *astral.Context) (*Conn, error) { + for { + if l.closed.Load() { + return nil, udp.ErrListenerClosed + } + select { + case c, ok := <-l.acceptCh: + if !ok { + return nil, udp.ErrListenerClosed + } + return c, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +// Close shuts down the listener and all active connections. +func (l *Listener) Close() error { + if !l.closed.CompareAndSwap(false, true) { + return udp.ErrListenerClosed + } + // Closing the UDP socket unblocks readLoop. + _ = l.udpConn.Close() + l.mu.Lock() + for _, c := range l.conns { + c.Close() + } + l.conns = nil + l.mu.Unlock() + close(l.acceptCh) + return nil +} + +// readLoop performs demultiplexing and inbound connection setup. +func (l *Listener) readLoop(handshakeTimeout time.Duration) { + buf := make([]byte, 64*1024) + for { + if l.closed.Load() { + return + } + n, addr, err := l.udpConn.ReadFromUDP(buf) + if err != nil { + if l.closed.Load() { + return + } + continue + } + if n < 13 { // minimal header length + continue + } + pkt := &Packet{} + if err := pkt.Unmarshal(buf[:n]); err != nil { + continue + } + + remoteKey := addr.String() + l.mu.Lock() + conn := l.conns[remoteKey] + if conn == nil && pkt.Flags&FlagSYN != 0 { // new inbound attempt + remoteEP, perr := udp.ParseEndpoint(addr.String()) + if perr != nil { + l.mu.Unlock() + continue + } + localEP, _ := udp.ParseEndpoint(l.udpConn.LocalAddr().String()) + // Per-handshake timeout context derived from baseCtx + hCtx, cancel := l.baseCtx.WithTimeout(handshakeTimeout) + c, cerr := NewConn(l.udpConn, localEP, remoteEP, l.cfg, false, pkt, hCtx) + if cerr != nil { // immediate constructor error; cancel context + cancel() + l.mu.Unlock() + continue + } + c.OnEstablished(func(ec *Conn) { + if l.closed.Load() { + cancel() + return + } + select { + case l.acceptCh <- ec: + default: + } + cancel() + }) + c.OnClosed(func(ec *Conn) { + l.mu.Lock() + delete(l.conns, remoteKey) + l.mu.Unlock() + cancel() + }) + l.conns[remoteKey] = c + conn = c + } + l.mu.Unlock() + + if conn != nil { + conn.ProcessPacket(pkt) + } + } +} diff --git a/mod/udp/rudp/packet.go b/mod/udp/rudp/packet.go new file mode 100644 index 00000000..9d0de18b --- /dev/null +++ b/mod/udp/rudp/packet.go @@ -0,0 +1,119 @@ +package rudp + +import ( + "bytes" + "encoding/binary" + "time" + + "github.com/cryptopunkscc/astrald/mod/udp" +) + +const ( + FlagSYN = 1 << 0 + FlagACK = 1 << 1 + FlagFIN = 1 << 2 +) + +// Packet represents a src UDP packet with TCP-like header +// Handshake usage: +// +// ConnID: carried in Seq (ISN) +// Seq: initial sequence number (ISN) +// Ack: cumulative acknowledgment of peer's ISN+1 +// Flags: SYN, ACK, FIN (DATA if Len > 0) +// Len: 0 for control packets (SYN, SYN|ACK, ACK, FIN) +type Packet struct { + Seq uint32 // Sequence number (first byte seq of this segment) + Ack uint32 // Acknowledgment number (cumulative ack: all bytes < Ack received) + Flags uint8 // Bit flags: SYN, ACK, FIN, etc. + Win uint16 // Advertised receive window in BYTES + Len uint16 // Payload length + Payload []byte // Actual data +} + +// Marshal serializes the Packet into bytes for transmission +// Format: [Seq:4][Ack:4][Flags:1][Win:2][Len:2][Payload:N] +func (p *Packet) Marshal() ([]byte, error) { + buf := new(bytes.Buffer) + if err := binary.Write(buf, binary.BigEndian, p.Seq); err != nil { + return nil, err + } + if err := binary.Write(buf, binary.BigEndian, p.Ack); err != nil { + return nil, err + } + if err := buf.WriteByte(p.Flags); err != nil { + return nil, err + } + if err := binary.Write(buf, binary.BigEndian, p.Win); err != nil { + return nil, err + } + if err := binary.Write(buf, binary.BigEndian, p.Len); err != nil { + return nil, err + } + if _, err := buf.Write(p.Payload); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// Unmarshal parses bytes into the Packet struct fields +func (p *Packet) Unmarshal(data []byte) error { + if len(data) < 13 { // 4+4+1+2+2 = 13 bytes header + return udp.ErrPacketTooShort + } + + p.Seq = binary.BigEndian.Uint32(data[0:4]) + p.Ack = binary.BigEndian.Uint32(data[4:8]) + p.Flags = data[8] + p.Win = binary.BigEndian.Uint16(data[9:11]) + p.Len = binary.BigEndian.Uint16(data[11:13]) + + // Validate payload length + if int(p.Len) != len(data)-13 { + return udp.ErrInvalidPayloadLength + } + + p.Payload = data[13:] + return nil +} + +// UnmarshalPacket parses bytes into a Packet instance. +func UnmarshalPacket(data []byte) (*Packet, error) { + if len(data) < 13 { // Minimum header size: 4+4+1+2+2 + return nil, udp.ErrMalformedPacket + } + + pkt := &Packet{} + r := bytes.NewReader(data) + if err := binary.Read(r, binary.BigEndian, &pkt.Seq); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &pkt.Ack); err != nil { + return nil, err + } + flags, err := r.ReadByte() + if err != nil { + return nil, err + } + pkt.Flags = flags + if err := binary.Read(r, binary.BigEndian, &pkt.Win); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &pkt.Len); err != nil { + return nil, err + } + if int(pkt.Len) > len(data)-13 { + return nil, udp.ErrMalformedPacket + } + pkt.Payload = make([]byte, pkt.Len) + if _, err := r.Read(pkt.Payload); err != nil { + return nil, err + } + + return pkt, nil +} + +type SentPacket struct { + pkt *Packet + sentTime time.Time +} diff --git a/mod/udp/rudp/packet_test.go b/mod/udp/rudp/packet_test.go new file mode 100644 index 00000000..f4e2760e --- /dev/null +++ b/mod/udp/rudp/packet_test.go @@ -0,0 +1,74 @@ +package rudp + +import ( + "bytes" + "testing" +) + +func TestPacketMarshalUnmarshal(t *testing.T) { + tests := []struct { + name string + packet Packet + expectErr bool + }{ + { + name: "Valid Packet", + packet: Packet{ + Seq: 1, + Ack: 2, + Flags: 0x01, + Win: 1024, + Len: 5, + Payload: []byte("hello"), + }, + expectErr: false, + }, + { + name: "Empty Payload", + packet: Packet{ + Seq: 10, + Ack: 20, + Flags: 0x02, + Win: 2048, + Len: 0, + Payload: []byte{}, + }, + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test Marshal + data, err := tt.packet.Marshal() + if (err != nil) != tt.expectErr { + t.Fatalf("Marshal error = %v, expectErr = %v", err, tt.expectErr) + } + + // Test Unmarshal + var unmarshaled Packet + err = unmarshaled.Unmarshal(data) + if (err != nil) != tt.expectErr { + t.Fatalf("Unmarshal error = %v, expectErr = %v", err, tt.expectErr) + } + + // Verify the unmarshaled packet matches the original + if !bytes.Equal(unmarshaled.Payload, tt.packet.Payload) || + unmarshaled.Seq != tt.packet.Seq || + unmarshaled.Ack != tt.packet.Ack || + unmarshaled.Flags != tt.packet.Flags || + unmarshaled.Win != tt.packet.Win || + unmarshaled.Len != tt.packet.Len { + t.Errorf("Unmarshaled packet does not match original. Got %+v, want %+v", unmarshaled, tt.packet) + } + }) + } +} + +func TestUnmarshalPacket(t *testing.T) { + data := []byte("testdata") + _, err := UnmarshalPacket(data) + if err == nil { + t.Error("Expected error for invalid packet data, got nil") + } +} diff --git a/mod/udp/rudp/receive.go b/mod/udp/rudp/receive.go new file mode 100644 index 00000000..0663914c --- /dev/null +++ b/mod/udp/rudp/receive.go @@ -0,0 +1,162 @@ +package rudp + +import ( + "net" + "time" +) + +// recvLoop fuses raw UDP receive and packet dispatch. During handshake (state < Established) +// packets are delivered into inCh for the blocking handshake loops. After establishment, +// packets are dispatched directly to the appropriate handlers without using inCh. +func (c *Conn) recvLoop() { + const maxPayloadSize = 64 * 1024 + buf := make([]byte, maxPayloadSize) + for { + if c.isClosed() { + return + } + if err := c.udpConn.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { + continue + } + n, addr, err := c.udpConn.ReadFromUDP(buf) + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Timeout() { + continue + } + if c.isClosed() { + return + } + continue + } + // address filter (compare only IP; optional port check commented for NAT flexibility) + if !addr.IP.Equal(net.IP(c.remoteEndpoint.IP)) { + continue + } + + if n < 13 { // minimum header size + continue + } + packetData := make([]byte, n) + copy(packetData, buf[:n]) + pkt := &Packet{} + if err := pkt.Unmarshal(packetData); err != nil { + continue + } + if int(pkt.Len) > maxPayloadSize { // sanity + continue + } + + // Handshake phase: deliver via channel for Start*Handshake loops + if !c.inState(StateEstablished) { + select { + case c.inCh <- pkt: + default: // if channel full (unlikely in handshake), drop + } + continue + } + + // Established: direct dispatch + // Handle data packets first (they may have ACK piggybacked) + if pkt.Len > 0 { + // Process piggyback ACK if present + if pkt.Flags&FlagACK != 0 { + c.HandleAckPacket(pkt) + } + // Process data payload + c.handleDataPacket(pkt) + continue + } + + // Pure ACK packets (no data payload) + if pkt.Flags&FlagACK != 0 { + c.HandleAckPacket(pkt) + continue + } + + // Pure control packets (SYN, FIN, etc.) + if pkt.Flags&(FlagSYN|FlagFIN) != 0 { + c.HandleControlPacket(pkt) + continue + } + } +} + +func (c *Conn) handleDataPacket(packet *Packet) { + if packet.Len == 0 { // ignore empty as data + return + } + seq := packet.Seq + plen := uint32(packet.Len) + + c.recvMu.Lock() + exp := c.expected + ackDelay := c.cfg.AckDelay // snapshot once inside lock + + switch { + case seq < exp: // duplicate / already received + c.queueAckLocked() + c.recvMu.Unlock() + c.triggerAck(ackDelay) + return + + case seq == exp: // in-order + if int(packet.Len) > int(c.recvRB.Free()) { + // No space -> request ACK (window advertisement) and drop + c.queueAckLocked() + c.recvMu.Unlock() + c.triggerAck(ackDelay) + return + } + // Write this packet fully + if n, _ := c.recvRB.Write(packet.Payload); n != int(packet.Len) { + // Failed partial write (shouldn't happen with ringbuffer); request ACK and abort + c.queueAckLocked() + c.recvMu.Unlock() + c.triggerAck(ackDelay) + return + } + exp += plen + c.expected = exp + // Drain any now-contiguous buffered packets + for { + nextPkt, ok := c.recvOO[exp] + if !ok { + break + } + if int(nextPkt.Len) > int(c.recvRB.Free()) { + // Break if capacity insufficient; will retry on next in-order arrival + break + } + if n, _ := c.recvRB.Write(nextPkt.Payload); n != int(nextPkt.Len) { + // On partial write, stop draining to preserve consistency + break + } + delete(c.recvOO, exp) + exp += uint32(nextPkt.Len) + c.expected = exp + } + c.queueAckLocked() + c.recvCond.Broadcast() + c.recvMu.Unlock() + + // mirror expected -> ackedSeqNum for piggyback + c.sendMu.Lock() + if c.ackedSeqNum < exp { + c.ackedSeqNum = exp + } + c.sendMu.Unlock() + c.triggerAck(ackDelay) + return + + default: // seq > exp (future / out-of-order) + if int(packet.Len) <= int(c.recvRB.Free()) { + if _, exists := c.recvOO[seq]; !exists { + c.recvOO[seq] = packet + } + } + c.queueAckLocked() // request duplicate ACK + c.recvMu.Unlock() + c.triggerAck(ackDelay) + return + } +} diff --git a/mod/udp/rudp/retransmissions.go b/mod/udp/rudp/retransmissions.go new file mode 100644 index 00000000..043a5579 --- /dev/null +++ b/mod/udp/rudp/retransmissions.go @@ -0,0 +1,161 @@ +package rudp + +import ( + "sort" + "time" + + "github.com/cryptopunkscc/astrald/mod/udp" +) + +type Unacked struct { + pkt *Packet // Packet metadata (seq, len) + sentTime time.Time // Last sent time + rtxCount int // Retransmit count + length int // Payload length + isHandshake bool // True if this entry is for a handshake control packet +} + +// startRtxTimer arms the retransmission timer if not already running +func (c *Conn) startRtxTimer() { + c.sendMu.Lock() + if c.rtxTimer != nil { + c.sendMu.Unlock() + return + } + interval := c.cfg.RetransmissionInterval + c.rtxTimer = time.AfterFunc(interval, func() { + c.handleRetransmissionTimeout() + c.sendMu.Lock() + if len(c.unacked) > 0 && !c.isClosed() { + c.rtxTimer.Reset(interval) + } else { + c.rtxTimer = nil + } + c.sendMu.Unlock() + }) + c.sendMu.Unlock() +} + +// queueAckLocked marks that an ACK should be (re)sent. Caller must hold recvMu. +func (c *Conn) queueAckLocked() { c.ackPending = true } + +// triggerAck decides immediate vs delayed ACK after recvMu has been released. +func (c *Conn) triggerAck(ackDelay time.Duration) { + if ackDelay == 0 { + c.sendPureACK() + } else { + c.scheduleAck() + } +} + +// scheduleAck sets / resets delayed ACK timer +func (c *Conn) scheduleAck() { + c.recvMu.Lock() + if !c.ackPending || c.isClosed() { + c.recvMu.Unlock() + return + } + d := c.cfg.AckDelay + if d <= 0 { + ackNeeded := c.ackPending + c.ackPending = false + c.recvMu.Unlock() + if ackNeeded { + c.sendPureACK() + } + return + } + if c.ackTimer != nil { + c.ackTimer.Reset(d) + } else { + c.ackTimer = time.AfterFunc(d, c.fireAck) + } + c.recvMu.Unlock() +} + +func (c *Conn) fireAck() { + c.recvMu.Lock() + if !c.ackPending || c.isClosed() { + c.recvMu.Unlock() + return + } + c.ackPending = false + c.recvMu.Unlock() + c.sendPureACK() +} + +// sendPureACK sends a standalone ACK reflecting current expected sequence. +func (c *Conn) sendPureACK() { + if !c.inState(StateEstablished) || c.isClosed() { + return + } + // snapshot expected & window + c.recvMu.Lock() + exp := c.expected + winFree := uint32(0) + if c.recvRB != nil { + winFree = uint32(c.recvRB.Free()) + } + c.recvMu.Unlock() + // clamp window to 16-bit + win := uint16(0) + if winFree > 0xFFFF { + win = 0xFFFF + } else { + win = uint16(winFree) + } + pkt := &Packet{Seq: 0, Ack: exp, Flags: FlagACK, Win: win, Len: 0} + b, err := pkt.Marshal() + if err != nil { + return + } + // best-effort send via unified path + _, _ = c.sendDatagram(b) +} + +// handleRetransmissionTimeoutLocked assumes sendMu is held and performs retransmissions. +// Returns true if the retransmission limit was exceeded for any packet. +func (c *Conn) handleRetransmissionTimeoutLocked() (limitExceeded bool) { + if len(c.unacked) == 0 { + return false + } + + seqs := make([]uint32, 0, len(c.unacked)) + for s := range c.unacked { + seqs = append(seqs, s) + } + sort.Slice(seqs, func(i, j int) bool { return seqs[i] < seqs[j] }) + + for _, s := range seqs { + u := c.unacked[s] + if u.rtxCount >= c.cfg.RetransmissionLimit { + limitExceeded = true + break + } + + // Update ACK field to latest cumulative ACK and retransmit + u.pkt.Ack = c.ackedSeqNum + b, err := u.pkt.Marshal() + if err == nil { + _, _ = c.sendDatagram(b) + } + u.rtxCount++ + u.sentTime = time.Now() + } + return +} + +// handleRetransmissionTimeout provides backward-compatible external behavior (locking & close semantics) +func (c *Conn) handleRetransmissionTimeout() { + c.sendMu.Lock() + limitExceeded := c.handleRetransmissionTimeoutLocked() + c.sendMu.Unlock() + if limitExceeded { + select { + case c.ErrChan <- udp.ErrRetransmissionLimitExceeded: + default: + } + c.Close() + close(c.ErrChan) + } +} diff --git a/mod/udp/rudp/send.go b/mod/udp/rudp/send.go new file mode 100644 index 00000000..e8d86687 --- /dev/null +++ b/mod/udp/rudp/send.go @@ -0,0 +1,108 @@ +package rudp + +import ( + "time" +) + +// Write enqueues data into the send ring buffer. It blocks until enough space is available. +// Signals the sender goroutine after enqueue. +func (c *Conn) writeSend(p []byte) (n int, err error) { + c.sendMu.Lock() + defer c.sendMu.Unlock() + writeLen := len(p) + for writeLen > int(c.sendRB.Free()) { + c.sendCond.Wait() + } + _, err = c.sendRB.Write(p) + if err != nil { + return 0, err + } + // Only the senderLoop needs to be awakened; writers waiting for space are not helped by a write. + c.sendCond.Signal() + return writeLen, nil +} + +// windowFull returns true if the packet window is full (no more packets can be sent). +func (c *Conn) windowFull() bool { + return len(c.unacked) >= c.cfg.MaxWindowPackets +} + +// commitPacketLocked registers the packet as unacked and advances sequence numbers. Caller holds sendMu. +func (c *Conn) commitPacketLocked(pkt *Packet) (startTimer bool) { + seq := pkt.Seq + c.unacked[seq] = &Unacked{pkt: pkt, sentTime: time.Now(), rtxCount: 0, length: int(pkt.Len)} + c.nextSeqNum += uint32(pkt.Len) + return len(c.unacked) == 1 +} + +// armRetransmitTimer starts retransmission timer if needed (no lock held). +func (c *Conn) armRetransmitTimer(need bool) { + if need { + c.startRtxTimer() + } +} + +// senderLoop runs as a goroutine and is responsible for sending packets from the send buffer. +// Inlined fragmentation & packet build: we avoid an extra lock/unlock by performing +// (select bytes -> copy -> build packet -> commit) under one critical section, then +// releasing the lock before marshaling and sending. +// FIFO model: bytes are consumed from sendRB as soon as they are packetized. Retransmissions +// use copies stored in unacked map. No random access over the ring is performed. +func (c *Conn) senderLoop() { + defer func() { c.sendCond.Broadcast() }() + for { + c.sendMu.Lock() + for { + if c.isClosed() || !c.inState(StateEstablished) { + c.sendMu.Unlock() + return + } + if c.sendRB.Length() > 0 && !c.windowFull() { + break + } + c.sendCond.Wait() + } + + // Decide fragment size and drain payload while still holding the lock. + ask := int(c.sendRB.Length()) + if ask > c.cfg.MaxSegmentSize { + ask = c.cfg.MaxSegmentSize + } + // Allocate buffer and read from ring (destructive read). + payload := make([]byte, ask) + readN, _ := c.sendRB.Read(payload) + if readN == 0 { // nothing actually read; loop and re-evaluate predicates + c.sendMu.Unlock() + continue + } + if readN != ask { // shrink payload if ring gave us fewer bytes + payload = payload[:readN] + } + + seq := c.nextSeqNum + pkt := &Packet{Seq: seq, Ack: c.ackedSeqNum, Flags: FlagACK, Len: uint16(len(payload)), Payload: payload} + startTimer := c.commitPacketLocked(pkt) + // We freed ring space; signal at least one waiting writer or (rarely) another waiter. + c.sendCond.Signal() + c.sendMu.Unlock() + + // Manage retransmission timer outside lock + c.armRetransmitTimer(startTimer) + + // Marshal & send outside lock to minimize critical section time. + raw, err := pkt.Marshal() + if err != nil { + select { + case c.ErrChan <- err: + default: + } + c.Close() + return + } + if _, err := c.sendDatagram(raw); err != nil { + // Suppress error (packet retained in unacked for retransmission). Continue loop. + continue + } + // next iteration + } +} diff --git a/mod/udp/src/config.go b/mod/udp/src/config.go new file mode 100644 index 00000000..b4bd5bbe --- /dev/null +++ b/mod/udp/src/config.go @@ -0,0 +1,19 @@ +package udp + +import ( + "time" + + "github.com/cryptopunkscc/astrald/mod/udp/rudp" +) + +type Config struct { + ListenPort int `yaml:"listen_port,omitempty"` // Port to listen on for incoming connections (default 1791) + PublicEndpoints []string `yaml:"public_endpoints,omitempty"` + DialTimeout time.Duration `yaml:"dial_timeout,omitempty"` // Timeout for dialing connections (default 1 minute) + TransportConfig rudp.Config `yaml:"transport_config,omitempty"` // Flow control settings for UDP connections +} + +var defaultConfig = Config{ + DialTimeout: time.Minute, + ListenPort: 1791, +} diff --git a/mod/udp/src/deps.go b/mod/udp/src/deps.go new file mode 100644 index 00000000..24afdf34 --- /dev/null +++ b/mod/udp/src/deps.go @@ -0,0 +1,29 @@ +package udp + +import ( + "github.com/cryptopunkscc/astrald/core" + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/nodes" + "github.com/cryptopunkscc/astrald/mod/objects" +) + +// Deps represents the dependencies required by the src UDP module. +type Deps struct { + Exonet exonet.Module + Nodes nodes.Module + Objects objects.Module +} + +func (mod *Module) LoadDependencies() (err error) { + err = core.Inject(mod.node, &mod.Deps) + if err != nil { + return + } + + mod.Exonet.SetDialer("udp", mod) + mod.Exonet.SetParser("udp", mod) + mod.Exonet.SetUnpacker("udp", mod) + mod.Nodes.AddResolver(mod) + + return +} diff --git a/mod/udp/src/dial.go b/mod/udp/src/dial.go new file mode 100644 index 00000000..15fc615f --- /dev/null +++ b/mod/udp/src/dial.go @@ -0,0 +1,57 @@ +package udp + +import ( + "net" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/udp" + "github.com/cryptopunkscc/astrald/mod/udp/rudp" +) + +var _ exonet.Dialer = &Module{} + +// Dial establishes a reliable (rudp) connection and returns it only after the +// RUDP client handshake succeeds. Timeout / cancellation behavior: +// - If the caller's context has a deadline, it governs both dial and handshake. +// - Otherwise net.Dialer.Timeout (DialTimeout in config, if >0) limits only the dial phase. +// - The handshake then reuses the original context (may block if no deadline provided). +func (mod *Module) Dial(ctx *astral.Context, endpoint exonet.Endpoint) (exonet.Conn, error) { + switch endpoint.Network() { + case "udp": + default: + return nil, exonet.ErrUnsupportedNetwork + } + + dialer := net.Dialer{} + if mod.config.DialTimeout > 0 { + dialer.Timeout = mod.config.DialTimeout + } + + conn, err := dialer.DialContext(ctx, "udp", endpoint.Address()) + if err != nil { + return nil, err + } + + localEndpoint, _ := udp.ParseEndpoint(conn.LocalAddr().String()) + remoteEndpoint, _ := udp.ParseEndpoint(conn.RemoteAddr().String()) + + udpConn, ok := conn.(*net.UDPConn) + if !ok { + return nil, exonet.ErrUnsupportedNetwork + } + + reliableConn, err := rudp.NewConn(udpConn, localEndpoint, remoteEndpoint, + mod.config.TransportConfig, true, nil, ctx) + if err != nil { + return nil, err + } + + err = reliableConn.StartClientHandshake(ctx) + if err != nil { + reliableConn.Close() + return nil, err + } + + return reliableConn, nil +} diff --git a/mod/udp/src/endpoint_resolver.go b/mod/udp/src/endpoint_resolver.go new file mode 100644 index 00000000..92ec39bb --- /dev/null +++ b/mod/udp/src/endpoint_resolver.go @@ -0,0 +1,22 @@ +package udp + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/sig" +) + +// NOTE: Should we expose our UDP endpoints the same way we expose TCP ones? + +func (mod *Module) ResolveEndpoints(ctx *astral.Context, nodeID *astral.Identity) (_ <-chan exonet.Endpoint, err error) { + if !nodeID.IsEqual(mod.node.Identity()) { + return sig.ArrayToChan([]exonet.Endpoint{}), nil + } + + var all []exonet.Endpoint + + all = append(all, mod.publicEndpoints...) + all = append(all, mod.localEndpoints()...) + + return sig.ArrayToChan(all), nil +} diff --git a/mod/udp/src/loader.go b/mod/udp/src/loader.go new file mode 100644 index 00000000..540b3022 --- /dev/null +++ b/mod/udp/src/loader.go @@ -0,0 +1,40 @@ +package udp + +import ( + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/log" + "github.com/cryptopunkscc/astrald/core" + "github.com/cryptopunkscc/astrald/core/assets" + "github.com/cryptopunkscc/astrald/mod/udp" +) + +type Loader struct{} + +func (Loader) Load(node astral.Node, assets assets.Assets, l *log.Logger) (core.Module, error) { + mod := &Module{ + node: node, + log: l, + config: defaultConfig, + } + + _ = assets.LoadYAML(udp.ModuleName, &mod.config) + + // Parse public endpoints + for _, pe := range mod.config.PublicEndpoints { + endpoint, err := udp.ParseEndpoint(pe) + if err != nil { + l.Error("error parsing public endpoint \"%v\": %v", pe, err) + continue + } + + mod.publicEndpoints = append(mod.publicEndpoints, endpoint) + } + + return mod, nil +} + +func init() { + if err := core.RegisterModule(udp.ModuleName, Loader{}); err != nil { + panic(err) + } +} diff --git a/mod/udp/src/module.go b/mod/udp/src/module.go new file mode 100644 index 00000000..cb043d20 --- /dev/null +++ b/mod/udp/src/module.go @@ -0,0 +1,81 @@ +package udp + +import ( + "context" + "net" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/astral/log" + "github.com/cryptopunkscc/astrald/core/assets" + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/udp" + "github.com/cryptopunkscc/astrald/tasks" +) + +// Module represents the UDP module and implements the exonet.Dialer interface. +type Module struct { + Deps + config Config // Configuration for the module + node astral.Node + assets assets.Assets + log *log.Logger + ctx context.Context + publicEndpoints []exonet.Endpoint +} + +func (mod *Module) Run(ctx *astral.Context) error { + mod.ctx = ctx + + err := tasks.Group(NewServer(mod)).Run(ctx) + if err != nil { + return err + } + + <-ctx.Done() + + return nil +} + +func (mod *Module) ListenPort() int { + return mod.config.ListenPort +} + +func (mod *Module) localIPs() ([]udp.IP, error) { + list := make([]udp.IP, 0) + + ifaceAddrs, err := InterfaceAddrs() + if err != nil { + return nil, err + } + + for _, a := range ifaceAddrs { + ipnet, ok := a.(*net.IPNet) + if !ok { + continue + } + ip := udp.IP(ipnet.IP) + list = append(list, ip) + } + + return list, nil +} + +func (mod *Module) localEndpoints() (list []exonet.Endpoint) { + ips, err := mod.localIPs() + if err != nil { + return + } + + for _, ip := range ips { + if ip.IsLoopback() { + continue + } + if ip.IsGlobalUnicast() || ip.IsPrivate() { + list = append(list, &udp.Endpoint{ + IP: ip, + Port: astral.Uint16(mod.ListenPort()), + }) + } + } + return +} diff --git a/mod/udp/src/net.go b/mod/udp/src/net.go new file mode 100644 index 00000000..9b37ab60 --- /dev/null +++ b/mod/udp/src/net.go @@ -0,0 +1,7 @@ +package udp + +import "net" + +// NOTE: same as in tcp module - consider moving to a common package + +var InterfaceAddrs = net.InterfaceAddrs diff --git a/mod/udp/src/parse.go b/mod/udp/src/parse.go new file mode 100644 index 00000000..a67b7b35 --- /dev/null +++ b/mod/udp/src/parse.go @@ -0,0 +1,16 @@ +package udp + +import ( + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/udp" +) + +func (mod *Module) Parse(network string, address string) (exonet.Endpoint, error) { + switch network { + case "udp": + default: + return nil, exonet.ErrUnsupportedNetwork + } + + return udp.ParseEndpoint(address) +} diff --git a/mod/udp/src/server.go b/mod/udp/src/server.go new file mode 100644 index 00000000..3d105577 --- /dev/null +++ b/mod/udp/src/server.go @@ -0,0 +1,69 @@ +package udp + +import ( + "net" + "time" + + "github.com/cryptopunkscc/astrald/astral" + "github.com/cryptopunkscc/astrald/mod/udp/rudp" +) + +// Server implements UDP listening with connection acceptance via rudp.Listener +type Server struct { + *Module + rListener *rudp.Listener + acceptCh chan *rudp.Conn +} + +// NewServer creates a new src UDP server +func NewServer(module *Module) *Server { + return &Server{ + Module: module, + acceptCh: make(chan *rudp.Conn, 16), + } +} + +// Run starts the server and listens for incoming connections +func (s *Server) Run(ctx *astral.Context) error { + + addr := &net.UDPAddr{Port: s.config.ListenPort} + hto := s.config.DialTimeout + if hto <= 0 { + hto = 5 * time.Second + } + rListener, err := rudp.Listen(ctx, addr, s.Module.config.TransportConfig, hto) + if err != nil { + s.log.Errorv(0, "failed to start rudp listener: %v", err) + return err + } + s.rListener = rListener + s.log.Info("started server at %v", rListener.Addr()) + defer s.log.Info("stopped server at %v", rListener.Addr()) + + // Accept loop + go func() { + acceptCtx := ctx + for { + conn, err := rListener.Accept(acceptCtx) + if err != nil { + return + } + select { + case s.acceptCh <- conn: + default: + // drop if application not consuming fast enough + } + } + }() + + <-ctx.Done() + return s.Close() +} + +// Close gracefully shuts down the server +func (s *Server) Close() error { + if s.rListener != nil { + _ = s.rListener.Close() + } + return nil +} diff --git a/mod/udp/src/unpack.go b/mod/udp/src/unpack.go new file mode 100644 index 00000000..0f9403e3 --- /dev/null +++ b/mod/udp/src/unpack.go @@ -0,0 +1,25 @@ +package udp + +import ( + "bytes" + + "github.com/cryptopunkscc/astrald/mod/exonet" + "github.com/cryptopunkscc/astrald/mod/udp" +) + +var _ exonet.Unpacker = &Module{} + +func (mod *Module) Unpack(network string, data []byte) (exonet.Endpoint, error) { + switch network { + case "udp": + default: + return nil, exonet.ErrUnsupportedNetwork + } + return Unpack(data) +} + +func Unpack(buf []byte) (e *udp.Endpoint, err error) { + e = &udp.Endpoint{} + _, err = e.ReadFrom(bytes.NewReader(buf)) + return +}