Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ extension FileTranslator {
default: return .infer(.primitive)
}
}
func inferAllOfAnyOfOneOf(_ schemas: [DereferencedJSONSchema]) throws -> MultipartPartInfo.ContentTypeSource? {

func inferAllOfAnyOfOneOf(_ schemas: [JSONSchema]) throws -> MultipartPartInfo.ContentTypeSource? {
// If all schemas are primitive, the allOf/anyOf/oneOf is also primitive.
// These cannot be binary, so only primitive vs complex.
for schema in schemas {
Expand All @@ -266,12 +267,13 @@ extension FileTranslator {
}
return .infer(.primitive)
}
func inferSchema(_ schema: DereferencedJSONSchema) throws -> (

func inferSchema(_ schema: JSONSchema) throws -> (
MultipartPartInfo.RepetitionKind, MultipartPartInfo.ContentTypeSource
)? {
let repetitionKind: MultipartPartInfo.RepetitionKind
let candidateSource: MultipartPartInfo.ContentTypeSource
switch schema {
switch schema.value {
case .null, .not: return nil
case .boolean, .number, .integer:
repetitionKind = .single
Expand All @@ -289,21 +291,30 @@ extension FileTranslator {
case .array(_, let context):
repetitionKind = .array
if let items = context.items {
switch items {
switch items.value {
case .null, .not: return nil
case .boolean, .number, .integer: candidateSource = .infer(.primitive)
case .string(_, let context): candidateSource = try inferStringContent(context)
case .object, .all, .one, .any, .fragment, .array: candidateSource = .infer(.complex)
case .reference(let ref, _):
guard let source = try inferSchema(components.lookup(ref))?.1 else { return nil }
candidateSource = source
}
} else {
candidateSource = .infer(.complex)
}
case .reference(let ref, _):
guard let (refRepetitionKind, refCandidateSource) = try inferSchema(components.lookup(ref)) else {
return nil
}
repetitionKind = refRepetitionKind
candidateSource = refCandidateSource
}

return (repetitionKind, candidateSource)
}
guard let (repetitionKind, candidateSource) = try inferSchema(schema.dereferenced(in: components)) else {
return nil
}
guard let (repetitionKind, candidateSource) = try inferSchema(schema) else { return nil }

let finalContentTypeSource: MultipartPartInfo.ContentTypeSource
if let encoding, let contentType = encoding.contentTypes.first, encoding.contentTypes.count == 1 {
finalContentTypeSource = try .explicit(contentType.asGeneratorContentType)
Expand Down
193 changes: 193 additions & 0 deletions Tests/OpenAPIGeneratorReferenceTests/SnippetBasedReferenceTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4230,6 +4230,199 @@ final class SnippetBasedReferenceTests: XCTestCase {
)
}

func testRequestMultipartBodyReferencedSchemaRecursive() throws {
try self.assertRequestInTypesClientServerTranslation(
"""
/foo:
post:
requestBody:
required: true
content:
multipart/form-data:
schema:
$ref: '#/components/schemas/NodeWrapper'
responses:
default:
description: Response
""",
"""
schemas:
NodeWrapper:
type: object
properties:
node:
$ref: '#/components/schemas/Node'
Node:
type: object
properties:
parent:
$ref: '#/components/schemas/Node'
""",
input: """
public struct Input: Sendable, Hashable {
@frozen public enum Body: Sendable, Hashable {
case multipartForm(OpenAPIRuntime.MultipartBody<Components.Schemas.NodeWrapper>)
}
public var body: Operations.post_sol_foo.Input.Body
public init(body: Operations.post_sol_foo.Input.Body) {
self.body = body
}
}
""",
schemas: """
public enum Schemas {
@frozen public enum NodeWrapper: Sendable, Hashable {
public struct nodePayload: Sendable, Hashable {
public var body: Components.Schemas.Node
public init(body: Components.Schemas.Node) {
self.body = body
}
}
case node(OpenAPIRuntime.MultipartPart<Components.Schemas.NodeWrapper.nodePayload>)
case undocumented(OpenAPIRuntime.MultipartRawPart)
}
public struct Node: Codable, Hashable, Sendable {
public var parent: Components.Schemas.Node? {
get {
self.storage.value.parent
}
_modify {
yield &self.storage.value.parent
}
}
public init(parent: Components.Schemas.Node? = nil) {
self.storage = .init(value: .init(parent: parent))
}
public enum CodingKeys: String, CodingKey {
case parent
}
public init(from decoder: any Decoder) throws {
self.storage = try .init(from: decoder)
}
public func encode(to encoder: any Encoder) throws {
try self.storage.encode(to: encoder)
}
private var storage: OpenAPIRuntime.CopyOnWriteBox<Storage>
private struct Storage: Codable, Hashable, Sendable {
var parent: Components.Schemas.Node?
init(parent: Components.Schemas.Node? = nil) {
self.parent = parent
}
typealias CodingKeys = Components.Schemas.Node.CodingKeys
}
}
}
""",
client: """
{ input in
let path = try converter.renderedPath(
template: "/foo",
parameters: []
)
var request: HTTPTypes.HTTPRequest = .init(
soar_path: path,
method: .post
)
suppressMutabilityWarning(&request)
let body: OpenAPIRuntime.HTTPBody?
switch input.body {
case let .multipartForm(value):
body = try converter.setRequiredRequestBodyAsMultipart(
value,
headerFields: &request.headerFields,
contentType: "multipart/form-data",
allowsUnknownParts: true,
requiredExactlyOncePartNames: [],
requiredAtLeastOncePartNames: [],
atMostOncePartNames: [
"node"
],
zeroOrMoreTimesPartNames: [],
encoding: { part in
switch part {
case let .node(wrapped):
var headerFields: HTTPTypes.HTTPFields = .init()
let value = wrapped.payload
let body = try converter.setRequiredRequestBodyAsJSON(
value.body,
headerFields: &headerFields,
contentType: "application/json; charset=utf-8"
)
return .init(
name: "node",
filename: wrapped.filename,
headerFields: headerFields,
body: body
)
case let .undocumented(value):
return value
}
}
)
}
return (request, body)
}
""",
server: """
{ request, requestBody, metadata in
let contentType = converter.extractContentTypeIfPresent(in: request.headerFields)
let body: Operations.post_sol_foo.Input.Body
let chosenContentType = try converter.bestContentType(
received: contentType,
options: [
"multipart/form-data"
]
)
switch chosenContentType {
case "multipart/form-data":
body = try converter.getRequiredRequestBodyAsMultipart(
OpenAPIRuntime.MultipartBody<Components.Schemas.NodeWrapper>.self,
from: requestBody,
transforming: { value in
.multipartForm(value)
},
boundary: contentType.requiredBoundary(),
allowsUnknownParts: true,
requiredExactlyOncePartNames: [],
requiredAtLeastOncePartNames: [],
atMostOncePartNames: [
"node"
],
zeroOrMoreTimesPartNames: [],
decoding: { part in
let headerFields = part.headerFields
let (name, filename) = try converter.extractContentDispositionNameAndFilename(in: headerFields)
switch name {
case "node":
try converter.verifyContentTypeIfPresent(
in: headerFields,
matches: "application/json"
)
let body = try await converter.getRequiredRequestBodyAsJSON(
Components.Schemas.Node.self,
from: part.body,
transforming: {
$0
}
)
return .node(.init(
payload: .init(body: body),
filename: filename
))
default:
return .undocumented(part)
}
}
)
default:
preconditionFailure("bestContentType chose an invalid content type.")
}
return Operations.post_sol_foo.Input(body: body)
}
"""
)
}

func testRequestMultipartBodyReferencedSchemaWithEncoding() throws {
try self.assertRequestInTypesClientServerTranslation(
"""
Expand Down
Loading