Skip to content

Commit 8bec997

Browse files
oldergodsvc-squareup-copybara
authored andcommitted
Introduce @RequestHeader and its FeatureBinding
GitOrigin-RevId: b88f05de7c3759005aeeb0219e23ff2dc9be2bae
1 parent e53efc0 commit 8bec997

File tree

5 files changed

+242
-0
lines changed

5 files changed

+242
-0
lines changed

misk-actions/api/misk-actions.api

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ public abstract interface annotation class misk/web/RequestContentType : java/la
199199
public abstract fun value ()[Ljava/lang/String;
200200
}
201201

202+
public abstract interface annotation class misk/web/RequestHeader : java/lang/annotation/Annotation {
203+
public abstract fun value ()Ljava/lang/String;
204+
}
205+
202206
public abstract interface annotation class misk/web/RequestHeaders : java/lang/annotation/Annotation {
203207
}
204208

misk-actions/src/main/kotlin/misk/web/Http.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ annotation class Description(val text: String)
2727
@Target(AnnotationTarget.VALUE_PARAMETER)
2828
annotation class RequestHeaders
2929

30+
/**
31+
* Extracts the named request header as a `String` or a `String?`. If the parameter is not nullable,
32+
* and has no default value, and the header is absent, the request will fail with an HTTP 400.
33+
*/
34+
@Target(AnnotationTarget.VALUE_PARAMETER)
35+
annotation class RequestHeader(val value: String = "")
36+
3037
@Target(AnnotationTarget.VALUE_PARAMETER)
3138
annotation class PathParam(val value: String = "")
3239

