Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ extension FileTranslator {
default: return .infer(.primitive)
}
}
func inferAllOfAnyOfOneOf(_ schemas: [DereferencedJSONSchema]) throws -> MultipartPartInfo.ContentTypeSource? {

func inferAllOfAnyOfOneOf(_ schemas: [JSONSchema]) throws -> MultipartPartInfo.ContentTypeSource? {
// If all schemas are primitive, the allOf/anyOf/oneOf is also primitive.
// These cannot be binary, so only primitive vs complex.
for schema in schemas {
Expand All @@ -266,12 +267,13 @@ extension FileTranslator {
}
return .infer(.primitive)
}
func inferSchema(_ schema: DereferencedJSONSchema) throws -> (

func inferSchema(_ schema: JSONSchema) throws -> (
MultipartPartInfo.RepetitionKind, MultipartPartInfo.ContentTypeSource
)? {
let repetitionKind: MultipartPartInfo.RepetitionKind
let candidateSource: MultipartPartInfo.ContentTypeSource
switch schema {
switch schema.value {
case .null, .not: return nil
case .boolean, .number, .integer:
repetitionKind = .single
Expand All @@ -289,21 +291,30 @@ extension FileTranslator {
case .array(_, let context):
repetitionKind = .array
if let items = context.items {
switch items {
switch items.value {
case .null, .not: return nil
case .boolean, .number, .integer: candidateSource = .infer(.primitive)
case .string(_, let context): candidateSource = try inferStringContent(context)
case .object, .all, .one, .any, .fragment, .array: candidateSource = .infer(.complex)
case .reference(let ref, _):
guard let source = try inferSchema(components.lookup(ref))?.1 else { return nil }
candidateSource = source
}
} else {
candidateSource = .infer(.complex)
}
case .reference(let ref, _):
guard let (refRepetitionKind, refCandidateSource) = try inferSchema(components.lookup(ref)) else {
return nil
}
repetitionKind = refRepetitionKind
candidateSource = refCandidateSource
}

return (repetitionKind, candidateSource)
}
guard let (repetitionKind, candidateSource) = try inferSchema(schema.dereferenced(in: components)) else {
return nil
}
guard let (repetitionKind, candidateSource) = try inferSchema(schema) else { return nil }

let finalContentTypeSource: MultipartPartInfo.ContentTypeSource
if let encoding, let contentType = encoding.contentTypes.first, encoding.contentTypes.count == 1 {
finalContentTypeSource = try .explicit(contentType.asGeneratorContentType)
Expand Down
203 changes: 203 additions & 0 deletions Tests/OpenAPIGeneratorReferenceTests/SnippetBasedReferenceTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4230,6 +4230,209 @@ final class SnippetBasedReferenceTests: XCTestCase {
)
}

func testRequestMultipartBodyReferencedSchemaRecursive() throws {
try self.assertRequestInTypesClientServerTranslation(
"""
/foo:
post:
requestBody:
required: true
content:
multipart/form-data:
schema:
$ref: '#/components/schemas/Node'
responses:
default:
description: Response
""",
"""
schemas:
Node:
type: object
properties:
log:
type: string
parent:
$ref: '#/components/schemas/Node'
required:
- log
""",
input: """
public struct Input: Sendable, Hashable {
@frozen public enum Body: Sendable, Hashable {
case multipartForm(OpenAPIRuntime.MultipartBody<Components.Schemas.Node>)
}
public var body: Operations.post_sol_foo.Input.Body
public init(body: Operations.post_sol_foo.Input.Body) {
self.body = body
}
}
""",
schemas: """
public enum Schemas {
@frozen public enum Node: Sendable, Hashable {
public struct logPayload: Sendable, Hashable {
public var body: OpenAPIRuntime.HTTPBody
public init(body: OpenAPIRuntime.HTTPBody) {
self.body = body
}
}
case log(OpenAPIRuntime.MultipartPart<Components.Schemas.Node.logPayload>)
public struct parentPayload: Sendable, Hashable {
public var body: Components.Schemas.Node
public init(body: Components.Schemas.Node) {
self.body = body
}
}
case parent(OpenAPIRuntime.MultipartPart<Components.Schemas.Node.parentPayload>)
case undocumented(OpenAPIRuntime.MultipartRawPart)
}
}
""",
client: """
{ input in
let path = try converter.renderedPath(
template: "/foo",
parameters: []
)
var request: HTTPTypes.HTTPRequest = .init(
soar_path: path,
method: .post
)
suppressMutabilityWarning(&request)
let body: OpenAPIRuntime.HTTPBody?
switch input.body {
case let .multipartForm(value):
body = try converter.setRequiredRequestBodyAsMultipart(
value,
headerFields: &request.headerFields,
contentType: "multipart/form-data",
allowsUnknownParts: true,
requiredExactlyOncePartNames: [
"log"
],
requiredAtLeastOncePartNames: [],
atMostOncePartNames: [
"parent"
],
zeroOrMoreTimesPartNames: [],
encoding: { part in
switch part {
case let .log(wrapped):
var headerFields: HTTPTypes.HTTPFields = .init()
let value = wrapped.payload
let body = try converter.setRequiredRequestBodyAsBinary(
value.body,
headerFields: &headerFields,
contentType: "text/plain"
)
return .init(
name: "log",
filename: wrapped.filename,
headerFields: headerFields,
body: body
)
case let .parent(wrapped):
var headerFields: HTTPTypes.HTTPFields = .init()
let value = wrapped.payload
let body = try converter.setRequiredRequestBodyAsJSON(
value.body,
headerFields: &headerFields,
contentType: "application/json; charset=utf-8"
)
return .init(
name: "parent",
filename: wrapped.filename,
headerFields: headerFields,
body: body
)
case let .undocumented(value):
return value
}
}
)
}
return (request, body)
}
""",
server: """
{ request, requestBody, metadata in
let contentType = converter.extractContentTypeIfPresent(in: request.headerFields)
let body: Operations.post_sol_foo.Input.Body
let chosenContentType = try converter.bestContentType(
received: contentType,
options: [
"multipart/form-data"
]
)
switch chosenContentType {
case "multipart/form-data":
body = try converter.getRequiredRequestBodyAsMultipart(
OpenAPIRuntime.MultipartBody<Components.Schemas.Node>.self,
from: requestBody,
transforming: { value in
.multipartForm(value)
},
boundary: contentType.requiredBoundary(),
allowsUnknownParts: true,
requiredExactlyOncePartNames: [
"log"
],
requiredAtLeastOncePartNames: [],
atMostOncePartNames: [
"parent"
],
zeroOrMoreTimesPartNames: [],
decoding: { part in
let headerFields = part.headerFields
let (name, filename) = try converter.extractContentDispositionNameAndFilename(in: headerFields)
switch name {
case "log":
try converter.verifyContentTypeIfPresent(
in: headerFields,
matches: "text/plain"
)
let body = try converter.getRequiredRequestBodyAsBinary(
OpenAPIRuntime.HTTPBody.self,
from: part.body,
transforming: {
$0
}
)
return .log(.init(
payload: .init(body: body),
filename: filename
))
case "parent":
try converter.verifyContentTypeIfPresent(
in: headerFields,
matches: "application/json"
)
let body = try await converter.getRequiredRequestBodyAsJSON(
Components.Schemas.Node.self,
from: part.body,
transforming: {
$0
}
)
return .parent(.init(
payload: .init(body: body),
filename: filename
))
default:
return .undocumented(part)
}
}
)
default:
preconditionFailure("bestContentType chose an invalid content type.")
}
return Operations.post_sol_foo.Input(body: body)
}
"""
)
}

func testRequestMultipartBodyReferencedSchemaWithEncoding() throws {
try self.assertRequestInTypesClientServerTranslation(
"""
Expand Down