diff --git a/Generator/Sources/Internal/Crawlers/Crawler.swift b/Generator/Sources/Internal/Crawlers/Crawler.swift index c37975d0..5e43bc3c 100644 --- a/Generator/Sources/Internal/Crawlers/Crawler.swift +++ b/Generator/Sources/Internal/Crawlers/Crawler.swift @@ -76,7 +76,8 @@ final class Crawler: SyntaxVisitor { attributes: attributes(from: node.attributes), accessibility: accessibility(from: node.modifiers) ?? (container as? HasAccessibility)?.accessibility ?? .internal, name: node.name.filteredDescription, - genericParameters: (genericParameters(from: node.primaryAssociatedTypeClause?.primaryAssociatedTypes) + associatedTypes(from: node.memberBlock.members)).merged(), + associatedTypes: associatedTypes(from: node.memberBlock.members), + primaryAssociatedTypes: genericParameters(from: node.primaryAssociatedTypeClause?.primaryAssociatedTypes), genericRequirements: genericRequirements(from: node.genericWhereClause?.requirements), inheritedTypes: inheritedTypes, members: [] diff --git a/Generator/Sources/Internal/GeneratorHelper.swift b/Generator/Sources/Internal/GeneratorHelper.swift index ac96eb4a..5d826cdf 100644 --- a/Generator/Sources/Internal/GeneratorHelper.swift +++ b/Generator/Sources/Internal/GeneratorHelper.swift @@ -40,7 +40,7 @@ struct GeneratorHelper { let matchableWhereConstraints = method.signature.parameters.enumerated().map { index, parameter -> String in let type = parameter.type.isOptional ? "OptionalMatchedType" : "MatchedType" - return "M\(index + 1).\(type) == \(genericSafeType(from: parameter.type.withoutAttributes(except: ["@Sendable"]).unoptionaled.description))" + return "M\(index + 1).\(type) == \(genericSafeType(from: parameter.type.withoutAttributes(except: ["@MainActor", "@Sendable"]).unoptionaled.description))" } let methodWhereConstraints = method.signature.whereConstraints return " where \((matchableWhereConstraints + methodWhereConstraints).joined(separator: ", "))" @@ -57,7 +57,7 @@ struct GeneratorHelper { private static func parameterMatchers(for parameters: [MethodParameter]) -> String { guard parameters.isEmpty == false else { return "let matchers: [Cuckoo.ParameterMatcher] = []" } - let tupleType = parameters.map { $0.type.withoutAttributes(except: ["@Sendable"]).description }.joined(separator: ", ") + let tupleType = parameters.map { $0.type.withoutAttributes(except: ["@MainActor", "@Sendable"]).description }.joined(separator: ", ") let matchers = parameters // Enumeration is done after filtering out parameters without usable names. .enumerated() diff --git a/Generator/Sources/Internal/Templates/MockTemplate.swift b/Generator/Sources/Internal/Templates/MockTemplate.swift index 4b057f3b..2ca05475 100644 --- a/Generator/Sources/Internal/Templates/MockTemplate.swift +++ b/Generator/Sources/Internal/Templates/MockTemplate.swift @@ -15,11 +15,17 @@ extension Templates { {% if container.hasParent %} extension {{ container.parentFullyQualifiedName }} { {% endif %} +{% if container.hasPrimaryAssociatedTypes %} +// runtime support for constrained protocols with primary associated types +@available(iOS 16, macOS 13, watchOS 9, tvOS 16, *) +{% endif %} {{ container.accessibility|withSpace }}class {{ container.mockName }}{{ container.genericParameters }}:{% if container.isNSObjectProtocol %} NSObject,{% endif %} {{ container.name }}{% if container.isImplementation %}{{ container.genericArguments }}{% endif %},{% if container.isImplementation %} Cuckoo.ClassMock{% else %} Cuckoo.ProtocolMock{% endif %}, @unchecked Sendable { - {% if container.isGeneric and not container.isImplementation %} + {% if container.isGeneric and not container.isImplementation and not container.hasOnlyPrimaryAssociatedTypes %} {{ container.accessibility|withSpace }}typealias MocksType = \(typeErasureClassName) - {% else %} + {% elif container.isImplementation %} {{ container.accessibility|withSpace }}typealias MocksType = {{ container.name }}{{ container.genericArguments }} + {% else %} + {{ container.accessibility|withSpace }}typealias MocksType = any {{ container.name }}{{ container.genericArguments }} {% endif %} {{ container.accessibility|withSpace }}typealias Stubbing = __StubbingProxy_{{ container.name }} {{ container.accessibility|withSpace }}typealias Verification = __VerificationProxy_{{ container.name }} @@ -31,7 +37,7 @@ extension {{ container.parentFullyQualifiedName }} { {{ container.accessibility|withSpace }}let cuckoo_manager = Cuckoo.MockManager.preconfiguredManager ?? Cuckoo.MockManager(hasParent: {{ container.isImplementation }}) - {% if container.isGeneric and not container.isImplementation %} + {% if container.isGeneric and not container.isImplementation and not container.hasOnlyPrimaryAssociatedTypes %} \(Templates.typeErasure.indented()) private var __defaultImplStub: \(typeErasureClassName)? @@ -43,7 +49,7 @@ extension {{ container.parentFullyQualifiedName }} { } {{ container.accessibility|withSpace }}func enableDefaultImplementation<\(staticGenericParameter): {{ container.name }}>(mutating stub: UnsafeMutablePointer<\(staticGenericParameter)>) where {{ container.genericProtocolIdentity }} { - __defaultImplStub = \(typeErasureClassName)(from: stub, keeping: nil) + __defaultImplStub = \(typeErasureClassName)(from: stub, keeping: stub.pointee) cuckoo_manager.enableDefaultStubImplementation() } {% else %} diff --git a/Generator/Sources/Internal/Templates/TypeErasureTemplate.swift b/Generator/Sources/Internal/Templates/TypeErasureTemplate.swift index 7b05109f..1de3c957 100644 --- a/Generator/Sources/Internal/Templates/TypeErasureTemplate.swift +++ b/Generator/Sources/Internal/Templates/TypeErasureTemplate.swift @@ -3,7 +3,7 @@ import Foundation extension Templates { static let typeErasure = """ {{ container.accessibility|withSpace }}class \(typeErasureClassName): {{ container.name }}, @unchecked Sendable { - private let reference: Any + private let reference: () -> any {{ container.name }}{{ container.genericPrimaryAssociatedTypeArguments }} {% for property in container.properties %} private let _getter_storage$${{ property.name }}: () -> {{ property.type }} @@ -18,9 +18,9 @@ extension Templates { } {% endfor %} - {# For developers: The `keeping reference: Any?` is necessary because when called from the `enableDefaultImplementation(stub:)` method + {# For developers: The `keeping reference: Any` is necessary because when called from the `enableDefaultImplementation(stub:)` method instead of `enableDefaultImplementation(mutating:)`, we need to prevent the struct getting deallocated. #} - init<\(staticGenericParameter): {{ container.name }}>(from defaultImpl: UnsafeMutablePointer<\(staticGenericParameter)>, keeping reference: @escaping @autoclosure () -> Any?) where {{ container.genericProtocolIdentity }} { + init<\(staticGenericParameter): {{ container.name }}>(from defaultImpl: UnsafeMutablePointer<\(staticGenericParameter)>, keeping reference: @escaping @autoclosure () -> \(staticGenericParameter)) where {{ container.genericProtocolIdentity }} { self.reference = reference {% for property in container.properties %} @@ -30,7 +30,9 @@ extension Templates { {% endif %} {% endfor %} {% for method in container.methods %} + {% if not method.hasGenericParams %} _storage${{ forloop.counter }}${{ method.name }} = defaultImpl.pointee.{{ method.name }} + {% endif %} {% endfor %} } {% if container.initializers %} @@ -43,9 +45,20 @@ extension Templates { {% endfor %} {% for method in container.methods +%} + {% if not method.hasGenericParams %} private let _storage${{ forloop.counter }}${{ method.name }}: ({{ method.inputTypes }}) {% if method.isAsync %} async{% endif %} {% if method.isThrowing %} throws{% endif %} -> {{ method.returnType }} + {% endif %} {{ container.accessibility|withSpace }}func {{ method.name|escapeReservedKeywords }}{{ method.signature }} { + {% if method.hasGenericParams and method.hasNonPrimaryAssociatedTypeParams %} + func openExistential<\(staticGenericParameter): {{ container.name }}{{ container.genericPrimaryAssociatedTypeArguments }}>(_ opened: \(staticGenericParameter)) {% if method.isAsync %} async{% endif %} {% if method.isThrowing %} throws{% endif %} -> {{ method.returnType }} { + return {% if method.isThrowing %} try{% endif %} {% if method.isAsync %} await{% endif %} opened.{{ method.name }}{{ method.staticGenericCall }} + } + return {% if method.isThrowing %} try{% endif %} {% if method.isAsync %} await{% endif %} openExistential(reference()) + {% elif method.hasGenericParams %} + return {% if method.isThrowing %} try{% endif %} {% if method.isAsync %} await{% endif %} reference().{{ method.name }}({{ method.call }}) + {% else %} return {% if method.isThrowing %} try{% endif %} {% if method.isAsync %} await{% endif %} _storage${{ forloop.counter }}${{ method.name }}({{ method.parameterNames }}) + {% endif %} } {% endfor %} } diff --git a/Generator/Sources/Internal/Tokens/Capabilities/HasGenerics.swift b/Generator/Sources/Internal/Tokens/Capabilities/HasGenerics.swift index 036242bf..1d385205 100644 --- a/Generator/Sources/Internal/Tokens/Capabilities/HasGenerics.swift +++ b/Generator/Sources/Internal/Tokens/Capabilities/HasGenerics.swift @@ -19,15 +19,39 @@ extension HasGenerics { var isGeneric: Bool { !genericParameters.isEmpty } + + var hasPrimaryAssociatedTypes: Bool { + guard let protocolDeclaration = asProtocol else { return false } + return !protocolDeclaration.primaryAssociatedTypes.isEmpty + } + + var hasOnlyPrimaryAssociatedTypes: Bool { + guard let protocolDeclaration = asProtocol else { return false } + return protocolDeclaration.primaryAssociatedTypes.count == protocolDeclaration.genericParameters.count + } func genericsSerialize() -> GeneratorContext { - let genericProtocolIdentity = isProtocol ? genericParameters.map { "\(Templates.staticGenericParameter).\($0.name) == \($0.name)" }.joined(separator: ", ") : nil - + let genericProtocolIdentity = isProtocol + ? genericParameters + .map { "\(Templates.staticGenericParameter).\($0.name) == \($0.name)" } + .joined(separator: ", ") + : nil + let genericPrimaryAssociatedTypeArguments: String? + if let protocolDeclaration = asProtocol, hasPrimaryAssociatedTypes { + let arguments = protocolDeclaration.primaryAssociatedTypes.map { $0.name }.joined(separator: ", ") + genericPrimaryAssociatedTypeArguments = "<\(arguments)>" + } else { + genericPrimaryAssociatedTypeArguments = nil + } + return [ "isGeneric": isGeneric, "genericParameters": genericParametersString, "genericArguments": genericArgumentsString, + "hasPrimaryAssociatedTypes": hasPrimaryAssociatedTypes, + "hasOnlyPrimaryAssociatedTypes": hasOnlyPrimaryAssociatedTypes, "genericProtocolIdentity": genericProtocolIdentity, + "genericPrimaryAssociatedTypeArguments": genericPrimaryAssociatedTypeArguments, ] .compactMapValues { $0 } } diff --git a/Generator/Sources/Internal/Tokens/ComplexType.swift b/Generator/Sources/Internal/Tokens/ComplexType.swift index b8c7b911..67fbe976 100644 --- a/Generator/Sources/Internal/Tokens/ComplexType.swift +++ b/Generator/Sources/Internal/Tokens/ComplexType.swift @@ -241,6 +241,83 @@ extension ComplexType { nil } } + + func replaceType(named typeName: String, with replacement: String) -> ComplexType? { + switch self { + case .attributed(let attributes, let baseType): + return baseType.replaceType(named: typeName, with: replacement) + .map { ComplexType.attributed(attributes: attributes, baseType: $0) } + case .optional(let wrappedType, let isImplicit): + return wrappedType.replaceType(named: typeName, with: replacement) + .map { ComplexType.optional(wrappedType: $0, isImplicit: isImplicit) } + case .array(let elementType): + return elementType.replaceType(named: typeName, with: replacement) + .map { ComplexType.array(elementType: $0) } + case .dictionary(let keyType, let valueType): + let newKey = keyType.replaceType(named: typeName, with: replacement) + let newValue = valueType.replaceType(named: typeName, with: replacement) + if newKey == nil && newValue == nil { return nil } + return .dictionary( + keyType: newKey ?? keyType, + valueType: newValue ?? valueType + ) + case .closure(let closure): + var changed = false + let newParams: [ComplexType.Closure.Parameter] = closure.parameters.map { param in + if let newType = param.type.replaceType(named: typeName, with: replacement) { + changed = true + return ComplexType.Closure.Parameter(label: param.label, type: newType) + } else { + return param + } + } + let newReturn = closure.returnType.replaceType(named: typeName, with: replacement) + if !changed && newReturn == nil { return nil } + return .closure(.init(parameters: newParams, effects: closure.effects, returnType: newReturn ?? closure.returnType)) + case .type(let name): + return name == typeName ? ComplexType.type(replacement) : nil + } + } + + func replaceTypes(named typeNames: [String], with replacement: (String) -> String) -> ComplexType? { + var changed = false + var type = self + for typeName in typeNames { + if let replaced = type.replaceType(named: typeName, with: replacement(typeName)) { + changed = true + type = replaced + } + } + + return changed ? type : nil + } + + func containsType(named typeName: String) -> Bool { + switch self { + case .attributed(_, let baseType): + baseType.containsType(named: typeName) + case .optional(let wrappedType, _): + wrappedType.containsType(named: typeName) + case .array(let elementType): + elementType.containsType(named: typeName) + case .dictionary(let keyType, let valueType): + keyType.containsType(named: typeName) || valueType.containsType(named: typeName) + case .closure(let closure): + (closure.parameters.map(\.type) + [closure.returnType]).contains(where: { $0.containsType(named: typeName)}) + case .type(let name): + name == typeName + } + } + + func containsTypes(named typeNames: [String]) -> Bool { + typeNames.contains(where: { containsType(named: $0) }) + } +} + +extension String { + func forceCast(as type: ComplexType) -> String { + "\(self) as! \(type.withoutAttributes(except: ["@MainActor", "@Sendable"]).description)" + } } extension ComplexType.Closure.Effects { diff --git a/Generator/Sources/Internal/Tokens/Method.swift b/Generator/Sources/Internal/Tokens/Method.swift index f20a4be2..b817bed4 100644 --- a/Generator/Sources/Internal/Tokens/Method.swift +++ b/Generator/Sources/Internal/Tokens/Method.swift @@ -55,24 +55,42 @@ extension Method { var hasOptionalParams: Bool { signature.parameters.contains { $0.type.isOptional } } + + var hasGenericParams: Bool { + !signature.genericParameters.isEmpty + } + + var hasNonPrimaryAssociatedTypeParams: Bool { + guard let parent = parent?.asProtocol else { return false } + return signature.containsTypes(named: parent.nonPrimaryAssociatedTypes.map(\.name)) + } func serialize() -> [String : Any] { - let call = signature.parameters - .map { parameter in - let name = escapeReservedKeywords(for: parameter.usableName) - let value = "\(parameter.isInout ? "&" : "")\(name)\(parameter.type.containsAttribute(named: "@autoclosure") ? "()" : "")" - if parameter.name == "_" { - return value - } else { - return "\(parameter.name): \(value)" - } - } - .joined(separator: ", ") - guard let parent else { fatalError("Failed to find parent of method \(fullSignature). Please file a bug.") } - + + let call = signature.parameters + .map(\.call) + .joined(separator: ", ") + + let staticGenericCall: String + if let parent = parent.asProtocol, !parent.nonPrimaryAssociatedTypes.isEmpty { + let nonPrimary = parent.nonPrimaryAssociatedTypes.map(\.name) + + let staticGenericCallableParameters = signature.parameters + .map { $0.callAndCastTypes(named: nonPrimary, as: { Templates.staticGenericParameter + ".\($0)" }) } + .joined(separator: ", ") + + staticGenericCall = if let returnType, returnType.containsTypes(named: nonPrimary) { + "(\(staticGenericCallableParameters))".forceCast(as: returnType) + } else { + "(\(staticGenericCallableParameters))" + } + } else { + staticGenericCall = "(\(call))" + } + let stubFunctionPrefix = parent.isClass ? "Class" : "Protocol" let returnString = returnType?.isVoid == false ? "" : "NoReturn" let throwingString = isThrowing ? "Throwing" : "" @@ -108,16 +126,19 @@ extension Method { "throwTypeError": signature.throwType?.type ?? "", "fullyQualifiedName": fullyQualifiedName, "call": call, + "staticGenericCall": staticGenericCall, "parameterSignature": signature.parameters.map { $0.description }.joined(separator: ", "), "parameterSignatureWithoutNames": signature.parameters.map { "\($0.name): \($0.type)" }.joined(separator: ", "), "argumentSignature": signature.parameters.map { $0.type.description }.joined(separator: ", "), "stubFunction": stubFunction, - "inputTypes": signature.parameters.map { $0.type.withoutAttributes(except: ["@escaping", "@Sendable"]).description }.joined(separator: ", "), - "genericInputTypes": signature.parameters.map { $0.type.withoutAttributes(except: ["@Sendable"]).description }.joined(separator: ", "), + "inputTypes": signature.parameters.map { $0.type.withoutAttributes(except: ["@escaping", "@MainActor", "@Sendable"]).description }.joined(separator: ", "), + "genericInputTypes": signature.parameters.map { $0.type.withoutAttributes(except: ["@MainActor", "@Sendable"]).description }.joined(separator: ", "), "isOptional": isOptional, "hasClosureParams": hasClosureParams, "hasOptionalParams": hasOptionalParams, + "hasGenericParams": hasGenericParams, "genericParameters": signature.genericParameters.sourceDescription, + "hasNonPrimaryAssociatedTypeParams": hasNonPrimaryAssociatedTypeParams, "hasUnavailablePlatforms": hasUnavailablePlatforms, "unavailablePlatformsCheck": unavailablePlatformsCheck, ] diff --git a/Generator/Sources/Internal/Tokens/MethodParameter.swift b/Generator/Sources/Internal/Tokens/MethodParameter.swift index 8f921a5a..35b323dc 100644 --- a/Generator/Sources/Internal/Tokens/MethodParameter.swift +++ b/Generator/Sources/Internal/Tokens/MethodParameter.swift @@ -25,6 +25,27 @@ struct MethodParameter: Token { var isEscaping: Bool { type.isClosure && (type.containsAttribute(named: "@escaping") || type.isOptional) } + + var call: String { + let escapedName = escapeReservedKeywords(for: usableName) + let value = "\(isInout ? "&" : "")\(escapedName)\(type.containsAttribute(named: "@autoclosure") ? "()" : "")" + if name == "_" { + return value + } else { + return "\(name): \(value)" + } + } + + func callAndCastTypes(named typeNames: [String], as replacement: (String) -> String) -> String { + let replaced = type.replaceTypes(named: typeNames, with: replacement) + + let callToCast = call + if let replaced { + return callToCast.forceCast(as: replaced) + } else { + return callToCast + } + } func serialize() -> [String: Any] { return [ diff --git a/Generator/Sources/Internal/Tokens/MethodSignature.swift b/Generator/Sources/Internal/Tokens/MethodSignature.swift index 68c94958..d33fefc7 100644 --- a/Generator/Sources/Internal/Tokens/MethodSignature.swift +++ b/Generator/Sources/Internal/Tokens/MethodSignature.swift @@ -64,3 +64,14 @@ extension Method.Signature { && whereConstraints == other.whereConstraints } } + +extension Method.Signature { + func containsType(named typeName: String) -> Bool { + parameters.map(\.type) + .contains(where: { $0.containsType(named: typeName) }) + } + + func containsTypes(named typeNames: [String]) -> Bool { + typeNames.contains(where: { containsType(named: $0) }) + } +} diff --git a/Generator/Sources/Internal/Tokens/ProtocolDeclaration.swift b/Generator/Sources/Internal/Tokens/ProtocolDeclaration.swift index e09f6a96..ab12e6a3 100644 --- a/Generator/Sources/Internal/Tokens/ProtocolDeclaration.swift +++ b/Generator/Sources/Internal/Tokens/ProtocolDeclaration.swift @@ -4,7 +4,8 @@ struct ProtocolDeclaration: ContainerToken { var attributes: [Attribute] var accessibility: Accessibility var name: String - var genericParameters: [GenericParameter] + var associatedTypes: [GenericParameter] + var primaryAssociatedTypes: [GenericParameter] var genericRequirements: [String] var inheritedTypes: [String] var members: [Token] @@ -16,7 +17,8 @@ struct ProtocolDeclaration: ContainerToken { attributes: attributes, accessibility: accessibility, name: name, - genericParameters: genericParameters, + associatedTypes: associatedTypes, + primaryAssociatedTypes: primaryAssociatedTypes, genericRequirements: genericRequirements, inheritedTypes: inheritedTypes, members: members, @@ -30,11 +32,20 @@ struct ProtocolDeclaration: ContainerToken { attributes: attributes, accessibility: accessibility, name: name, - genericParameters: genericParameters, + associatedTypes: associatedTypes, + primaryAssociatedTypes: primaryAssociatedTypes, genericRequirements: genericRequirements, inheritedTypes: inheritedTypes, members: members, isNSObjectProtocol: isNSObjectProtocol ) } + + var genericParameters: [GenericParameter] { + (associatedTypes + primaryAssociatedTypes).merged() + } + + var nonPrimaryAssociatedTypes: [GenericParameter] { + associatedTypes.filter { !primaryAssociatedTypes.map(\.name).contains($0.name) } + } } diff --git a/Tests/Swift/GenericProtocolTest.swift b/Tests/Swift/GenericProtocolTest.swift index 5124826a..33a090ff 100644 --- a/Tests/Swift/GenericProtocolTest.swift +++ b/Tests/Swift/GenericProtocolTest.swift @@ -59,6 +59,12 @@ private class GenericProtocolConformerClass: GenericProtocol { } func closureParameter(closure: @escaping () -> Void) {} + + func genericParameter(value: T) {} + + func genericAndAassociatedTypeParameters(value: T, theC: C, theV: V) -> (C?) -> V { + { _ in theV } + } } private struct GenericProtocolConformerStruct: GenericProtocol { @@ -94,8 +100,50 @@ private struct GenericProtocolConformerStruct: GenericProtocol func noReturnAsync() async {} func closureParameter(closure: @escaping () -> Void) {} + + func genericParameter(value: T) {} + + func genericAndAassociatedTypeParameters(value: T, theC: C, theV: V) -> (C?) -> V { + { _ in theV } + } +} + +private final class MixedPrimaryAssociatedTypeProtocolConformerClass: MixedPrimaryAssociatedTypeProtocol { + + let defaultOutput: Output + + init(defaultOutput: Output) { + self.defaultOutput = defaultOutput + } + + func convertToOutput( + value: sending T, + input: sending Input, + convert: @escaping @MainActor @Sendable (T, Input) async throws -> Output? + ) async rethrows -> Output { + if let result = try await convert(value, input) { + return result + } + return defaultOutput + } } +private struct MixedPrimaryAssociatedTypeProtocolConformerStruct: MixedPrimaryAssociatedTypeProtocol { + var defaultOutput: Output + + func convertToOutput( + value: T, + input: Input, + convert: @escaping @MainActor @Sendable (T, Input) async throws -> Output? + ) async rethrows -> Output { + if let result = try await convert(value, input) { + return result + } + return defaultOutput + } +} + +@available(iOS 16.0.0, macOS 13.0.0, watchOS 9.0, tvOS 16, *) final class GenericProtocolTest: XCTestCase { private func createMock(value: V, classy: MockTestedClass = MockTestedClass()) -> MockGenericProtocol { return MockGenericProtocol(theC: classy, theV: value) @@ -145,6 +193,96 @@ final class GenericProtocolTest: XCTestCase { verify(mock).optionalProperty.set(isNil()) } + func testCallSomeC() { + let mock = createMock(value: 1337) + let classy = MockTestedClass() + let expected = 42 + stub(mock) { mock in + when(mock.callSomeC(theC: classy)).thenReturn(expected) + } + + let actual = mock.callSomeC(theC: classy) + + XCTAssertEqual(actual, expected) + verify(mock).callSomeC(theC: classy) + } + + func testCallSomeV() { + let mock = createMock(value: 1) + let expected = 99 + stub(mock) { mock in + when(mock.callSomeV(theV: equal(to: 1))).thenReturn(expected) + } + + let actual = mock.callSomeV(theV: 1) + + XCTAssertEqual(actual, expected) + verify(mock).callSomeV(theV: equal(to: 1)) + } + + func testCompute() { + let mock = createMock(value: 5) + let classy = MockTestedClass() + let expected = MockTestedClass() + stub(mock) { mock in + when(mock.compute(classy: classy, value: equal(to: 5))).thenReturn(expected) + } + + let actual = mock.compute(classy: classy, value: 5) + + XCTAssertTrue(actual === expected) + verify(mock).compute(classy: classy, value: equal(to: 5)) + } + + func testClosureParameter() { + let mock = createMock(value: 10) + var closureInvoked = false + stub(mock) { mock in + when(mock.closureParameter(closure: any())).then { closure in + closure() + } + } + + mock.closureParameter { + closureInvoked = true + } + + XCTAssertTrue(closureInvoked) + verify(mock).closureParameter(closure: any()) + } + + func testGenericParameter() { + let mock = createMock(value: 10) + var captured: Any? + let expected = "wow" + stub(mock) { mock in + when(mock.genericParameter(value: equal(to: expected))).then { (value: String) in + captured = value + } + } + + mock.genericParameter(value: expected) + + XCTAssertEqual(captured as? String, expected) + verify(mock).genericParameter(value: equal(to: expected)) + } + + func testGenericAndAssociatedTypeParameters() { + let expected = 2112 + let mock = createMock(value: expected) + let classy = MockTestedClass() + stub(mock) { mock in + when(mock.genericAndAassociatedTypeParameters(value: equal(to: "value"), theC: classy, theV: equal(to: expected))).then { (_: String, _: MockTestedClass, theV: Int) -> (MockTestedClass?) -> Int in + { _ in theV } + } + } + + let closure = mock.genericAndAassociatedTypeParameters(value: "value", theC: classy, theV: expected) + + XCTAssertEqual(closure(classy), expected) + verify(mock).genericAndAassociatedTypeParameters(value: equal(to: "value"), theC: classy, theV: equal(to: expected)) + } + func testNoReturn() { let mock = createMock(value: "Hello. Sniffing through tests? If you're having trouble with Cuckoo, shoot us a message!") var called = false @@ -236,6 +374,85 @@ final class GenericProtocolTest: XCTestCase { verify(mock).noReturnAsync() } + + // MARK: - MockMixedPrimaryAssociatedTypeProtocol + + private typealias ConvertClosure = @MainActor @Sendable (Int, String) async throws -> Int? + + private func createMock() -> MockMixedPrimaryAssociatedTypeProtocol { + MockMixedPrimaryAssociatedTypeProtocol() + } + + func testDefaultOutputProperty() { + let mock = createMock() + let expected = 77 + stub(mock) { mock in + when(mock.defaultOutput.get).thenReturn(expected) + } + + XCTAssertEqual(mock.defaultOutput, expected) + verify(mock).defaultOutput.get() + } + + func testConvertToOutputStubbing() async throws { + let mock = createMock() + let value = 5 + let input = "input" + let expected = 123 + + stub(mock) { mock in + when(mock.convertToOutput(value: anyInt(), input: anyString(), convert: any(ConvertClosure.self))) + .thenReturn(expected) + } + let closure: @MainActor @Sendable (Int, String) async throws -> Int? = { _, _ in nil } + + let actual = try await mock.convertToOutput(value: value, input: input, convert: closure) + + XCTAssertEqual(actual, expected) + verify(mock).convertToOutput(value: equal(to: value), input: equal(to: input), convert: any(ConvertClosure.self)) + } + + func testConvertToOutputDefaultImplementation() async throws { + let defaultImpl = MixedPrimaryAssociatedTypeProtocolConformerClass(defaultOutput: 7) + let mock = createMock() + mock.enableDefaultImplementation(defaultImpl) + + let value = 3 + let input = "abc" + let closure: @MainActor @Sendable (Int, String) async throws -> Int? = { value, input in + value + input.count + } + let expected = value + input.count + + XCTAssertEqual(mock.defaultOutput, defaultImpl.defaultOutput) + let result = try await mock.convertToOutput(value: value, input: input, convert: closure) + + XCTAssertEqual(result, expected) + verify(mock).defaultOutput.get() + verify(mock).convertToOutput(value: equal(to: value), input: equal(to: input), convert: any(ConvertClosure.self)) + } + + func testStructDefaultImplementationReflectsMutations() async throws { + var defaultImpl = MixedPrimaryAssociatedTypeProtocolConformerStruct(defaultOutput: 5) + let mock = createMock() + mock.enableDefaultImplementation(mutating: &defaultImpl) + + XCTAssertEqual(mock.defaultOutput, 5) + + defaultImpl.defaultOutput = 9 + + XCTAssertEqual(mock.defaultOutput, 9) + + let value = 3 + let input = "abc" + + let fallback: @MainActor @Sendable (Int, String) async throws -> Int? = { _, _ in nil } + let result = try await mock.convertToOutput(value: value, input: input, convert: fallback) + + XCTAssertEqual(result, 9) + verify(mock, times(2)).defaultOutput.get() + verify(mock).convertToOutput(value: equal(to: value), input: equal(to: input), convert: any(ConvertClosure.self)) + } } extension MockTestedClass: Matchable{ diff --git a/Tests/Swift/Source/GenericProtocol.swift b/Tests/Swift/Source/GenericProtocol.swift index 2f5a78ec..aff35920 100644 --- a/Tests/Swift/Source/GenericProtocol.swift +++ b/Tests/Swift/Source/GenericProtocol.swift @@ -1,6 +1,6 @@ import Foundation -protocol GenericProtocol { +protocol GenericProtocol { associatedtype C: AnyObject associatedtype V @@ -27,6 +27,12 @@ protocol GenericProtocol { /// Test for a bug that produces uncompilable code when associated types are used along with closures. /// Requires change from WrappableType to ComplexType. func closureParameter(closure: @escaping () -> Void) + + /// Test for a bug that produces uncompilable code when generic functions are used in protocols with associated types. + func genericParameter(value: T) + + /// Test for a bug that produces uncompilable code when generic functions are used in protocols with associated types and function has associated type parameters. + func genericAndAassociatedTypeParameters(value: T, theC: C, theV: V) -> (C?) -> V } protocol PrimaryAssociatedTypeProtocol { @@ -34,3 +40,16 @@ protocol PrimaryAssociatedTypeProtocol { func connect() -> Output } + +protocol MixedPrimaryAssociatedTypeProtocol { + associatedtype Input: Equatable + associatedtype Output: Equatable + + var defaultOutput: Output { get } + + func convertToOutput( + value: sending T, + input: sending Input, + convert: @escaping @MainActor @Sendable (T, Input) async throws -> Output? + ) async rethrows -> Output +}