Skip to content

Commit 39a7ff4

Browse files
oldergodsvc-squareup-copybara
authored andcommitted
Allow Action definitions in superclasses/interfaces
GitOrigin-RevId: 7daa8ddb7b9564cf11aab5cd2a297cd94ca51299
1 parent 4d04cbe commit 39a7ff4

File tree

10 files changed

+286
-49
lines changed

10 files changed

+286
-49
lines changed

misk/src/main/kotlin/misk/Action.kt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@ import kotlin.reflect.KParameter
99
import kotlin.reflect.KType
1010
import kotlin.reflect.full.findAnnotation
1111

12+
/**
13+
* Adapts a function so that it can be called by the framework.
14+
*
15+
* This adapts the parameters and return value as HTTP content (like the request and response
16+
* bodies, path parameters, and query parameters).
17+
*
18+
* This aggregates annotations from overridden functions on inherited interfaces and superclasses.
19+
* For example, putting `@RequestBody` on an interface function parameter is as good as putting it
20+
* on the implementing function's parameter. If both a supertype function and a subtype function
21+
* have the same annotation, the subtype's annotation takes precedence.
22+
*
23+
*/
1224
data class Action(
1325
val name: String,
1426
val function: KFunction<*>,

misk/src/main/kotlin/misk/grpc/GrpcFeatureBinding.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ import misk.web.FeatureBinding.Claimer
1515
import misk.web.FeatureBinding.Subject
1616
import misk.web.PathPattern
1717
import misk.web.WebConfig
18-
import misk.web.actions.findAnnotationWithOverrides
1918
import misk.web.mediatype.MediaTypes
2019
import java.lang.reflect.Type
2120
import kotlin.coroutines.CoroutineContext
21+
import kotlin.reflect.full.findAnnotation
2222

2323
internal class GrpcFeatureBinding(
2424
private val requestAdapter: ProtoAdapter<Any>,
@@ -129,7 +129,7 @@ internal class GrpcFeatureBinding(
129129
"@Grpc functions must have either 1 or 2 parameters: $action"
130130
}
131131

132-
val wireAnnotation = action.function.findAnnotationWithOverrides<WireRpc>() ?: return null
132+
val wireAnnotation = action.function.findAnnotation<WireRpc>() ?: return null
133133

134134
claimer.claimParameter(0)
135135
claimer.claimRequestBody()

misk/src/main/kotlin/misk/web/BoundAction.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import misk.scope.SeedDataTransformer
99
import misk.security.authz.AccessInterceptor
1010
import misk.web.actions.WebAction
1111
import misk.web.actions.asChain
12-
import misk.web.actions.findAnnotationWithOverrides
1312
import misk.web.mediatype.MediaRange
1413
import misk.web.mediatype.MediaTypes
1514
import misk.web.mediatype.compareTo
@@ -22,6 +21,7 @@ import com.google.inject.Provider
2221
import misk.api.HttpRequest
2322
import javax.servlet.http.HttpServletRequest
2423
import kotlin.reflect.KType
24+
import kotlin.reflect.full.findAnnotation
2525

2626
/**
2727
* Decodes an HTTP request into a call to a web action, then encodes its response into an HTTP
@@ -150,7 +150,7 @@ internal class BoundAction<A : WebAction>(
150150
WebActionMetadata(
151151
name = action.name,
152152
function = action.function,
153-
description = action.function.findAnnotationWithOverrides<Description>()?.text,
153+
description = action.function.findAnnotation<Description>()?.text,
154154
functionAnnotations = action.function.annotations,
155155
acceptedMediaRanges = action.acceptedMediaRanges,
156156
responseContentType = action.responseContentType,

misk/src/main/kotlin/misk/web/actions/WebActionFactory.kt

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,17 @@ internal class WebActionFactory @Inject constructor(
6060
): List<BoundAction<A>> {
6161
// Find the function with Get, Post, Put, Delete or ConnectWebSocket annotation.
6262
// Only one such function is allowed.
63-
val actionFunctions = webActionClass.functions.mapNotNull {
64-
if (it.findAnnotationWithOverrides<Get>() != null ||
65-
it.findAnnotationWithOverrides<Post>() != null ||
66-
it.findAnnotationWithOverrides<Patch>() != null ||
67-
it.findAnnotationWithOverrides<Put>() != null ||
68-
it.findAnnotationWithOverrides<Grpc>() != null ||
69-
it.findAnnotationWithOverrides<Delete>() != null ||
70-
it.findAnnotationWithOverrides<ConnectWebSocket>() != null ||
71-
it.findAnnotationWithOverrides<WireRpc>() != null
63+
val functionsWithOverrides = webActionClass.functions.map { it.withOverrides() }
64+
65+
val actionFunctions = functionsWithOverrides.mapNotNull {
66+
if (it.findAnnotation<Get>() != null ||
67+
it.findAnnotation<Post>() != null ||
68+
it.findAnnotation<Patch>() != null ||
69+
it.findAnnotation<Put>() != null ||
70+
it.findAnnotation<Grpc>() != null ||
71+
it.findAnnotation<Delete>() != null ||
72+
it.findAnnotation<ConnectWebSocket>() != null ||
73+
it.findAnnotation<WireRpc>() != null
7274
) {
7375
it as? KFunction<*>
7476
?: throw IllegalArgumentException("expected $it to be a function")
@@ -91,14 +93,14 @@ internal class WebActionFactory @Inject constructor(
9193
val effectivePrefix = pathPrefix.dropLast(1)
9294

9395
actionFunctions.forEach { actionFunction ->
94-
val get = actionFunction.findAnnotationWithOverrides<Get>()
95-
val post = actionFunction.findAnnotationWithOverrides<Post>()
96-
val patch = actionFunction.findAnnotationWithOverrides<Patch>()
97-
val put = actionFunction.findAnnotationWithOverrides<Put>()
98-
val delete = actionFunction.findAnnotationWithOverrides<Delete>()
99-
val webActionGrpc = actionFunction.findAnnotationWithOverrides<Grpc>()
100-
val connectWebSocket = actionFunction.findAnnotationWithOverrides<ConnectWebSocket>()
101-
val grpc = actionFunction.findAnnotationWithOverrides<WireRpc>()
96+
val get = actionFunction.findAnnotation<Get>()
97+
val post = actionFunction.findAnnotation<Post>()
98+
val patch = actionFunction.findAnnotation<Patch>()
99+
val put = actionFunction.findAnnotation<Put>()
100+
val delete = actionFunction.findAnnotation<Delete>()
101+
val webActionGrpc = actionFunction.findAnnotation<Grpc>()
102+
val connectWebSocket = actionFunction.findAnnotation<ConnectWebSocket>()
103+
val grpc = actionFunction.findAnnotation<WireRpc>()
102104

103105
if (get != null) {
104106
collectBoundActions(
@@ -222,7 +224,7 @@ internal class WebActionFactory @Inject constructor(
222224
pathPattern: String,
223225
action: Action
224226
): BoundAction<A> {
225-
// Ensure that default interceptors are called before any user provided interceptors
227+
// Ensure that default interceptors are called before any user provided interceptors.
226228
val networkInterceptors =
227229
beforeContentEncodingNetworkInterceptorFactories.mapNotNull { it.create(action) } +
228230
forContentEncodingNetworkInterceptorFactories.mapNotNull { it.create(action) } +

misk/src/main/kotlin/misk/web/actions/WebActions.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ internal fun WebAction.asChain(
5959
sourceChannel?.let { launch { sourceChannel.bridgeFromSource() } }
6060
sinkChannel?.let { launch { sinkChannel.bridgeToSink() } }
6161
try {
62-
function.callSuspendBy(argsMap)
62+
(function as? FunctionWithOverrides)?.callSuspendBy(argsMap)
63+
?: function.callSuspendBy(argsMap)
6364
} finally {
6465
// Once the action is complete, close the send channel and wait for the jobs to finish
6566
// This blocks any additional sends to the channel, but will allow existing responses in

misk/src/main/kotlin/misk/web/actions/reflect.kt

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,18 @@ package misk.web.actions
33
import java.lang.reflect.Method
44
import java.util.ArrayDeque
55
import kotlin.reflect.KFunction
6+
import kotlin.reflect.KParameter
7+
import kotlin.reflect.full.callSuspendBy
68
import kotlin.reflect.jvm.javaMethod
79

8-
/** Returns an instance of [T] annotating this method or a method it overrides. */
9-
internal inline fun <reified T : Annotation> KFunction<*>.findAnnotationWithOverrides(): T? {
10-
return javaMethod!!.findAnnotationWithOverrides(T::class.java)
11-
}
12-
1310
/**
14-
* Returns an instance of [T] annotating this method or a method it overrides. If multiple
15-
* overridden methods have the annotation, one is chosen arbitrarily.
11+
* Returns a function that delegates everything to this. It promotes all annotations from overridden
12+
* functions to the returned function, and all annotations from overridden parameters to the
13+
* returned function.
14+
*
15+
* Use this to make it easy to get annotations on overridden functions 'for free'.
1616
*/
17-
internal fun <T : Annotation> Method.findAnnotationWithOverrides(annotationClass: Class<T>): T? {
18-
for (method in overrides()) {
19-
val annotation = method.getAnnotation(annotationClass)
20-
if (annotation != null) {
21-
return annotation
22-
}
23-
}
24-
return null
25-
}
17+
internal fun <R> KFunction<R>.withOverrides(): KFunction<R> = FunctionWithOverrides(this)
2618

2719
/** Returns the overrides of this method with overriding methods preceding overridden methods. */
2820
internal fun Method.overrides(): Set<Method> {
@@ -32,7 +24,7 @@ internal fun Method.overrides(): Set<Method> {
3224
}
3325

3426
/** Returns the method that [override] overrides. */
35-
internal fun Class<*>.getOverriddenMethod(override: Method): Method? {
27+
private fun Class<*>.getOverriddenMethod(override: Method): Method? {
3628
return try {
3729
check(this.isAssignableFrom(override.declaringClass))
3830
val overridden = getDeclaredMethod(override.name, *override.parameterTypes)
@@ -105,3 +97,51 @@ internal fun Method.preferNonSynthetic(): Method {
10597

10698
return this
10799
}
100+
101+
internal val KFunction<*>.javaMethod: Method?
102+
get() = (this as? FunctionWithOverrides)?.function?.javaMethod ?: this.javaMethod
103+
104+
internal class FunctionWithOverrides<out R>(
105+
val function: KFunction<R>
106+
) : KFunction<R> by function {
107+
private val methodOverrides = function.javaMethod!!.overrides()
108+
109+
override val annotations: List<Annotation> =
110+
methodOverrides.flatMap { it.annotations.toList() }
111+
112+
override val parameters: List<KParameter> =
113+
function.parameters.mapIndexed { index, parameter ->
114+
ParameterWithOverrides(
115+
parameter,
116+
methodOverrides.flatMap { override ->
117+
when (index) {
118+
0 -> listOf()
119+
else -> override.parameters[index - 1].annotations.toList()
120+
}
121+
}
122+
)
123+
}
124+
125+
override fun callBy(args: Map<KParameter, Any?>): R {
126+
val parameters = args.mapKeys { (key, _) ->
127+
function.parameters[key.index]
128+
}
129+
return function.callBy(parameters)
130+
}
131+
132+
suspend fun callSuspendBy(args: Map<KParameter, Any?>): R {
133+
val parameters = args.mapKeys { (key, _) ->
134+
function.parameters[key.index]
135+
}
136+
return function.callSuspendBy(parameters)
137+
}
138+
139+
override fun toString(): String {
140+
return function.toString()
141+
}
142+
}
143+
144+
private class ParameterWithOverrides(
145+
val parameter: KParameter,
146+
override val annotations: List<Annotation>,
147+
) : KParameter by parameter

misk/src/main/kotlin/misk/web/extractors/ResponseBodyFeatureBinding.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ import misk.web.FeatureBinding.Subject
88
import misk.web.Grpc
99
import misk.web.PathPattern
1010
import misk.web.actions.WebSocketListener
11-
import misk.web.actions.findAnnotationWithOverrides
1211
import misk.web.interceptors.ResponseBodyMarshallerFactory
1312
import misk.web.marshal.Marshaller
1413
import jakarta.inject.Inject
1514
import jakarta.inject.Singleton
15+
import kotlin.reflect.full.findAnnotation
1616

1717
internal class ResponseBodyFeatureBinding(
1818
private val responseBodyMarshaller: Marshaller<Any>
@@ -41,7 +41,7 @@ internal class ResponseBodyFeatureBinding(
4141
pathPattern: PathPattern,
4242
claimer: Claimer
4343
): FeatureBinding? {
44-
if (action.dispatchMechanism == DispatchMechanism.GRPC && action.function.findAnnotationWithOverrides<Grpc>() == null) return null
44+
if (action.dispatchMechanism == DispatchMechanism.GRPC && action.function.findAnnotation<Grpc>() == null) return null
4545
if (action.returnType.classifier == Unit::class) return null
4646
if (action.returnType.classifier == WebSocketListener::class) return null
4747

misk/src/main/kotlin/misk/web/metadata/webaction/WebActionMetadata.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ import misk.web.DispatchMechanism
55
import misk.web.MiskWebFormBuilder
66
import misk.web.NetworkInterceptor
77
import misk.web.PathPattern
8+
import misk.web.actions.javaMethod
89
import misk.web.formatter.ClassNameFormatter
910
import misk.web.mediatype.MediaRange
1011
import okhttp3.MediaType
1112
import kotlin.reflect.KFunction
1213
import kotlin.reflect.KParameter
1314
import kotlin.reflect.KType
14-
import kotlin.reflect.jvm.javaMethod
1515

1616
/** Metadata front end model for Web Action Misk-Web Tab */
1717
data class WebActionMetadata(

misk/src/test/kotlin/misk/web/actions/ReflectTest.kt

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ package misk.web.actions
33
import org.assertj.core.api.Assertions.assertThat
44
import org.junit.jupiter.api.Test
55
import java.lang.reflect.Method
6+
import kotlin.reflect.full.findAnnotation
7+
import kotlin.reflect.full.functions
8+
import kotlin.reflect.jvm.javaMethod
69

710
internal class ReflectTest {
811
@Test
@@ -55,19 +58,75 @@ internal class ReflectTest {
5558
@Test
5659
internal fun annotationWithOverrides() {
5760
assertThat(
58-
Square::class.java.getDeclaredMethod("area")
59-
.findAnnotationWithOverrides(Tag::class.java)!!.name
61+
Square::class.functions.first { it.name == "area" }
62+
.withOverrides()
63+
.findAnnotation<Tag>()!!.name
6064
).isEqualTo("square")
6165
assertThat(
62-
Square::class.java.getDeclaredMethod("perimeter")
63-
.findAnnotationWithOverrides(Tag::class.java)
66+
Square::class.functions.first { it.name == "perimeter" }
67+
.withOverrides()
68+
.findAnnotation<Tag>()
6469
).isNull()
6570
assertThat(
66-
Square::class.java.getDeclaredMethod("edgeCount")
67-
.findAnnotationWithOverrides(Tag::class.java)!!.name
71+
Square::class.functions.first { it.name == "edgeCount" }
72+
.withOverrides()
73+
.findAnnotation<Tag>()!!.name
6874
).isEqualTo("polygon")
6975
}
7076

77+
@Test
78+
internal fun parameterAnnotationWithOverridesRequestWithoutValue() {
79+
val greetEmptyFunction = RealGreetingAction::class.functions.single { it.name == "greetEmpty" }
80+
.withOverrides()
81+
assertThat(greetEmptyFunction.findAnnotation<Tag>()!!.name).isEqualTo("/greet-empty")
82+
83+
val messageEmptyParameter = greetEmptyFunction.parameters.single { it.name == "message" }
84+
assertThat(messageEmptyParameter.findAnnotation<RequestEmpty>()).isNotNull()
85+
86+
assertThat(greetEmptyFunction.callBy(
87+
mapOf(
88+
greetEmptyFunction.parameters[0] to RealGreetingAction(),
89+
greetEmptyFunction.parameters[1] to "hello",
90+
)
91+
)).isEqualTo("806")
92+
}
93+
94+
@Test
95+
internal fun parameterAnnotationWithOverridesParameterWithValue() {
96+
val greetFunction = RealGreetingAction::class.functions.single { it.name == "greet" }
97+
.withOverrides()
98+
assertThat(greetFunction.findAnnotation<Tag>()!!.name).isEqualTo("/greet")
99+
100+
val messageParameter = greetFunction.parameters.single { it.name == "message" }
101+
assertThat(messageParameter.findAnnotation<RequestPayload>()!!.name).isEqualTo("hello")
102+
103+
assertThat(greetFunction.callBy(
104+
mapOf(
105+
greetFunction.parameters[0] to RealGreetingAction(),
106+
greetFunction.parameters[1] to "hello",
107+
)
108+
)).isEqualTo("33")
109+
}
110+
111+
@Target(AnnotationTarget.VALUE_PARAMETER)
112+
annotation class RequestPayload(val name: String)
113+
114+
@Target(AnnotationTarget.VALUE_PARAMETER)
115+
annotation class RequestEmpty
116+
117+
interface GreetingAction {
118+
@Tag("/greet")
119+
fun greet(@RequestPayload("hello") message: String): String
120+
121+
@Tag("/greet-empty")
122+
fun greetEmpty(@RequestEmpty message: String): String
123+
}
124+
125+
class RealGreetingAction:GreetingAction {
126+
override fun greet(message: String): String = "33"
127+
override fun greetEmpty(message: String): String = "806"
128+
}
129+
71130
class Square : Polygon(), Territory {
72131
@Tag("square") override fun area() = error("unused")
73132
override fun perimeter() = error("unused")

0 commit comments

Comments
 (0)