diff --git a/Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift b/Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift index b974190b..19dcf61f 100644 --- a/Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift +++ b/Sources/_OpenAPIGeneratorCore/Translator/Multipart/MultipartContentInspector.swift @@ -257,7 +257,8 @@ extension FileTranslator { default: return .infer(.primitive) } } - func inferAllOfAnyOfOneOf(_ schemas: [DereferencedJSONSchema]) throws -> MultipartPartInfo.ContentTypeSource? { + + func inferAllOfAnyOfOneOf(_ schemas: [JSONSchema]) throws -> MultipartPartInfo.ContentTypeSource? { // If all schemas are primitive, the allOf/anyOf/oneOf is also primitive. // These cannot be binary, so only primitive vs complex. for schema in schemas { @@ -266,12 +267,13 @@ extension FileTranslator { } return .infer(.primitive) } - func inferSchema(_ schema: DereferencedJSONSchema) throws -> ( + + func inferSchema(_ schema: JSONSchema) throws -> ( MultipartPartInfo.RepetitionKind, MultipartPartInfo.ContentTypeSource )? { let repetitionKind: MultipartPartInfo.RepetitionKind let candidateSource: MultipartPartInfo.ContentTypeSource - switch schema { + switch schema.value { case .null, .not: return nil case .boolean, .number, .integer: repetitionKind = .single @@ -289,21 +291,30 @@ extension FileTranslator { case .array(_, let context): repetitionKind = .array if let items = context.items { - switch items { + switch items.value { case .null, .not: return nil case .boolean, .number, .integer: candidateSource = .infer(.primitive) case .string(_, let context): candidateSource = try inferStringContent(context) case .object, .all, .one, .any, .fragment, .array: candidateSource = .infer(.complex) + case .reference(let ref, _): + guard let source = try inferSchema(components.lookup(ref))?.1 else { return nil } + candidateSource = source } } else { candidateSource = .infer(.complex) } + case .reference(let ref, _): + guard let (refRepetitionKind, refCandidateSource) = try inferSchema(components.lookup(ref)) else { + return nil + } + repetitionKind = refRepetitionKind + candidateSource = refCandidateSource } + return (repetitionKind, candidateSource) } - guard let (repetitionKind, candidateSource) = try inferSchema(schema.dereferenced(in: components)) else { - return nil - } + guard let (repetitionKind, candidateSource) = try inferSchema(schema) else { return nil } + let finalContentTypeSource: MultipartPartInfo.ContentTypeSource if let encoding, let contentType = encoding.contentTypes.first, encoding.contentTypes.count == 1 { finalContentTypeSource = try .explicit(contentType.asGeneratorContentType) diff --git a/Tests/OpenAPIGeneratorReferenceTests/SnippetBasedReferenceTests.swift b/Tests/OpenAPIGeneratorReferenceTests/SnippetBasedReferenceTests.swift index f3a59d0d..5511845a 100644 --- a/Tests/OpenAPIGeneratorReferenceTests/SnippetBasedReferenceTests.swift +++ b/Tests/OpenAPIGeneratorReferenceTests/SnippetBasedReferenceTests.swift @@ -4230,6 +4230,199 @@ final class SnippetBasedReferenceTests: XCTestCase { ) } + func testRequestMultipartBodyReferencedSchemaRecursive() throws { + try self.assertRequestInTypesClientServerTranslation( + """ + /foo: + post: + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: '#/components/schemas/NodeWrapper' + responses: + default: + description: Response + """, + """ + schemas: + NodeWrapper: + type: object + properties: + node: + $ref: '#/components/schemas/Node' + Node: + type: object + properties: + parent: + $ref: '#/components/schemas/Node' + """, + input: """ + public struct Input: Sendable, Hashable { + @frozen public enum Body: Sendable, Hashable { + case multipartForm(OpenAPIRuntime.MultipartBody) + } + public var body: Operations.post_sol_foo.Input.Body + public init(body: Operations.post_sol_foo.Input.Body) { + self.body = body + } + } + """, + schemas: """ + public enum Schemas { + @frozen public enum NodeWrapper: Sendable, Hashable { + public struct nodePayload: Sendable, Hashable { + public var body: Components.Schemas.Node + public init(body: Components.Schemas.Node) { + self.body = body + } + } + case node(OpenAPIRuntime.MultipartPart) + case undocumented(OpenAPIRuntime.MultipartRawPart) + } + public struct Node: Codable, Hashable, Sendable { + public var parent: Components.Schemas.Node? { + get { + self.storage.value.parent + } + _modify { + yield &self.storage.value.parent + } + } + public init(parent: Components.Schemas.Node? = nil) { + self.storage = .init(value: .init(parent: parent)) + } + public enum CodingKeys: String, CodingKey { + case parent + } + public init(from decoder: any Decoder) throws { + self.storage = try .init(from: decoder) + } + public func encode(to encoder: any Encoder) throws { + try self.storage.encode(to: encoder) + } + private var storage: OpenAPIRuntime.CopyOnWriteBox + private struct Storage: Codable, Hashable, Sendable { + var parent: Components.Schemas.Node? + init(parent: Components.Schemas.Node? = nil) { + self.parent = parent + } + typealias CodingKeys = Components.Schemas.Node.CodingKeys + } + } + } + """, + client: """ + { input in + let path = try converter.renderedPath( + template: "/foo", + parameters: [] + ) + var request: HTTPTypes.HTTPRequest = .init( + soar_path: path, + method: .post + ) + suppressMutabilityWarning(&request) + let body: OpenAPIRuntime.HTTPBody? + switch input.body { + case let .multipartForm(value): + body = try converter.setRequiredRequestBodyAsMultipart( + value, + headerFields: &request.headerFields, + contentType: "multipart/form-data", + allowsUnknownParts: true, + requiredExactlyOncePartNames: [], + requiredAtLeastOncePartNames: [], + atMostOncePartNames: [ + "node" + ], + zeroOrMoreTimesPartNames: [], + encoding: { part in + switch part { + case let .node(wrapped): + var headerFields: HTTPTypes.HTTPFields = .init() + let value = wrapped.payload + let body = try converter.setRequiredRequestBodyAsJSON( + value.body, + headerFields: &headerFields, + contentType: "application/json; charset=utf-8" + ) + return .init( + name: "node", + filename: wrapped.filename, + headerFields: headerFields, + body: body + ) + case let .undocumented(value): + return value + } + } + ) + } + return (request, body) + } + """, + server: """ + { request, requestBody, metadata in + let contentType = converter.extractContentTypeIfPresent(in: request.headerFields) + let body: Operations.post_sol_foo.Input.Body + let chosenContentType = try converter.bestContentType( + received: contentType, + options: [ + "multipart/form-data" + ] + ) + switch chosenContentType { + case "multipart/form-data": + body = try converter.getRequiredRequestBodyAsMultipart( + OpenAPIRuntime.MultipartBody.self, + from: requestBody, + transforming: { value in + .multipartForm(value) + }, + boundary: contentType.requiredBoundary(), + allowsUnknownParts: true, + requiredExactlyOncePartNames: [], + requiredAtLeastOncePartNames: [], + atMostOncePartNames: [ + "node" + ], + zeroOrMoreTimesPartNames: [], + decoding: { part in + let headerFields = part.headerFields + let (name, filename) = try converter.extractContentDispositionNameAndFilename(in: headerFields) + switch name { + case "node": + try converter.verifyContentTypeIfPresent( + in: headerFields, + matches: "application/json" + ) + let body = try await converter.getRequiredRequestBodyAsJSON( + Components.Schemas.Node.self, + from: part.body, + transforming: { + $0 + } + ) + return .node(.init( + payload: .init(body: body), + filename: filename + )) + default: + return .undocumented(part) + } + } + ) + default: + preconditionFailure("bestContentType chose an invalid content type.") + } + return Operations.post_sol_foo.Input(body: body) + } + """ + ) + } + func testRequestMultipartBodyReferencedSchemaWithEncoding() throws { try self.assertRequestInTypesClientServerTranslation( """