misk/src/main/kotlin/misk/web/MiskWebModule.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ import misk.web.extractors.PathParamFeatureBinding
5151
import misk.web.extractors.QueryParamFeatureBinding
5252
import misk.web.extractors.RequestBodyException
5353
import misk.web.extractors.RequestBodyFeatureBinding
54+
import misk.web.extractors.RequestHeaderFeatureBinding
5455
import misk.web.extractors.RequestHeadersFeatureBinding
5556
import misk.web.extractors.ResponseBodyFeatureBinding
5657
import misk.web.extractors.WebSocketFeatureBinding
@@ -256,6 +257,7 @@ class MiskWebModule @JvmOverloads constructor(
256257
multibind<FeatureBinding.Factory>().toInstance(PathParamFeatureBinding.Factory)
257258
multibind<FeatureBinding.Factory>().toInstance(QueryParamFeatureBinding.Factory)
258259
multibind<FeatureBinding.Factory>().toInstance(FormValueFeatureBinding.Factory)
260+
multibind<FeatureBinding.Factory>().toInstance(RequestHeaderFeatureBinding.Factory)
259261
multibind<FeatureBinding.Factory>().toInstance(RequestHeadersFeatureBinding.Factory)
260262
multibind<FeatureBinding.Factory>().toInstance(WebSocketFeatureBinding.Factory)
261263
multibind<FeatureBinding.Factory>().toInstance(WebSocketListenerFeatureBinding.Factory)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package misk.web.extractors
2+
3+
import misk.Action
4+
import misk.exceptions.BadRequestException
5+
import misk.web.FeatureBinding
6+
import misk.web.FeatureBinding.Claimer
7+
import misk.web.FeatureBinding.Subject
8+
import misk.web.PathPattern
9+
import misk.web.RequestHeader
10+
import kotlin.reflect.KParameter
11+
import kotlin.reflect.full.findAnnotation
12+
13+
/** Binds parameters annotated [RequestHeader] to HTTP request headers. */
14+
internal class RequestHeaderFeatureBinding private constructor(
15+
private val parameters: List<ParameterBinding>
16+
) : FeatureBinding {
17+
override fun beforeCall(subject: Subject) {
18+
for (element in parameters) {
19+
element.bind(subject)
20+
}
21+
}
22+
23+
internal class ParameterBinding(
24+
val parameter: KParameter,
25+
private val converter: StringConverter,
26+
private val name: String
27+
) {
28+
fun bind(subject: Subject) {
29+
val rawValue = subject.httpCall.requestHeaders[name]
30+
if (rawValue == null) {
31+
when {
32+
parameter.isOptional -> return
33+
parameter.type.isMarkedNullable -> return
34+
else -> throw BadRequestException("Required request header $name not present")
35+
}
36+
}
37+
val value = try {
38+
converter(rawValue)
39+
} catch (e: IllegalArgumentException) {
40+
throw BadRequestException("Invalid format for parameter: $name", e)
41+
}
42+
subject.setParameter(parameter, value)
43+
}
44+
}
45+
46+
companion object Factory : FeatureBinding.Factory {
47+
override fun create(
48+
action: Action,
49+
pathPattern: PathPattern,
50+
claimer: Claimer
51+
): FeatureBinding? {
52+
val bindings = action.parameters.mapNotNull { it.toRequestHeaderBinding() }
53+
if (bindings.isEmpty()) return null
54+
55+
for (binding in bindings) {
56+
claimer.claimParameter(binding.parameter)
57+
}
58+
59+
return RequestHeaderFeatureBinding(bindings)
60+
}
61+
62+
internal fun KParameter.toRequestHeaderBinding(): ParameterBinding? {
63+
val annotation = findAnnotation<RequestHeader>() ?: return null
64+
val name = if (annotation.value.isBlank()) name!! else annotation.value
65+
66+
val stringConverter = converterFor(type)
67+
?: throw IllegalArgumentException("Unable to create converter for $name")
68+
69+
return ParameterBinding(this, stringConverter, name)
70+
}
71+
}
72+
}
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
package misk.web.extractors
2+
3+
import misk.MiskTestingServiceModule
4+
import misk.inject.KAbstractModule
5+
import misk.testing.MiskTest
6+
import misk.testing.MiskTestModule
7+
import misk.web.WebActionModule
8+
import misk.web.WebServerTestingModule
9+
import misk.web.actions.WebAction
10+
import misk.web.jetty.JettyService
11+
import okhttp3.OkHttpClient
12+
import okhttp3.Request
13+
import org.assertj.core.api.Assertions.assertThat
14+
import org.junit.jupiter.api.Test
15+
import jakarta.inject.Inject
16+
import misk.web.Get
17+
import misk.web.RequestHeader
18+
import okhttp3.Headers
19+
import okhttp3.Headers.Companion.headersOf
20+
21+
@MiskTest(startService = true)
22+
internal class RequestHeaderParameterTest {
23+
@MiskTestModule
24+
val module = TestModule()
25+
26+
@Inject lateinit var jettyService: JettyService
27+
28+
@Test fun happyPath() {
29+
assertThat(get("/echo-user-agent", headersOf("cash-user-agent", "Cash App 4.0")))
30+
.isEqualTo("your user agent is 'Cash App 4.0'")
31+
assertThat(get("/echo-user-agent", headersOf()))
32+
.isEqualTo("Required request header Cash-User-Agent not present")
33+
}
34+
35+
@Test fun returnsLastHeader() {
36+
assertThat(
37+
get(
38+
"/echo-user-agent",
39+
headersOf(
40+
"cash-user-agent", "Cash App 4.0",
41+
"cash-user-agent", "Cash App 5.0"
42+
)
43+
)
44+
).isEqualTo("your user agent is 'Cash App 5.0'")
45+
}
46+
47+
@Test fun optionalParameter() {
48+
assertThat(get("/echo-optional-user-agent", headersOf("cash-user-agent", "Cash App 4.0")))
49+
.isEqualTo("your user agent is 'Cash App 4.0'")
50+
assertThat(get("/echo-optional-user-agent", headersOf()))
51+
.isEqualTo("your user agent is '<absent>'")
52+
}
53+
54+
@Test fun nullableParameter() {
55+
assertThat(get("/echo-nullable-user-agent", headersOf("cash-user-agent", "Cash App 4.0")))
56+
.isEqualTo("your user agent is 'Cash App 4.0'")
57+
assertThat(get("/echo-nullable-user-agent", headersOf()))
58+
.isEqualTo("your user agent is 'null'")
59+
}
60+
61+
@Test fun nullableOptionalParameter() {
62+
assertThat(get("/echo-nullable-optional-user-agent", headersOf("cash-user-agent", "Cash App 4.0")))
63+
.isEqualTo("your user agent is 'Cash App 4.0'")
64+
assertThat(get("/echo-nullable-optional-user-agent", headersOf()))
65+
.isEqualTo("your user agent is '<absent>'")
66+
}
67+
68+
@Test fun typeConvertedParameter() {
69+
assertThat(get("/echo-user-agent-type", headersOf("cash-user-agent-type", "ANDROID")))
70+
.isEqualTo("your user agent type is ANDROID")
71+
assertThat(get("/echo-user-agent-type", headersOf()))
72+
.isEqualTo("Required request header Cash-User-Agent-Type not present")
73+
}
74+
75+
@Test fun multipleRequestHeaders() {
76+
assertThat(
77+
get(
78+
"/multiple-request-headers",
79+
headersOf(
80+
"user-agent", "Cash App 4.0",
81+
"accept-encoding", "gzip"
82+
)
83+
)
84+
).isEqualTo("your user agent is 'Cash App 4.0' and your accept encoding is 'gzip'")
85+
}
86+
87+
class EchoUserAgentAction @Inject constructor() : WebAction {
88+
@Get("/echo-user-agent")
89+
fun call(
90+
@RequestHeader("Cash-User-Agent") userAgent: String
91+
) = "your user agent is '$userAgent'"
92+
}
93+
94+
class EchoOptionalUserAgentAction @Inject constructor() : WebAction {
95+
@Get("/echo-optional-user-agent")
96+
fun call(
97+
@RequestHeader("Cash-User-Agent") userAgent: String = "<absent>"
98+
) = "your user agent is '$userAgent'"
99+
}
100+
101+
class EchoNullableUserAgentAction @Inject constructor() : WebAction {
102+
@Get("/echo-nullable-user-agent")
103+
fun call(
104+
@RequestHeader("Cash-User-Agent") userAgent: String?
105+
) = "your user agent is '$userAgent'"
106+
}
107+
108+
class EchoNullableOptionalUserAgentAction @Inject constructor() : WebAction {
109+
@Get("/echo-nullable-optional-user-agent")
110+
fun call(
111+
@RequestHeader("Cash-User-Agent") userAgent: String? = "<absent>"
112+
) = "your user agent is '$userAgent'"
113+
}
114+
115+
class EchoUserAgentTypeAction @Inject constructor() : WebAction {
116+
@Get("/echo-user-agent-type")
117+
fun call(
118+
@RequestHeader("Cash-User-Agent-Type") userAgentType: CashUserAgentType
119+
) = "your user agent type is $userAgentType"
120+
}
121+
122+
class MultipleRequestHeadersAction @Inject constructor() : WebAction {
123+
@Get("/multiple-request-headers")
124+
fun call(
125+
@RequestHeader("User-Agent") userAgent: String,
126+
@RequestHeader("Accept-Encoding") acceptEncoding: String,
127+
) = "your user agent is '$userAgent' and your accept encoding is '$acceptEncoding'"
128+
}
129+
130+
enum class CashUserAgentType {
131+
IOS,
132+
ANDROID
133+
}
134+
135+
class TestModule : KAbstractModule() {
136+
override fun configure() {
137+
install(WebServerTestingModule())
138+
install(MiskTestingServiceModule())
139+
install(WebActionModule.create<EchoNullableUserAgentAction>())
140+
install(WebActionModule.create<EchoNullableOptionalUserAgentAction>())
141+
install(WebActionModule.create<EchoOptionalUserAgentAction>())
142+
install(WebActionModule.create<EchoUserAgentAction>())
143+
install(WebActionModule.create<EchoUserAgentTypeAction>())
144+
install(WebActionModule.create<MultipleRequestHeadersAction>())
145+
}
146+
}
147+
148+
private fun get(path: String, headers: Headers): String {
149+
val url = jettyService.httpServerUrl.newBuilder()
150+
.encodedPath(path)
151+
.build()
152+
val request = Request(url, headers)
153+
val httpClient = OkHttpClient()
154+
val response = httpClient.newCall(request).execute()
155+
return response.body.source().readUtf8()
156+
}
157+
}

0 commit comments

Comments
 (0)