|
| 1 | +package software.amazon.smithy.aws.swift.codegen.customization.machinelearning |
| 2 | + |
| 3 | +import software.amazon.smithy.aws.swift.codegen.middleware.PredictInputEndpointURLHostMiddlewareRenderable |
| 4 | +import software.amazon.smithy.aws.swift.codegen.middleware.handlers.PredictInputEndpointURLHostMiddlewareHandler |
| 5 | +import software.amazon.smithy.model.Model |
| 6 | +import software.amazon.smithy.model.shapes.OperationShape |
| 7 | +import software.amazon.smithy.model.shapes.ServiceShape |
| 8 | +import software.amazon.smithy.swift.codegen.MiddlewareGenerator |
| 9 | +import software.amazon.smithy.swift.codegen.SwiftDelegator |
| 10 | +import software.amazon.smithy.swift.codegen.SwiftDependency |
| 11 | +import software.amazon.smithy.swift.codegen.SwiftSettings |
| 12 | +import software.amazon.smithy.swift.codegen.core.CodegenContext |
| 13 | +import software.amazon.smithy.swift.codegen.core.toProtocolGenerationContext |
| 14 | +import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator |
| 15 | +import software.amazon.smithy.swift.codegen.integration.SwiftIntegration |
| 16 | +import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.MiddlewareShapeUtils |
| 17 | +import software.amazon.smithy.swift.codegen.middleware.MiddlewareStep |
| 18 | +import software.amazon.smithy.swift.codegen.middleware.OperationMiddleware |
| 19 | +import software.amazon.smithy.swift.codegen.model.expectShape |
| 20 | + |
| 21 | +internal val ENABLED_OPERATIONS: Map<String, Set<String>> = mapOf( |
| 22 | + "com.amazonaws.machinelearning#AmazonML_20141212" to setOf( |
| 23 | + "com.amazonaws.machinelearning#Predict" |
| 24 | + ) |
| 25 | +) |
| 26 | + |
| 27 | +class PredictEndpointIntegration(private val enabledOperations: Map<String, Set<String>> = ENABLED_OPERATIONS) : SwiftIntegration { |
| 28 | + |
| 29 | + override fun enabledForService(model: Model, settings: SwiftSettings): Boolean { |
| 30 | + val currentServiceId = model.expectShape<ServiceShape>(settings.service).id.toString() |
| 31 | + return enabledOperations.keys.contains(currentServiceId) |
| 32 | + } |
| 33 | + override fun writeAdditionalFiles(ctx: CodegenContext, delegator: SwiftDelegator) { |
| 34 | + val serviceShape = ctx.model.expectShape<ServiceShape>(ctx.settings.service) |
| 35 | + val protocolGeneratorContext = ctx.toProtocolGenerationContext(serviceShape, delegator)?.let { it } ?: run { return } |
| 36 | + val service = ctx.model.expectShape<ServiceShape>(ctx.settings.service) |
| 37 | + val operationsToGenerate = enabledOperations.getOrDefault(service.id.toString(), setOf()) |
| 38 | + |
| 39 | + operationsToGenerate.forEach { operation -> |
| 40 | + val op = ctx.model.expectShape<OperationShape>(operation) |
| 41 | + val inputSymbol = MiddlewareShapeUtils.inputSymbol(ctx.symbolProvider, ctx.model, op) |
| 42 | + val outputSymbol = MiddlewareShapeUtils.outputSymbol(ctx.symbolProvider, ctx.model, op) |
| 43 | + val outputErrorSymbol = MiddlewareShapeUtils.outputErrorSymbol(op) |
| 44 | + |
| 45 | + val inputType = op.input.get() |
| 46 | + delegator.useFileWriter("${ctx.settings.moduleName}/models/$inputType+EndpointURLHostMiddleware.swift") { writer -> |
| 47 | + writer.addImport(SwiftDependency.CLIENT_RUNTIME.target) |
| 48 | + val predictMiddleware = PredictInputEndpointURLHostMiddlewareHandler(writer, protocolGeneratorContext, inputSymbol, outputSymbol, outputErrorSymbol) |
| 49 | + MiddlewareGenerator(writer, predictMiddleware).generate() |
| 50 | + } |
| 51 | + } |
| 52 | + } |
| 53 | + |
| 54 | + override fun customizeMiddleware( |
| 55 | + ctx: ProtocolGenerator.GenerationContext, |
| 56 | + operationShape: OperationShape, |
| 57 | + operationMiddleware: OperationMiddleware |
| 58 | + ) { |
| 59 | + if (enabledOperations.values.contains(setOf(operationShape.id.toString()))) { |
| 60 | + operationMiddleware.removeMiddleware(operationShape, MiddlewareStep.INITIALIZESTEP, "OperationInputUrlHostMiddleware") |
| 61 | + operationMiddleware.appendMiddleware(operationShape, PredictInputEndpointURLHostMiddlewareRenderable()) |
| 62 | + } |
| 63 | + } |
| 64 | +} |
0 commit comments