Skip to content

Commit b635546

Browse files
committed
[gardening] Privatization of Mocking
1 parent 7b33caf commit b635546

File tree

4 files changed

+131
-159
lines changed

4 files changed

+131
-159
lines changed

Sources/SystemInternals/Exports.swift

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,45 @@ extension String {
141141
self.init(cString: platformString)
142142
#endif
143143
}
144+
}
144145

146+
// TLS
147+
#if os(Windows)
148+
internal typealias _PlatformTLSKey = DWORD
149+
#else
150+
internal typealias _PlatformTLSKey = pthread_key_t
151+
#endif
145152

153+
internal func makeTLSKey() -> _PlatformTLSKey {
154+
#if os(Windows)
155+
let raw: DWORD = FlsAlloc(nil)
156+
if raw == FLS_OUT_OF_INDEXES {
157+
fatalError("Unable to create key")
158+
}
159+
return raw
160+
#else
161+
var raw = pthread_key_t()
162+
guard 0 == pthread_key_create(&raw, nil) else {
163+
fatalError("Unable to create key")
164+
}
165+
return raw
166+
#endif
167+
}
168+
internal func setTLS(_ key: _PlatformTLSKey, _ p: UnsafeMutableRawPointer?) {
169+
#if os(Windows)
170+
guard FlsSetValue(key, p) else {
171+
fatalError("Unable to set TLS")
172+
}
173+
#else
174+
guard 0 == pthread_setspecific(key, p) else {
175+
fatalError("Unable to set TLS")
176+
}
177+
#endif
178+
}
179+
internal func getTLS(_ key: _PlatformTLSKey) -> UnsafeMutableRawPointer? {
180+
#if os(Windows)
181+
FlsGetValue(key)
182+
#else
183+
pthread_getspecific(key)
184+
#endif
146185
}

Sources/SystemInternals/Mocking.swift

Lines changed: 79 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
//
1919

