From 82f69d740753cb22cfcead4e63f0721bfed93cb9 Mon Sep 17 00:00:00 2001 From: Paul Toffoloni Date: Tue, 19 Aug 2025 17:12:06 +0200 Subject: [PATCH 1/3] Remove `dereferenced(in:)` from `MultipartContentInspector` --- .../Multipart/MultipartContentInspector.swift | 27 ++- .../SnippetBasedReferenceTests.swift | 203 ++++++++++++++++++ 2 files changed, 221 insertions(+), 9 deletions(-) diff --git a/Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift b/Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift index b974190b..54da84b2 100644 --- a/Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift +++ b/Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift @@ -257,7 +257,8 @@ extension FileTranslator { default: return .infer(.primitive) } } - func inferAllOfAnyOfOneOf(_ schemas: [DereferencedJSONSchema]) throws -> MultipartPartInfo.ContentTypeSource? { + + func inferAllOfAnyOfOneOf(_ schemas: [JSONSchema]) throws -> MultipartPartInfo.ContentTypeSource? { // If all schemas are primitive, the allOf/anyOf/oneOf is also primitive. // These cannot be binary, so only primitive vs complex. for schema in schemas { @@ -266,12 +267,13 @@ extension FileTranslator { } return .infer(.primitive) } - func inferSchema(_ schema: DereferencedJSONSchema) throws -> ( + + func inferSchema(_ schema: JSONSchema) throws -> ( MultipartPartInfo.RepetitionKind, MultipartPartInfo.ContentTypeSource )? { let repetitionKind: MultipartPartInfo.RepetitionKind let candidateSource: MultipartPartInfo.ContentTypeSource - switch schema { + switch schema.value { case .null, .not: return nil case .boolean, .number, .integer: repetitionKind = .single @@ -279,7 +281,7 @@ extension FileTranslator { case .string(_, let context): repetitionKind = .single candidateSource = try inferStringContent(context) - case .object, .fragment: + case .object(_, _), .fragment: repetitionKind = .single candidateSource = .infer(.complex) case .all(of: let schemas, _), .one(of: let schemas, _), .any(of: let schemas, _): @@ -289,21 +291,28 @@ extension FileTranslator { case .array(_, let context): repetitionKind = .array if let items = context.items { - switch items { + switch items.value { case .null, .not: return nil case .boolean, .number, .integer: candidateSource = .infer(.primitive) case .string(_, let context): candidateSource = try inferStringContent(context) - case .object, .all, .one, .any, .fragment, .array: candidateSource = .infer(.complex) + case .object(_, _), .all, .one, .any, .fragment, .array(_, _): candidateSource = .infer(.complex) + default: candidateSource = .infer(.complex) } } else { candidateSource = .infer(.complex) } + case .reference(let ref, _): + guard let (refRepetitionKind, refCandidateSource) = try inferSchema(components.lookup(ref)) else { + return nil + } + repetitionKind = refRepetitionKind + candidateSource = refCandidateSource } + return (repetitionKind, candidateSource) } - guard let (repetitionKind, candidateSource) = try inferSchema(schema.dereferenced(in: components)) else { - return nil - } + guard let (repetitionKind, candidateSource) = try inferSchema(schema) else { return nil } + let finalContentTypeSource: MultipartPartInfo.ContentTypeSource if let encoding, let contentType = encoding.contentTypes.first, encoding.contentTypes.count == 1 { finalContentTypeSource = try .explicit(contentType.asGeneratorContentType) diff --git a/Tests/OpenAPIGeneratorReferenceTests/SnippetBasedReferenceTests.swift b/Tests/OpenAPIGeneratorReferenceTests/SnippetBasedReferenceTests.swift index f3a59d0d..8aee47a5 100644 --- a/Tests/OpenAPIGeneratorReferenceTests/SnippetBasedReferenceTests.swift +++ b/Tests/OpenAPIGeneratorReferenceTests/SnippetBasedReferenceTests.swift @@ -4230,6 +4230,209 @@ final class SnippetBasedReferenceTests: XCTestCase { ) } + func testRequestMultipartBodyReferencedSchemaRecursive() throws { + try self.assertRequestInTypesClientServerTranslation( + """ + /foo: + post: + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: '#/components/schemas/Node' + responses: + default: + description: Response + """, + """ + schemas: + Node: + type: object + properties: + log: + type: string + parent: + $ref: '#/components/schemas/Node' + required: + - log + """, + input: """ + public struct Input: Sendable, Hashable { + @frozen public enum Body: Sendable, Hashable { + case multipartForm(OpenAPIRuntime.MultipartBody) + } + public var body: Operations.post_sol_foo.Input.Body + public init(body: Operations.post_sol_foo.Input.Body) { + self.body = body + } + } + """, + schemas: """ + public enum Schemas { + @frozen public enum Node: Sendable, Hashable { + public struct logPayload: Sendable, Hashable { + public var body: OpenAPIRuntime.HTTPBody + public init(body: OpenAPIRuntime.HTTPBody) { + self.body = body + } + } + case log(OpenAPIRuntime.MultipartPart) + public struct parentPayload: Sendable, Hashable { + public var body: Components.Schemas.Node + public init(body: Components.Schemas.Node) { + self.body = body + } + } + case parent(OpenAPIRuntime.MultipartPart) + case undocumented(OpenAPIRuntime.MultipartRawPart) + } + } + """, + client: """ + { input in + let path = try converter.renderedPath( + template: "/foo", + parameters: [] + ) + var request: HTTPTypes.HTTPRequest = .init( + soar_path: path, + method: .post + ) + suppressMutabilityWarning(&request) + let body: OpenAPIRuntime.HTTPBody? + switch input.body { + case let .multipartForm(value): + body = try converter.setRequiredRequestBodyAsMultipart( + value, + headerFields: &request.headerFields, + contentType: "multipart/form-data", + allowsUnknownParts: true, + requiredExactlyOncePartNames: [ + "log" + ], + requiredAtLeastOncePartNames: [], + atMostOncePartNames: [ + "parent" + ], + zeroOrMoreTimesPartNames: [], + encoding: { part in + switch part { + case let .log(wrapped): + var headerFields: HTTPTypes.HTTPFields = .init() + let value = wrapped.payload + let body = try converter.setRequiredRequestBodyAsBinary( + value.body, + headerFields: &headerFields, + contentType: "text/plain" + ) + return .init( + name: "log", + filename: wrapped.filename, + headerFields: headerFields, + body: body + ) + case let .parent(wrapped): + var headerFields: HTTPTypes.HTTPFields = .init() + let value = wrapped.payload + let body = try converter.setRequiredRequestBodyAsJSON( + value.body, + headerFields: &headerFields, + contentType: "application/json; charset=utf-8" + ) + return .init( + name: "parent", + filename: wrapped.filename, + headerFields: headerFields, + body: body + ) + case let .undocumented(value): + return value + } + } + ) + } + return (request, body) + } + """, + server: """ + { request, requestBody, metadata in + let contentType = converter.extractContentTypeIfPresent(in: request.headerFields) + let body: Operations.post_sol_foo.Input.Body + let chosenContentType = try converter.bestContentType( + received: contentType, + options: [ + "multipart/form-data" + ] + ) + switch chosenContentType { + case "multipart/form-data": + body = try converter.getRequiredRequestBodyAsMultipart( + OpenAPIRuntime.MultipartBody.self, + from: requestBody, + transforming: { value in + .multipartForm(value) + }, + boundary: contentType.requiredBoundary(), + allowsUnknownParts: true, + requiredExactlyOncePartNames: [ + "log" + ], + requiredAtLeastOncePartNames: [], + atMostOncePartNames: [ + "parent" + ], + zeroOrMoreTimesPartNames: [], + decoding: { part in + let headerFields = part.headerFields + let (name, filename) = try converter.extractContentDispositionNameAndFilename(in: headerFields) + switch name { + case "log": + try converter.verifyContentTypeIfPresent( + in: headerFields, + matches: "text/plain" + ) + let body = try converter.getRequiredRequestBodyAsBinary( + OpenAPIRuntime.HTTPBody.self, + from: part.body, + transforming: { + $0 + } + ) + return .log(.init( + payload: .init(body: body), + filename: filename + )) + case "parent": + try converter.verifyContentTypeIfPresent( + in: headerFields, + matches: "application/json" + ) + let body = try await converter.getRequiredRequestBodyAsJSON( + Components.Schemas.Node.self, + from: part.body, + transforming: { + $0 + } + ) + return .parent(.init( + payload: .init(body: body), + filename: filename + )) + default: + return .undocumented(part) + } + } + ) + default: + preconditionFailure("bestContentType chose an invalid content type.") + } + return Operations.post_sol_foo.Input(body: body) + } + """ + ) + } + func testRequestMultipartBodyReferencedSchemaWithEncoding() throws { try self.assertRequestInTypesClientServerTranslation( """ From 32b60447850f1fe094b6125b43954abe3d60710d Mon Sep 17 00:00:00 2001 From: Paul Toffoloni Date: Tue, 19 Aug 2025 17:23:03 +0200 Subject: [PATCH 2/3] Fix ref array inference --- .../Translator/Multipart/MultipartContentInspector.swift | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift b/Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift index 54da84b2..19dcf61f 100644 --- a/Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift +++ b/Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift @@ -281,7 +281,7 @@ extension FileTranslator { case .string(_, let context): repetitionKind = .single candidateSource = try inferStringContent(context) - case .object(_, _), .fragment: + case .object, .fragment: repetitionKind = .single candidateSource = .infer(.complex) case .all(of: let schemas, _), .one(of: let schemas, _), .any(of: let schemas, _): @@ -295,8 +295,10 @@ extension FileTranslator { case .null, .not: return nil case .boolean, .number, .integer: candidateSource = .infer(.primitive) case .string(_, let context): candidateSource = try inferStringContent(context) - case .object(_, _), .all, .one, .any, .fragment, .array(_, _): candidateSource = .infer(.complex) - default: candidateSource = .infer(.complex) + case .object, .all, .one, .any, .fragment, .array: candidateSource = .infer(.complex) + case .reference(let ref, _): + guard let source = try inferSchema(components.lookup(ref))?.1 else { return nil } + candidateSource = source } } else { candidateSource = .infer(.complex) From 5db200f0054e1edcb6906a79d4d7d197ea8ab357 Mon Sep 17 00:00:00 2001 From: Paul Toffoloni Date: Mon, 25 Aug 2025 17:01:02 +0200 Subject: [PATCH 3/3] Update test --- .../SnippetBasedReferenceTests.swift | 108 ++++++++---------- 1 file changed, 49 insertions(+), 59 deletions(-) diff --git a/Tests/OpenAPIGeneratorReferenceTests/SnippetBasedReferenceTests.swift b/Tests/OpenAPIGeneratorReferenceTests/SnippetBasedReferenceTests.swift index 8aee47a5..5511845a 100644 --- a/Tests/OpenAPIGeneratorReferenceTests/SnippetBasedReferenceTests.swift +++ b/Tests/OpenAPIGeneratorReferenceTests/SnippetBasedReferenceTests.swift @@ -4240,27 +4240,28 @@ final class SnippetBasedReferenceTests: XCTestCase { content: multipart/form-data: schema: - $ref: '#/components/schemas/Node' + $ref: '#/components/schemas/NodeWrapper' responses: default: description: Response """, """ schemas: + NodeWrapper: + type: object + properties: + node: + $ref: '#/components/schemas/Node' Node: type: object properties: - log: - type: string parent: $ref: '#/components/schemas/Node' - required: - - log """, input: """ public struct Input: Sendable, Hashable { @frozen public enum Body: Sendable, Hashable { - case multipartForm(OpenAPIRuntime.MultipartBody) + case multipartForm(OpenAPIRuntime.MultipartBody) } public var body: Operations.post_sol_foo.Input.Body public init(body: Operations.post_sol_foo.Input.Body) { @@ -4270,23 +4271,46 @@ final class SnippetBasedReferenceTests: XCTestCase { """, schemas: """ public enum Schemas { - @frozen public enum Node: Sendable, Hashable { - public struct logPayload: Sendable, Hashable { - public var body: OpenAPIRuntime.HTTPBody - public init(body: OpenAPIRuntime.HTTPBody) { - self.body = body - } - } - case log(OpenAPIRuntime.MultipartPart) - public struct parentPayload: Sendable, Hashable { + @frozen public enum NodeWrapper: Sendable, Hashable { + public struct nodePayload: Sendable, Hashable { public var body: Components.Schemas.Node public init(body: Components.Schemas.Node) { self.body = body } } - case parent(OpenAPIRuntime.MultipartPart) + case node(OpenAPIRuntime.MultipartPart) case undocumented(OpenAPIRuntime.MultipartRawPart) } + public struct Node: Codable, Hashable, Sendable { + public var parent: Components.Schemas.Node? { + get { + self.storage.value.parent + } + _modify { + yield &self.storage.value.parent + } + } + public init(parent: Components.Schemas.Node? = nil) { + self.storage = .init(value: .init(parent: parent)) + } + public enum CodingKeys: String, CodingKey { + case parent + } + public init(from decoder: any Decoder) throws { + self.storage = try .init(from: decoder) + } + public func encode(to encoder: any Encoder) throws { + try self.storage.encode(to: encoder) + } + private var storage: OpenAPIRuntime.CopyOnWriteBox + private struct Storage: Codable, Hashable, Sendable { + var parent: Components.Schemas.Node? + init(parent: Components.Schemas.Node? = nil) { + self.parent = parent + } + typealias CodingKeys = Components.Schemas.Node.CodingKeys + } + } } """, client: """ @@ -4308,31 +4332,15 @@ final class SnippetBasedReferenceTests: XCTestCase { headerFields: &request.headerFields, contentType: "multipart/form-data", allowsUnknownParts: true, - requiredExactlyOncePartNames: [ - "log" - ], + requiredExactlyOncePartNames: [], requiredAtLeastOncePartNames: [], atMostOncePartNames: [ - "parent" + "node" ], zeroOrMoreTimesPartNames: [], encoding: { part in switch part { - case let .log(wrapped): - var headerFields: HTTPTypes.HTTPFields = .init() - let value = wrapped.payload - let body = try converter.setRequiredRequestBodyAsBinary( - value.body, - headerFields: &headerFields, - contentType: "text/plain" - ) - return .init( - name: "log", - filename: wrapped.filename, - headerFields: headerFields, - body: body - ) - case let .parent(wrapped): + case let .node(wrapped): var headerFields: HTTPTypes.HTTPFields = .init() let value = wrapped.payload let body = try converter.setRequiredRequestBodyAsJSON( @@ -4341,7 +4349,7 @@ final class SnippetBasedReferenceTests: XCTestCase { contentType: "application/json; charset=utf-8" ) return .init( - name: "parent", + name: "node", filename: wrapped.filename, headerFields: headerFields, body: body @@ -4368,42 +4376,24 @@ final class SnippetBasedReferenceTests: XCTestCase { switch chosenContentType { case "multipart/form-data": body = try converter.getRequiredRequestBodyAsMultipart( - OpenAPIRuntime.MultipartBody.self, + OpenAPIRuntime.MultipartBody.self, from: requestBody, transforming: { value in .multipartForm(value) }, boundary: contentType.requiredBoundary(), allowsUnknownParts: true, - requiredExactlyOncePartNames: [ - "log" - ], + requiredExactlyOncePartNames: [], requiredAtLeastOncePartNames: [], atMostOncePartNames: [ - "parent" + "node" ], zeroOrMoreTimesPartNames: [], decoding: { part in let headerFields = part.headerFields let (name, filename) = try converter.extractContentDispositionNameAndFilename(in: headerFields) switch name { - case "log": - try converter.verifyContentTypeIfPresent( - in: headerFields, - matches: "text/plain" - ) - let body = try converter.getRequiredRequestBodyAsBinary( - OpenAPIRuntime.HTTPBody.self, - from: part.body, - transforming: { - $0 - } - ) - return .log(.init( - payload: .init(body: body), - filename: filename - )) - case "parent": + case "node": try converter.verifyContentTypeIfPresent( in: headerFields, matches: "application/json" @@ -4415,7 +4405,7 @@ final class SnippetBasedReferenceTests: XCTestCase { $0 } ) - return .parent(.init( + return .node(.init( payload: .init(body: body), filename: filename ))