diff --git a/core/lwip.go b/core/lwip.go index 9ba76e48..5e314a2a 100644 --- a/core/lwip.go +++ b/core/lwip.go @@ -129,7 +129,7 @@ func (s *lwipStack) Close() error { s.cancel() // Abort and close all TCP and UDP connections. - tcpConns.Range(func(_, c interface{}) bool { + tcpConns.Range(func(c, _ interface{}) bool { c.(*tcpConn).Abort() return true }) diff --git a/core/tcp_callback_export.go b/core/tcp_callback_export.go index 0b75f999..9c8f587d 100644 --- a/core/tcp_callback_export.go +++ b/core/tcp_callback_export.go @@ -56,16 +56,11 @@ func tcpRecvFn(arg unsafe.Pointer, tpcb *C.struct_tcp_pcb, p *C.struct_pbuf, err } }() - conn, ok := tcpConns.Load(getConnKeyVal(arg)) - if !ok { - // The connection does not exists. - C.tcp_abort(tpcb) - return C.ERR_ABRT - } + var conn = (*tcpConn)(arg) if p == nil { // Peer closed, EOF. - err := conn.(TCPConn).LocalClosed() + err := conn.LocalClosed() switch err.(*lwipError).Code { case LWIP_ERR_ABRT: return C.ERR_ABRT @@ -86,7 +81,7 @@ func tcpRecvFn(arg unsafe.Pointer, tpcb *C.struct_tcp_pcb, p *C.struct_pbuf, err C.pbuf_copy_partial(p, unsafe.Pointer(&buf[0]), p.tot_len, 0) } - rerr := conn.(TCPConn).Receive(buf[:totlen]) + rerr := conn.Receive(buf[:totlen]) if rerr != nil { switch rerr.(*lwipError).Code { case LWIP_ERR_ABRT: @@ -114,52 +109,43 @@ func tcpRecvFn(arg unsafe.Pointer, tpcb *C.struct_tcp_pcb, p *C.struct_pbuf, err //export tcpSentFn func tcpSentFn(arg unsafe.Pointer, tpcb *C.struct_tcp_pcb, len C.u16_t) C.err_t { - if conn, ok := tcpConns.Load(getConnKeyVal(arg)); ok { - err := conn.(TCPConn).Sent(uint16(len)) - switch err.(*lwipError).Code { - case LWIP_ERR_ABRT: - return C.ERR_ABRT - case LWIP_ERR_OK: - return C.ERR_OK - default: - panic("unexpected error") - } - } else { - C.tcp_abort(tpcb) + var conn = (*tcpConn)(arg) + err := conn.Sent(uint16(len)) + switch err.(*lwipError).Code { + case LWIP_ERR_ABRT: return C.ERR_ABRT + case LWIP_ERR_OK: + return C.ERR_OK + default: + panic("unexpected error") } } //export tcpErrFn func tcpErrFn(arg unsafe.Pointer, err C.err_t) { - if conn, ok := tcpConns.Load(getConnKeyVal(arg)); ok { - switch err { - case C.ERR_ABRT: - // Aborted through tcp_abort or by a TCP timer - conn.(TCPConn).Err(errors.New("connection aborted")) - case C.ERR_RST: - // The connection was reset by the remote host - conn.(TCPConn).Err(errors.New("connection reseted")) - default: - conn.(TCPConn).Err(errors.New(fmt.Sprintf("lwip error code %v", int(err)))) - } + var conn = (*tcpConn)(arg) + switch err { + case C.ERR_ABRT: + // Aborted through tcp_abort or by a TCP timer + conn.Err(errors.New("connection aborted")) + case C.ERR_RST: + // The connection was reset by the remote host + conn.Err(errors.New("connection reseted")) + default: + conn.Err(errors.New(fmt.Sprintf("lwip error code %v", int(err)))) } } //export tcpPollFn func tcpPollFn(arg unsafe.Pointer, tpcb *C.struct_tcp_pcb) C.err_t { - if conn, ok := tcpConns.Load(getConnKeyVal(arg)); ok { - err := conn.(TCPConn).Poll() - switch err.(*lwipError).Code { - case LWIP_ERR_ABRT: - return C.ERR_ABRT - case LWIP_ERR_OK: - return C.ERR_OK - default: - panic("unexpected error") - } - } else { - C.tcp_abort(tpcb) + var conn = (*tcpConn)(arg) + err := conn.Poll() + switch err.(*lwipError).Code { + case LWIP_ERR_ABRT: return C.ERR_ABRT + case LWIP_ERR_OK: + return C.ERR_OK + default: + panic("unexpected error") } } diff --git a/core/tcp_conn.go b/core/tcp_conn.go index de5f6595..73b68f99 100644 --- a/core/tcp_conn.go +++ b/core/tcp_conn.go @@ -3,13 +3,17 @@ package core /* #cgo CFLAGS: -I./c/include #include "lwip/tcp.h" + +void tcp_arg_cgo(struct tcp_pcb *pcb, uintptr_t ptr) { + tcp_arg(pcb, (void*)ptr); +} + */ import "C" import ( "errors" "fmt" "io" - "math/rand" "net" "sync" "time" @@ -18,6 +22,8 @@ import ( type tcpConnState uint +var tcpConns sync.Map + const ( // tcpNewConn is the initial state. tcpNewConn tcpConnState = iota @@ -62,8 +68,6 @@ type tcpConn struct { handler TCPConnHandler remoteAddr *net.TCPAddr localAddr *net.TCPAddr - connKeyArg unsafe.Pointer - connKey uint32 canWrite *sync.Cond // Condition variable to implement TCP backpressure. state tcpConnState sndPipeReader *io.PipeReader @@ -73,12 +77,6 @@ type tcpConn struct { } func newTCPConn(pcb *C.struct_tcp_pcb, handler TCPConnHandler) (TCPConn, error) { - connKeyArg := newConnKeyArg() - connKey := rand.Uint32() - setConnKeyVal(unsafe.Pointer(connKeyArg), connKey) - - // Pass the key as arg for subsequent tcp callbacks. - C.tcp_arg(pcb, unsafe.Pointer(connKeyArg)) // Register callbacks. setTCPRecvCallback(pcb) @@ -92,16 +90,14 @@ func newTCPConn(pcb *C.struct_tcp_pcb, handler TCPConnHandler) (TCPConn, error) handler: handler, localAddr: ParseTCPAddr(ipAddrNTOA(pcb.remote_ip), uint16(pcb.remote_port)), remoteAddr: ParseTCPAddr(ipAddrNTOA(pcb.local_ip), uint16(pcb.local_port)), - connKeyArg: connKeyArg, - connKey: connKey, canWrite: sync.NewCond(&sync.Mutex{}), state: tcpNewConn, sndPipeReader: pipeReader, sndPipeWriter: pipeWriter, } - // Associate conn with key and save to the global map. - tcpConns.Store(connKey, conn) + C.tcp_arg_cgo(pcb, C.uintptr_t(uintptr(unsafe.Pointer(conn)))) + tcpConns.Store(conn, true) // Connecting remote host could take some time, do it in another goroutine // to prevent blocking the lwip thread. @@ -453,10 +449,7 @@ func (conn *tcpConn) LocalClosed() error { } func (conn *tcpConn) release() { - if _, found := tcpConns.Load(conn.connKey); found { - freeConnKeyArg(conn.connKeyArg) - tcpConns.Delete(conn.connKey) - } + tcpConns.Delete(conn) conn.sndPipeWriter.Close() conn.sndPipeReader.Close() conn.state = tcpClosed diff --git a/core/tcp_conn_map.go b/core/tcp_conn_map.go deleted file mode 100644 index 881708e4..00000000 --- a/core/tcp_conn_map.go +++ /dev/null @@ -1,65 +0,0 @@ -package core - -/* -#cgo CFLAGS: -I./c/include -#include "lwip/tcp.h" -#include - -void* -new_conn_key_arg() -{ - return malloc(sizeof(uint32_t)); -} - -void -free_conn_key_arg(void *arg) -{ - free(arg); -} - -void -set_conn_key_val(void *arg, uint32_t val) -{ - *((uint32_t*)arg) = val; -} - -uint32_t -get_conn_key_val(void *arg) -{ - return *((uint32_t*)arg); -} -*/ -import "C" -import ( - "sync" - "unsafe" -) - -var tcpConns sync.Map - -// We need such a key-value mechanism because when passing a Go pointer -// to C, the Go pointer will only be valid during the call. -// If we pass a Go pointer to tcp_arg(), this pointer will not be usable -// in subsequent callbacks (e.g.: tcp_recv(), tcp_err()). -// -// Instead we need to pass a C pointer to tcp_arg(), we manually allocate -// the memory in C and return its pointer to Go code. After the connection -// end, the memory should be freed manually. -// -// See also: -// https://github.com/golang/go/issues/12416 -func newConnKeyArg() unsafe.Pointer { - return C.new_conn_key_arg() -} - -func freeConnKeyArg(p unsafe.Pointer) { - C.free_conn_key_arg(p) -} - -func setConnKeyVal(p unsafe.Pointer, val uint32) { - C.set_conn_key_val(p, C.uint32_t(val)) -} - -func getConnKeyVal(p unsafe.Pointer) uint32 { - return uint32(C.get_conn_key_val(p)) -}