diff --git a/src/core/tcp.ml b/src/core/tcp.ml index a955b252..41b765c0 100644 --- a/src/core/tcp.ml +++ b/src/core/tcp.ml @@ -30,10 +30,12 @@ module type S = sig and type write_error := write_error val dst: flow -> ipaddr * int + val unread : flow -> Cstruct.t -> unit val write_nodelay: flow -> Cstruct.t -> (unit, write_error) result Lwt.t val writev_nodelay: flow -> Cstruct.t list -> (unit, write_error) result Lwt.t val create_connection: ?keepalive:Keepalive.t -> t -> ipaddr * int -> (flow, error) result Lwt.t val listen : t -> port:int -> ?keepalive:Keepalive.t -> (flow -> unit Lwt.t) -> unit + val is_listening : t -> port:int -> (flow -> unit Lwt.t) option val unlisten : t -> port:int -> unit val input: t -> src:ipaddr -> dst:ipaddr -> Cstruct.t -> unit Lwt.t end diff --git a/src/core/tcp.mli b/src/core/tcp.mli index e807e071..7d7e8efb 100644 --- a/src/core/tcp.mli +++ b/src/core/tcp.mli @@ -53,6 +53,10 @@ module type S = sig (** Get the destination IP address and destination port that a flow is currently connected to. *) + val unread : flow -> Cstruct.t -> unit + (** [unread flow buffer] puts [buffer] at the beginning of the receive queue, + so the next [read] from [flow] will receive [buffer]. *) + val write_nodelay: flow -> Cstruct.t -> (unit, write_error) result Lwt.t (** [write_nodelay flow buffer] writes the contents of [buffer] to the flow. The thread blocks until all data has been successfully @@ -83,8 +87,12 @@ module type S = sig executed for each flow that was established. If [keepalive] is provided, this configuration will be applied before calling [callback]. - @raise Invalid_argument if [port < 0] or [port > 65535] - *) + @raise Invalid_argument if [port < 0] or [port > 65535] *) + + val is_listening : t -> port:int -> (flow -> unit Lwt.t) option + (** [is_listening t ~port] returns the [callback] on [port], if it exists. + + @raise Invalid_argument if [port < 0] or [port > 65535] *) val unlisten : t -> port:int -> unit (** [unlisten t ~port] stops any listener on [port]. *) diff --git a/src/core/udp.ml b/src/core/udp.ml index cb5319ba..b845f8ff 100644 --- a/src/core/udp.ml +++ b/src/core/udp.ml @@ -6,6 +6,7 @@ module type S = sig val disconnect : t -> unit Lwt.t type callback = src:ipaddr -> dst:ipaddr -> src_port:int -> Cstruct.t -> unit Lwt.t val listen : t -> port:int -> callback -> unit + val is_listening : t -> port:int -> callback option val unlisten : t -> port:int -> unit val input: t -> src:ipaddr -> dst:ipaddr -> Cstruct.t -> unit Lwt.t val write: ?src:ipaddr -> ?src_port:int -> ?ttl:int -> dst:ipaddr -> dst_port:int -> t -> Cstruct.t -> diff --git a/src/core/udp.mli b/src/core/udp.mli index ff0a1dbf..924cc56f 100644 --- a/src/core/udp.mli +++ b/src/core/udp.mli @@ -29,6 +29,11 @@ module type S = sig @raise Invalid_argument if [port < 0] or [port > 65535] *) + val is_listening : t -> port:int -> callback option + (** [is_listening t ~port] returns the [callback] on [port], if it exists. + + @raise Invalid_argument if [port < 0] or [port > 65535] *) + val unlisten : t -> port:int -> unit (** [unlisten t ~port] stops any listeners on [port]. *) diff --git a/src/stack-unix/dune b/src/stack-unix/dune index 4bf5d785..fb96eb3a 100644 --- a/src/stack-unix/dune +++ b/src/stack-unix/dune @@ -33,7 +33,7 @@ (library (name tcpv4v6_socket) (public_name tcpip.tcpv4v6-socket) - (modules tcp_socket tcpv4v6_socket) + (modules tcpv4v6_socket) (wrapped false) (instrumentation (backend bisect_ppx)) diff --git a/src/stack-unix/tcp_socket.ml b/src/stack-unix/tcp_socket.ml deleted file mode 100644 index 5a0b30af..00000000 --- a/src/stack-unix/tcp_socket.ml +++ /dev/null @@ -1,68 +0,0 @@ -open Lwt - -type error = [ Tcpip.Tcp.error | `Exn of exn ] -type write_error = [ Tcpip.Tcp.write_error | `Exn of exn ] - -let pp_error ppf = function - | #Tcpip.Tcp.error as e -> Tcpip.Tcp.pp_error ppf e - | `Exn e -> Fmt.exn ppf e - -let pp_write_error ppf = function - | #Tcpip.Tcp.write_error as e -> Tcpip.Tcp.pp_write_error ppf e - | `Exn e -> Fmt.exn ppf e - -let ignore_canceled = function - | Lwt.Canceled -> Lwt.return_unit - | exn -> raise exn - -let disconnect _ = - return_unit - -let read fd = - let buflen = 4096 in - let buf = Cstruct.create buflen in - Lwt.catch (fun () -> - Lwt_cstruct.read fd buf - >>= function - | 0 -> return (Ok `Eof) - | n when n = buflen -> return (Ok (`Data buf)) - | n -> return @@ Ok (`Data (Cstruct.sub buf 0 n)) - ) - (fun exn -> return (Error (`Exn exn))) - -let rec write fd buf = - Lwt.catch - (fun () -> - Lwt_cstruct.write fd buf - >>= function - | n when n = Cstruct.length buf -> return @@ Ok () - | 0 -> return @@ Error `Closed - | n -> write fd (Cstruct.sub buf n (Cstruct.length buf - n)) - ) (function - | Unix.Unix_error(Unix.EPIPE, _, _) -> return @@ Error `Closed - | e -> return (Error (`Exn e))) - -let writev fd bufs = - Lwt_list.fold_left_s - (fun res buf -> - match res with - | Error _ as e -> return e - | Ok () -> write fd buf - ) (Ok ()) bufs - -(* TODO make nodelay a flow option *) -let write_nodelay fd buf = - write fd buf - -(* TODO make nodelay a flow option *) -let writev_nodelay fd bufs = - writev fd bufs - -let close fd = - Lwt.catch - (fun () -> Lwt_unix.close fd) - (function - | Unix.Unix_error (Unix.EBADF, _, _) -> Lwt.return_unit - | e -> Lwt.fail e) - -let input _t ~src:_ ~dst:_ _buf = Lwt.return_unit diff --git a/src/stack-unix/tcpv4v6_socket.ml b/src/stack-unix/tcpv4v6_socket.ml index b8b4809c..e7eb2fb3 100644 --- a/src/stack-unix/tcpv4v6_socket.ml +++ b/src/stack-unix/tcpv4v6_socket.ml @@ -21,12 +21,15 @@ module Log = (val Logs.src_log src : Logs.LOG) open Lwt.Infix type ipaddr = Ipaddr.t -type flow = Lwt_unix.file_descr +type flow = { + mutable buf : Cstruct.t; + fd : Lwt_unix.file_descr; +} type t = { interface: [ `Any | `Ip of Unix.inet_addr * Unix.inet_addr | `V4_only of Unix.inet_addr | `V6_only of Unix.inet_addr ]; (* source ip to bind to *) - mutable active_connections : Lwt_unix.file_descr list; - listen_sockets : (int, Lwt_unix.file_descr list) Hashtbl.t; + mutable active_connections : flow list; + listen_sockets : (int, Lwt_unix.file_descr list * (flow -> unit Lwt.t)) Hashtbl.t; mutable switched_off : unit Lwt.t; } @@ -35,7 +38,75 @@ let set_switched_off t switched_off = let any_v6 = Ipaddr_unix.V6.to_inet_addr Ipaddr.V6.unspecified -include Tcp_socket +type error = [ Tcpip.Tcp.error | `Exn of exn ] +type write_error = [ Tcpip.Tcp.write_error | `Exn of exn ] + +let pp_error ppf = function + | #Tcpip.Tcp.error as e -> Tcpip.Tcp.pp_error ppf e + | `Exn e -> Fmt.exn ppf e + +let pp_write_error ppf = function + | #Tcpip.Tcp.write_error as e -> Tcpip.Tcp.pp_write_error ppf e + | `Exn e -> Fmt.exn ppf e + +let ignore_canceled = function + | Lwt.Canceled -> Lwt.return_unit + | exn -> raise exn + +let read ({ buf ; fd } as flow) = + if Cstruct.length buf > 0 then begin + flow.buf <- Cstruct.empty; + Lwt.return (Ok (`Data buf)) + end else + let buflen = 4096 in + let buf = Cstruct.create buflen in + Lwt.catch (fun () -> + Lwt_cstruct.read fd buf + >>= function + | 0 -> Lwt.return (Ok `Eof) + | n when n = buflen -> Lwt.return (Ok (`Data buf)) + | n -> Lwt.return @@ Ok (`Data (Cstruct.sub buf 0 n)) + ) + (fun exn -> Lwt.return (Error (`Exn exn))) + +let rec write ({ fd; _ } as flow) buf = + Lwt.catch + (fun () -> + Lwt_cstruct.write fd buf + >>= function + | n when n = Cstruct.length buf -> Lwt.return @@ Ok () + | 0 -> Lwt.return @@ Error `Closed + | n -> write flow (Cstruct.sub buf n (Cstruct.length buf - n)) + ) (function + | Unix.Unix_error(Unix.EPIPE, _, _) -> Lwt.return @@ Error `Closed + | e -> Lwt.return (Error (`Exn e))) + +let writev fd bufs = + Lwt_list.fold_left_s + (fun res buf -> + match res with + | Error _ as e -> Lwt.return e + | Ok () -> write fd buf + ) (Ok ()) bufs + +(* TODO make nodelay a flow option *) +let write_nodelay fd buf = + write fd buf + +(* TODO make nodelay a flow option *) +let writev_nodelay fd bufs = + writev fd bufs + +let close_fd fd = + Lwt.catch + (fun () -> Lwt_unix.close fd) + (function + | Unix.Unix_error (Unix.EBADF, _, _) -> Lwt.return_unit + | e -> Lwt.fail e) + +let close { fd; _ } = close_fd fd + +let input _t ~src:_ ~dst:_ _buf = Lwt.return_unit let connect ~ipv4_only ~ipv6_only ipv4 ipv6 = let interface = @@ -62,11 +133,11 @@ let connect ~ipv4_only ~ipv6_only ipv4 ipv6 = let disconnect t = Lwt_list.iter_p close t.active_connections >>= fun () -> - Lwt_list.iter_p close - (Hashtbl.fold (fun _ fd acc -> fd @ acc) t.listen_sockets []) >>= fun () -> + Lwt_list.iter_p close_fd + (Hashtbl.fold (fun _ (fds, _) acc -> fds @ acc) t.listen_sockets []) >>= fun () -> Lwt.cancel t.switched_off ; Lwt.return_unit -let dst fd = +let dst { fd; _ } = match Lwt_unix.getpeername fd with | Unix.ADDR_UNIX _ -> raise (Failure "unexpected: got a unix instead of tcp sock") @@ -78,6 +149,10 @@ let dst fd = in ip, port +let unread fd buf = + let buf = Cstruct.append buf fd.buf in + fd.buf <- buf + let create_connection ?keepalive t (dst,dst_port) = match match dst, t.interface with @@ -104,19 +179,23 @@ let create_connection ?keepalive t (dst,dst_port) = | None -> () | Some { Tcpip.Tcp.Keepalive.after; interval; probes } -> Tcp_socket_options.enable_keepalive ~fd ~after ~interval ~probes ); - t.active_connections <- fd :: t.active_connections; - Lwt.return (Ok fd)) + let flow = { buf = Cstruct.empty ; fd } in + t.active_connections <- flow :: t.active_connections; + Lwt.return (Ok flow)) (fun exn -> - close fd >>= fun () -> + close_fd fd >>= fun () -> Lwt.return (Error (`Exn exn))) let unlisten t ~port = match Hashtbl.find_opt t.listen_sockets port with | None -> () - | Some fds -> + | Some (fds, _) -> Hashtbl.remove t.listen_sockets port; try List.iter (fun fd -> Unix.close (Lwt_unix.unix_file_descr fd)) fds with _ -> () +let is_listening t ~port = + Option.map snd (Hashtbl.find_opt t.listen_sockets port) + let listen t ~port ?keepalive callback = if port < 0 || port > 65535 then raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port)); @@ -147,7 +226,7 @@ let listen t ~port ?keepalive callback = in List.iter (fun (fd, addr) -> Unix.bind (Lwt_unix.unix_file_descr fd) addr; - Hashtbl.replace t.listen_sockets port (List.map fst fds); + Hashtbl.replace t.listen_sockets port (List.map fst fds, callback); Lwt_unix.listen fd 10; (* FIXME: we should not ignore the result *) Lwt.async (fun () -> @@ -156,7 +235,8 @@ let listen t ~port ?keepalive callback = if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; Lwt.catch (fun () -> Lwt_unix.accept fd >|= fun (afd, _) -> - t.active_connections <- afd :: t.active_connections; + let flow = { buf = Cstruct.empty ; fd = afd } in + t.active_connections <- flow :: t.active_connections; (match keepalive with | None -> () | Some { Tcpip.Tcp.Keepalive.after; interval; probes } -> @@ -164,10 +244,10 @@ let listen t ~port ?keepalive callback = Lwt.async (fun () -> Lwt.catch - (fun () -> callback afd) + (fun () -> callback flow) (fun exn -> Log.warn (fun m -> m "error %s in callback" (Printexc.to_string exn)) ; - close afd)); + close flow)); `Continue) (function | Unix.Unix_error (Unix.EBADF, _, _) -> @@ -179,4 +259,4 @@ let listen t ~port ?keepalive callback = | `Continue -> loop () | `Stop -> Lwt.return_unit in - Lwt.catch loop ignore_canceled >>= fun () -> close fd)) fds + Lwt.catch loop ignore_canceled >>= fun () -> close_fd fd)) fds diff --git a/src/stack-unix/tcpv4v6_socket.mli b/src/stack-unix/tcpv4v6_socket.mli index f4493ad2..726f7e5d 100644 --- a/src/stack-unix/tcpv4v6_socket.mli +++ b/src/stack-unix/tcpv4v6_socket.mli @@ -17,7 +17,6 @@ include Tcpip.Tcp.S with type ipaddr = Ipaddr.t - and type flow = Lwt_unix.file_descr and type error = [ Tcpip.Tcp.error | `Exn of exn ] and type write_error = [ Tcpip.Tcp.write_error | `Exn of exn ] diff --git a/src/stack-unix/udpv4v6_socket.ml b/src/stack-unix/udpv4v6_socket.ml index c6ff9571..dc567c63 100644 --- a/src/stack-unix/udpv4v6_socket.ml +++ b/src/stack-unix/udpv4v6_socket.ml @@ -27,7 +27,7 @@ let any_v6 = Ipaddr_unix.V6.to_inet_addr Ipaddr.V6.unspecified type t = { interface: [ `Any | `Ip of Unix.inet_addr * Unix.inet_addr | `V4_only of Unix.inet_addr | `V6_only of Unix.inet_addr ]; (* source ip to bind to *) - listen_fds: (int, Lwt_unix.file_descr * Lwt_unix.file_descr option) Hashtbl.t; (* UDP fds bound to a particular port *) + listen_fds: (int, Lwt_unix.file_descr * Lwt_unix.file_descr option * callback) Hashtbl.t; (* UDP fds bound to a particular port *) mutable switched_off : unit Lwt.t; } @@ -38,12 +38,12 @@ let ignore_canceled = function | Lwt.Canceled -> Lwt.return_unit | exn -> raise exn -let get_udpv4v6_listening_fd ?(preserve = true) ?(v4_or_v6 = `Both) {listen_fds;interface;_} port = +let get_udpv4v6_listening_fd ?preserve ?(v4_or_v6 = `Both) {listen_fds;interface;_} port = try Lwt.return (match Hashtbl.find listen_fds port with - | (fd, None) -> false, [ fd ] - | (fd, Some fd') -> false, [ fd ; fd' ]) + | (fd, None, _) -> false, [ fd ] + | (fd, Some fd', _) -> false, [ fd ; fd' ]) with Not_found -> (match interface with | `Any -> @@ -76,8 +76,8 @@ let get_udpv4v6_listening_fd ?(preserve = true) ?(v4_or_v6 = `Both) {listen_fds; | `V6_only ip -> let fd = Lwt_unix.(socket PF_INET6 SOCK_DGRAM 0) in Lwt_unix.bind fd (Lwt_unix.ADDR_INET (ip, port)) >|= fun () -> - ((fd, None), [ fd ])) >|= fun (fds, r) -> - if preserve then Hashtbl.add listen_fds port fds; + ((fd, None), [ fd ])) >|= fun ((fd1, fd2), r) -> + Option.iter (fun cb -> Hashtbl.add listen_fds port (fd1, fd2, cb)) preserve; true, r @@ -121,7 +121,7 @@ let connect ~ipv4_only ~ipv6_only ipv4 ipv6 = Lwt.return { interface; listen_fds; switched_off = fst (Lwt.wait ()) } let disconnect t = - Hashtbl.fold (fun _ (fd, fd') r -> + Hashtbl.fold (fun _ (fd, fd', _) r -> r >>= fun () -> close fd >>= fun () -> match fd' with None -> Lwt.return_unit | Some fd -> close fd) @@ -146,7 +146,7 @@ let write ?src:_ ?src_port ?ttl:_ttl ~dst ~dst_port t buf = match t.interface, v4_or_v6 with | `Any, _ | `Ip _, _ | `V4_only _, `V4 | `V6_only _, `V6 -> let p = match src_port with None -> 0 | Some x -> x in - get_udpv4v6_listening_fd ~preserve:false ~v4_or_v6 t p >>= fun (created, fds) -> + get_udpv4v6_listening_fd ~v4_or_v6 t p >>= fun (created, fds) -> ((match fds, v4_or_v6 with | [ fd ], _ -> Lwt.return (Ok fd) | [ v4 ; _v6 ], `V4 -> Lwt.return (Ok v4) @@ -161,19 +161,25 @@ let write ?src:_ ?src_port ?ttl:_ttl ~dst ~dst_port t buf = let unlisten t ~port = try - let fd, fd' = Hashtbl.find t.listen_fds port in + let fd, fd', _ = Hashtbl.find t.listen_fds port in Hashtbl.remove t.listen_fds port; (match fd' with None -> () | Some fd' -> Unix.close (Lwt_unix.unix_file_descr fd')); Unix.close (Lwt_unix.unix_file_descr fd) with _ -> () +let is_listening t ~port = + if port < 0 || port > 65535 then + raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port)) + else + Option.map (fun (_, _, cb) -> cb) (Hashtbl.find_opt t.listen_fds port) + let listen t ~port callback = if port < 0 || port > 65535 then raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port)) else (* FIXME: we should not ignore the result *) Lwt.async (fun () -> - get_udpv4v6_listening_fd t port >|= fun (_, fds) -> + get_udpv4v6_listening_fd ~preserve:callback t port >|= fun (_, fds) -> List.iter (fun fd -> Lwt.async (fun () -> let buf = Cstruct.create 4096 in diff --git a/src/tcp/flow.ml b/src/tcp/flow.ml index e644ab1c..6b96857d 100644 --- a/src/tcp/flow.ml +++ b/src/tcp/flow.ml @@ -83,6 +83,12 @@ struct else Hashtbl.replace t.listeners port (keepalive, cb) + let is_listening t ~port = + if port < 0 || port > 65535 then + raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port)) + else + Option.map snd (Hashtbl.find_opt t.listeners port) + let unlisten t ~port = Hashtbl.remove t.listeners port let _pp_pcb fmt pcb = @@ -581,6 +587,9 @@ struct (* No existing PCB, so check if it is a SYN for a listening function *) (input_no_pcb t (pkt, payload)) + let unread pcb buf = + User_buffer.Rx.add_l pcb.urx buf + (* Blocking read on a PCB *) let read pcb = User_buffer.Rx.take_l pcb.urx diff --git a/src/tcp/user_buffer.ml b/src/tcp/user_buffer.ml index cc95d25c..6f4b99cb 100644 --- a/src/tcp/user_buffer.ml +++ b/src/tcp/user_buffer.ml @@ -59,6 +59,9 @@ module Rx = struct | None -> 0 | Some b -> Cstruct.length b + let add_l t s = + ignore(Lwt_dllist.add_l (Some s) t.q) + let add_r t s = if t.cur_size > t.max_size then let th,u = Lwt.wait () in diff --git a/src/tcp/user_buffer.mli b/src/tcp/user_buffer.mli index 63f984d3..d3e21930 100644 --- a/src/tcp/user_buffer.mli +++ b/src/tcp/user_buffer.mli @@ -19,6 +19,7 @@ module Rx : sig type t val create : max_size:int32 -> wnd:Window.t -> t + val add_l : t -> Cstruct.t -> unit val add_r : t -> Cstruct.t option -> unit Lwt.t val take_l : t -> Cstruct.t option Lwt.t val cur_size : t -> int32 diff --git a/src/udp/udp.ml b/src/udp/udp.ml index 1031bccc..d0400a66 100644 --- a/src/udp/udp.ml +++ b/src/udp/udp.ml @@ -40,6 +40,12 @@ module Make (Ip : Tcpip.Ip.S) (Random : Mirage_random.S) = struct else Hashtbl.replace t.listeners port callback + let is_listening t ~port = + if port < 0 || port > 65535 then + raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port)) + else + Hashtbl.find_opt t.listeners port + let unlisten t ~port = Hashtbl.remove t.listeners port (* TODO: ought we to check to make sure the destination is relevant