2020
#if ENABLE_MOCKING
21-
public struct Trace {
22-
public struct Entry: Hashable {
23-
var name: String
24-
var arguments: [AnyHashable]
21+
internal struct Trace {
22+
internal struct Entry: Hashable {
23+
private var name: String
24+
private var arguments: [AnyHashable]
2525

26-
public init(name: String, _ arguments: [AnyHashable]) {
26+
internal init(name: String, _ arguments: [AnyHashable]) {
2727
self.name = name
2828
self.arguments = arguments
2929
}
@@ -32,39 +32,20 @@ public struct Trace {
3232
private var entries: [Entry] = []
3333
private var firstEntry: Int = 0
3434

35-
public var isEmpty: Bool { firstEntry >= entries.count }
35+
internal var isEmpty: Bool { firstEntry >= entries.count }
3636

37-
public mutating func dequeue() -> Entry? {
37+
internal mutating func dequeue() -> Entry? {
3838
guard !self.isEmpty else { return nil }
3939
defer { firstEntry += 1 }
4040
return entries[firstEntry]
4141
}
4242

43-
internal mutating func add(_ e: Entry) {
43+
fileprivate mutating func add(_ e: Entry) {
4444
entries.append(e)
4545
}
46-
47-
public mutating func clear() { entries.removeAll() }
48-
}
49-
50-
// TODO: Track
51-
public struct WriteBuffer {
52-
public var enabled: Bool = false
53-
54-
private var buffer: [UInt8] = []
55-
private var chunkSize: Int? = nil
56-
57-
internal mutating func write(_ buf: UnsafeRawBufferPointer) -> Int {
58-
guard enabled else { return 0 }
59-
let chunk = chunkSize ?? buf.count
60-
buffer.append(contentsOf: buf.prefix(chunk))
61-
return chunk
62-
}
63-
64-
public var contents: [UInt8] { buffer }
6546
}
6647

67-
public enum ForceErrno: Equatable {
48+
internal enum ForceErrno: Equatable {
6849
case none
6950
case always(errno: CInt)
7051

@@ -74,70 +55,17 @@ public enum ForceErrno: Equatable {
7455
// Provide access to the driver, context, and trace stack of mocking
7556
public class MockingDriver {
7657
// Record syscalls and their arguments
77-
public var trace = Trace()
58+
internal var trace = Trace()
7859

7960
// Mock errors inside syscalls
80-
public var forceErrno = ForceErrno.none
81-
82-
// A buffer to put `write` bytes into
83-
public var writeBuffer = WriteBuffer()
61+
internal var forceErrno = ForceErrno.none
8462

8563
// Whether we should pretend to be Windows for syntactic operations
8664
// inside FilePath
87-
public var forceWindowsSyntaxForPaths = false
65+
fileprivate var forceWindowsSyntaxForPaths = false
8866
}
8967

90-
#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
91-
import Darwin
92-
#elseif os(Linux) || os(FreeBSD) || os(Android)
93-
import Glibc
94-
#elseif os(Windows)
95-
import ucrt
96-
import WinSDK
97-
#else
98-
#error("Unsupported Platform")
99-
#endif
100-
101-
// TLS helper functions
102-
#if os(Windows)
103-
internal typealias TLSKey = DWORD
104-
internal func makeTLSKey() -> TLSKey {
105-
let raw: DWORD = FlsAlloc(nil)
106-
if raw == FLS_OUT_OF_INDEXES {
107-
fatalError("Unable to create key")
108-
}
109-
return raw
110-
}
111-
internal func setTLS(_ key: TLSKey, _ p: UnsafeMutableRawPointer?) {
112-
guard FlsSetValue(key, p) else {
113-
fatalError("Unable to set TLS")
114-
}
115-
}
116-
internal func getTLS(_ key: TLSKey) -> UnsafeMutableRawPointer? {
117-
FlsGetValue(key)
118-
}
119-
120-
#else
121-
122-
internal typealias TLSKey = pthread_key_t
123-
internal func makeTLSKey() -> TLSKey {
124-
var raw = pthread_key_t()
125-
guard 0 == pthread_key_create(&raw, nil) else {
126-
fatalError("Unable to create key")
127-
}
128-
return raw
129-
}
130-
internal func setTLS(_ key: TLSKey, _ p: UnsafeMutableRawPointer?) {
131-
guard 0 == pthread_setspecific(key, p) else {
132-
fatalError("Unable to set TLS")
133-
}
134-
}
135-
internal func getTLS(_ key: TLSKey) -> UnsafeMutableRawPointer? {
136-
pthread_getspecific(key)
137-
}
138-
#endif
139-
140-
private let driverKey: TLSKey = { makeTLSKey() }()
68+
private let driverKey: _PlatformTLSKey = { makeTLSKey() }()
14169

14270
internal var currentMockingDriver: MockingDriver? {
14371
#if !ENABLE_MOCKING
@@ -152,7 +80,7 @@ internal var currentMockingDriver: MockingDriver? {
15280
extension MockingDriver {
15381
/// Enables mocking for the duration of `f` with a clean trace queue
15482
/// Restores prior mocking status and trace queue after execution
155-
public static func withMockingEnabled(
83+
internal static func withMockingEnabled(
15684
_ f: (MockingDriver) throws -> ()
15785
) rethrows {
15886
let priorMocking = currentMockingDriver
@@ -179,7 +107,7 @@ private var contextualMockingEnabled: Bool {
179107
}
180108

181109
extension MockingDriver {
182-
public static var enabled: Bool { mockingEnabled }
110+
internal static var enabled: Bool { mockingEnabled }
183111

184112
public static var forceWindowsPaths: Bool {
185113
currentMockingDriver?.forceWindowsSyntaxForPaths ?? false
@@ -198,11 +126,74 @@ internal var mockingEnabled: Bool {
198126
#endif
199127
}
200128

201-
@inlinable @inline(__always)
129+
@inline(__always) @inlinable
202130
public var forceWindowsPaths: Bool {
203131
#if !ENABLE_MOCKING
204132
return false
205133
#else
206134
return MockingDriver.forceWindowsPaths
207135
#endif
208136
}
137+
138+
139+
#if ENABLE_MOCKING
140+
// Strip the mock_system prefix and the arg list suffix
141+
private func originalSyscallName(_ s: String) -> String {
142+
// `function` must be of format `system_<name>(<parameters>)`
143+
precondition(s.starts(with: "system_"))
144+
return String(s.dropFirst("system_".count).prefix { $0.isLetter })
145+
}
146+
147+
private func mockImpl(
148+
name: String,
149+
_ args: [AnyHashable]
150+
) -> CInt {
151+
let origName = originalSyscallName(name)
152+
guard let driver = currentMockingDriver else {
153+
fatalError("Mocking requested from non-mocking context")
154+
}
155+
driver.trace.add(Trace.Entry(name: origName, args))
156+
157+
switch driver.forceErrno {
158+
case .none: break
159+
case .always(let e):
160+
system_errno = e
161+
return -1
162+
case .counted(let e, let count):
163+
assert(count >= 1)
164+
system_errno = e
165+
driver.forceErrno = count > 1 ? .counted(errno: e, count: count-1) : .none
166+
return -1
167+
}
168+
169+
return 0
170+
}
171+
172+
internal func _mock(
173+
name: String = #function, _ args: AnyHashable...
174+
) -> CInt {
175+
precondition(mockingEnabled)
176+
return mockImpl(name: name, args)
177+
}
178+
internal func _mockInt(
179+
name: String = #function, _ args: AnyHashable...
180+
) -> Int {
181+
Int(mockImpl(name: name, args))
182+
}
183+
184+
internal func _mockOffT(
185+
name: String = #function, _ args: AnyHashable...
186+
) -> COffT {
187+
COffT(mockImpl(name: name, args))
188+
}
189+
#endif // ENABLE_MOCKING
190+
191+
// Force paths to be treated as Windows syntactically if `enabled` is
192+
// true.
193+
internal func _withWindowsPaths(enabled: Bool, _ body: () -> ()) {
194+
guard enabled else { return body() }
195+
MockingDriver.withMockingEnabled { driver in
196+
driver.forceWindowsSyntaxForPaths = true
197+
body()
198+
}
199+
}

0 commit comments

Comments
 (0)