diff --git a/changelog.md b/changelog.md index 665346ad7009e..4bb8b0055dd51 100644 --- a/changelog.md +++ b/changelog.md @@ -39,6 +39,9 @@ errors. [//]: # "Additions:" +- Standard posix calls on Linux and Windows are not all thread safe, switch to + ugly but thread-safe extensions where affected in `nativesockets` and `times`. + - `setutils.symmetricDifference` along with its operator version `` setutils.`-+-` `` and in-place version `setutils.toggle` have been added to more efficiently calculate the symmetric difference of bitsets. diff --git a/lib/posix/posix.nim b/lib/posix/posix.nim index 5046daaa2ed95..a2f734c578ec9 100644 --- a/lib/posix/posix.nim +++ b/lib/posix/posix.nim @@ -1073,6 +1073,14 @@ else: proc gethostbyaddr*(a1: cstring, a2: cint, a3: cint): ptr Hostent {. importc, header: "".} proc gethostbyname*(a1: cstring): ptr Hostent {.importc, header: "".} +when defined(linux): + proc gethostbyaddr_r*(a1: pointer, a2: SockLen, a3: cint, + ret: ptr Hostent, buf: cstring, buflen: csize_t, + res: ptr ptr Hostent, h_errnop: ptr cint): cint {. + importc, header: "".} + proc gethostbyname_r*(name: cstring, ret: ptr Hostent, + buf: cstring, buflen: csize_t, res: ptr ptr Hostent, + h_errnop: ptr cint): cint {.importc, header: "".} proc gethostent*(): ptr Hostent {.importc, header: "".} proc getnameinfo*(a1: ptr SockAddr, a2: SockLen, @@ -1090,6 +1098,13 @@ proc getprotoent*(): ptr Protoent {.importc, header: "".} proc getservbyname*(a1, a2: cstring): ptr Servent {.importc, header: "".} proc getservbyport*(a1: cint, a2: cstring): ptr Servent {. importc, header: "".} +when defined(linux) and not defined(android): + proc getservbyname_r*(name, proto: cstring, resultBuf: ptr Servent, + buf: cstring, buflen: csize_t, res: ptr ptr Servent): cint {. + importc, header: "".} + proc getservbyport_r*(port: cint, proto: cstring, resultBuf: ptr Servent, + buf: cstring, buflen: csize_t, res: ptr ptr Servent): cint {. + importc, header: "".} proc getservent*(): ptr Servent {.importc, header: "".} proc sethostent*(a1: cint) {.importc, header: "".} diff --git a/lib/pure/nativesockets.nim b/lib/pure/nativesockets.nim index c7868b74c6278..69c1ee9af5eaa 100644 --- a/lib/pure/nativesockets.nim +++ b/lib/pure/nativesockets.nim @@ -25,6 +25,8 @@ when hostOS == "solaris": const useWinVersion = defined(windows) or defined(nimdoc) const useNimNetLite = defined(nimNetLite) or defined(freertos) or defined(zephyr) or defined(nuttx) +const smallBufInitSize = when defined(testReentrantBufs): 1 else: 1024 +const largeBufInitSize = when defined(testReentrantBufs): 1 else: 4096 when useWinVersion: import std/winlean @@ -215,6 +217,23 @@ proc getProtoByName*(name: string): int {.since: (1, 3, 5).} = ## Returns a protocol code from the database that matches the protocol `name`. when useWinVersion: let protoent = winlean.getprotobyname(name.cstring) + elif defined(linux) and not defined(android): + var protoent: ptr posix.Protoent + {.emit: ";\n#ifdef __GLIBC__".} + proc getprotobyname_r(name: cstring, resultBuf: ptr posix.Protoent, + buf: cstring, buflen: csize_t, + res: ptr ptr posix.Protoent): cint + {.importc, header: "".} + var pe: posix.Protoent + var buf = newString(smallBufInitSize) + while true: + let ret = getprotobyname_r(name.cstring, addr pe, + buf.cstring, csize_t(buf.len), addr protoent) + if ret != ERANGE: break + buf.setLen(buf.len * 2) + {.emit: ";\n#else".} + protoent = posix.getprotobyname(name.cstring) + {.emit: ";\n#endif".} else: let protoent = posix.getprotobyname(name.cstring) @@ -371,6 +390,15 @@ when not useNimNetLite: ## On posix this will search through the `/etc/services` file. when useWinVersion: var s = winlean.getservbyname(name, proto) + elif defined(linux) and not defined(android): + var se: posix.Servent + var s: ptr posix.Servent + var buf = newString(smallBufInitSize) + while true: + let ret = getservbyname_r(name.cstring, proto.cstring, addr se, + buf.cstring, csize_t(buf.len), addr s) + if ret != ERANGE: break + buf.setLen(buf.len * 2) else: var s = posix.getservbyname(name, proto) if s == nil: raiseOSError(osLastError(), "Service not found.") @@ -389,6 +417,15 @@ when not useNimNetLite: ## On posix this will search through the `/etc/services` file. when useWinVersion: var s = winlean.getservbyport(uint16(port).cint, proto) + elif defined(linux) and not defined(android): + var se: posix.Servent + var s: ptr posix.Servent + var buf = newString(smallBufInitSize) + while true: + let ret = getservbyport_r(uint16(port).cint, proto.cstring, addr se, + buf.cstring, csize_t(buf.len), addr s) + if ret != ERANGE: break + buf.setLen(buf.len * 2) else: var s = posix.getservbyport(uint16(port).cint, proto) if s == nil: raiseOSError(osLastError(), "Service not found.") @@ -424,14 +461,29 @@ when not useNimNetLite: var s = winlean.gethostbyaddr(cast[ptr InAddr](myAddr), addrLen.cuint, cint(family)) if s == nil: raiseOSError(osLastError()) - else: - var s = + elif defined(linux): + var he: posix.Hostent + var h_errnop: cint + var s: ptr posix.Hostent + var buf = newString(largeBufInitSize) + while true: when defined(android4): - posix.gethostbyaddr(cast[cstring](myAddr), addrLen.cint, - cint(family)) + let ret = gethostbyaddr_r(cast[cstring](myAddr), addrLen.SockLen, + cint(family), addr he, + buf.cstring, csize_t(buf.len), + addr s, addr h_errnop) else: - posix.gethostbyaddr(myAddr, addrLen.SockLen, - cint(family)) + let ret = gethostbyaddr_r(myAddr, addrLen.SockLen, + cint(family), addr he, + buf.cstring, csize_t(buf.len), + addr s, addr h_errnop) + if ret != ERANGE: break + buf.setLen(buf.len * 2) + if s == nil: + raiseOSError(osLastError(), $hstrerror(h_errnop)) + else: + # macOS: gethostbyaddr is thread-safe via TLS + var s = posix.gethostbyaddr(myAddr, addrLen.SockLen, cint(family)) if s == nil: raiseOSError(osLastError(), $hstrerror(h_errno)) @@ -476,8 +528,23 @@ when not useNimNetLite: ## This function will lookup the IP address of a hostname. when useWinVersion: var s = winlean.gethostbyname(name) + elif defined(linux): + var he: posix.Hostent + var h_errnop: cint + var s: ptr posix.Hostent + var buf = newString(largeBufInitSize) + while true: + let ret = gethostbyname_r(name.cstring, addr he, + buf.cstring, csize_t(buf.len), + addr s, addr h_errnop) + if ret != ERANGE: break + buf.setLen(buf.len * 2) + if s == nil: + raiseOSError(osLastError(), $hstrerror(h_errnop)) else: var s = posix.gethostbyname(name) + if s == nil: + raiseOSError(osLastError(), $hstrerror(h_errno)) if s == nil: raiseOSError(osLastError()) result = Hostent( name: $s.h_name, diff --git a/lib/pure/times.nim b/lib/pure/times.nim index 2951ac6cdb078..cff1d6e289e4b 100644 --- a/lib/pure/times.nim +++ b/lib/pure/times.nim @@ -264,7 +264,8 @@ elif defined(windows): tm_yday*: cint ## Day of year [0,365]. tm_isdst*: cint ## Daylight Savings flag. - proc localtime(a1: var CTime): ptr Tm {.importc, header: "", sideEffect.} + # Windows CRT's localtime_s has reversed args vs C11 + proc localtime_s(a2: ptr Tm, a1: var CTime): cint {.importc, header: "", sideEffect.} type Month* = enum ## Represents a month. Note that the enum starts at `1`, @@ -1322,20 +1323,25 @@ else: when defined(windows): if unix < 0: var a = 0.CTime - let tmPtr = localtime(a) - if not tmPtr.isNil: - let tm = tmPtr[] - return ((0 - tm.toAdjUnix).int, false) - return (0, false) + var tm: Tm + if localtime_s(addr tm, a) != 0: + return (0, false) + return ((0 - tm.toAdjUnix).int, false) # In case of a 32-bit time_t, we fallback to the closest available # timezone information. var a = clamp(unix, low(CTime).int64, high(CTime).int64).CTime - let tmPtr = localtime(a) - if not tmPtr.isNil: - let tm = tmPtr[] - return ((a.int64 - tm.toAdjUnix).int, tm.tm_isdst > 0) - return (0, false) + var tm: Tm + when defined(windows): + if localtime_s(addr tm, a) != 0: + return (0, false) + else: + # localtime_r doesn't call tzset() implicitly unlike localtime(). + # tzset() must be called before localtime_r to pick up TZ changes. + tzset() + if localtime_r(a, tm).isNil: + return (0, false) + return ((a.int64 - tm.toAdjUnix).int, tm.tm_isdst > 0) proc localZonedTimeFromTime(time: Time): ZonedTime {.gcsafe.} = let (offset, dst) = getLocalOffsetAndDst(time.seconds) diff --git a/tests/stdlib/tnativesockets.nim b/tests/stdlib/tnativesockets.nim index 8242beb836c94..6eb9a8dcca77e 100644 --- a/tests/stdlib/tnativesockets.nim +++ b/tests/stdlib/tnativesockets.nim @@ -10,6 +10,10 @@ block: let hostname = getHostname() doAssert hostname.len > 0 +block: + doAssertRaises(OSError): + discard getHostByName("nonexistent.invalid") + when defined(windows): assertAll: toInt(IPPROTO_IP) == 0 diff --git a/tests/stdlib/tnativesockets_reentrant.nim b/tests/stdlib/tnativesockets_reentrant.nim new file mode 100644 index 0000000000000..d2b0e9c5abaa6 --- /dev/null +++ b/tests/stdlib/tnativesockets_reentrant.nim @@ -0,0 +1,35 @@ +discard """ + matrix: "--mm:refc -d:testReentrantBufs; --mm:orc -d:testReentrantBufs" + joinable: false +""" + +## Tests that the _r network calls handle ERANGE buffer resizing correctly. +## Compiled with -d:testReentrantBufs which sets initial buffers to 1 byte, +## forcing the while-loop to grow buffers through several ERANGE iterations. + +import std/nativesockets +import std/assertions + +when defined(linux): + block: # getProtoByName + doAssert getProtoByName("tcp") == 6 + doAssert getProtoByName("udp") == 17 + + block: # getServByName + let s = getServByName("http", "tcp") + doAssert s.name == "http" + doAssert s.proto == "tcp" + + block: # getServByPort + let s = getServByPort(Port(htons(80)), "tcp") + doAssert s.name == "http" + doAssert s.proto == "tcp" + + block: # getHostByName + let he = getHostByName("localhost") + doAssert he.name.len > 0 + doAssert he.addrList.len > 0 + + block: # getHostByAddr + let he = getHostByAddr("127.0.0.1") + doAssert he.name.len > 0