Skip to content

Commit a42841a

Browse files
committed
test: ensure that HTTP headers work
1 parent 6d6ce5f commit a42841a

File tree

1 file changed

+216
-30
lines changed
  • host/tests/integration_tests/python/runtime

1 file changed

+216
-30
lines changed

host/tests/integration_tests/python/runtime/http.rs

Lines changed: 216 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use datafusion_udf_wasm_host::{
1010
WasmPermissions, WasmScalarUdf,
1111
http::{AllowCertainHttpRequests, Matcher},
1212
};
13+
use wasmtime_wasi_http::types::DEFAULT_FORBIDDEN_HEADERS;
1314
use wiremock::{Mock, MockServer, ResponseTemplate, matchers};
1415

1516
use crate::integration_tests::{
@@ -128,63 +129,161 @@ def perform_request(url: str) -> str:
128129
}
129130

130131
#[tokio::test(flavor = "multi_thread")]
131-
async fn test_urllib3_happy_path() {
132+
async fn test_integration() {
132133
const CODE: &str = r#"
133134
import urllib3
134135
135-
def perform_request(method: str, url: str) -> str:
136-
resp = urllib3.request(method, url)
136+
def _headers_str_to_dict(headers: str) -> dict[str, str]:
137+
headers_dct = {}
138+
if headers is not None:
139+
for k_v in headers.split(";"):
140+
[k, v] = k_v.split(":")
141+
headers_dct[k] = v
142+
return headers_dct
143+
144+
def _headers_dict_to_str(headers: dict[str, str]) -> str:
145+
headers = ";".join((
146+
f"{k}:{v}"
147+
for k, v in headers.items()
148+
if k not in ["content-length", "content-type", "date"]
149+
))
150+
if not headers:
151+
return "n/a"
152+
else:
153+
return headers
154+
155+
def perform_request(method: str, url: str, headers: str | None) -> str:
156+
try:
157+
resp = urllib3.request(
158+
method=method,
159+
url=url,
160+
headers=_headers_str_to_dict(headers),
161+
)
162+
except Exception as e:
163+
return f"ERR: {e}"
137164
138165
resp_status = resp.status
139166
resp_body = resp.data.decode("utf-8")
167+
resp_headers = _headers_dict_to_str(resp.headers)
140168
141-
return f"status={resp_status} body='{resp_body}'"
169+
return f"OK: status={resp_status} headers={resp_headers} body='{resp_body}'"
142170
"#;
143171

144-
let cases = [
172+
let mut cases = vec![
145173
TestCase {
146-
resp_body: "case_1",
174+
resp: Ok(TestResponse {
175+
body: "case_1",
176+
..Default::default()
177+
}),
147178
..Default::default()
148179
},
149180
TestCase {
150-
resp_body: "case_2",
151181
method: "POST",
182+
resp: Ok(TestResponse {
183+
body: "case_2",
184+
..Default::default()
185+
}),
152186
..Default::default()
153187
},
154188
TestCase {
155-
resp_body: "case_3",
156-
path: "/foo",
189+
path: "/foo".to_owned(),
190+
resp: Ok(TestResponse {
191+
body: "case_3",
192+
..Default::default()
193+
}),
157194
..Default::default()
158195
},
159196
TestCase {
160-
resp_body: "case_4",
161-
path: "/201",
162-
resp_status: 500,
197+
path: "/500".to_owned(),
198+
resp: Ok(TestResponse {
199+
status: 500,
200+
body: "case_4",
201+
..Default::default()
202+
}),
203+
..Default::default()
204+
},
205+
TestCase {
206+
path: "/headers_in".to_owned(),
207+
requ_headers: vec![
208+
("foo".to_owned(), &["bar"]),
209+
("multi".to_owned(), &["some", "thing"]),
210+
],
211+
resp: Ok(TestResponse {
212+
body: "case_5",
213+
..Default::default()
214+
}),
215+
..Default::default()
216+
},
217+
TestCase {
218+
path: "/headers_out".to_owned(),
219+
resp: Ok(TestResponse {
220+
headers: vec![
221+
("foo".to_owned(), &["bar"]),
222+
("multi".to_owned(), &["some", "thing"]),
223+
],
224+
body: "case_6",
225+
..Default::default()
226+
}),
227+
..Default::default()
228+
},
229+
TestCase {
230+
base: Some("http://test.com"),
231+
resp: Err("HTTPConnectionPool(host='test.com', port=80): Max retries exceeded with url: / (Caused by ProtocolError('Connection aborted.', WasiErrorCode('Request failed with wasi http error ErrorCode_HttpRequestDenied')))".to_owned()),
163232
..Default::default()
164233
},
165234
];
235+
cases.extend(DEFAULT_FORBIDDEN_HEADERS.iter().map(|h| TestCase {
236+
path: format!("/forbidden_header/{h}"),
237+
requ_headers: vec![(h.to_string(), &["foo"])],
238+
resp: Err("Err { value: HeaderError_Forbidden }".to_owned()),
239+
..Default::default()
240+
}));
166241

167242
let server = MockServer::start().await;
168243
let mut permissions = AllowCertainHttpRequests::default();
169244

170245
let mut builder_method = StringBuilder::new();
171246
let mut builder_url = StringBuilder::new();
247+
let mut builder_headers = StringBuilder::new();
172248
let mut builder_result = StringBuilder::new();
173249

174250
for case in &cases {
175-
case.mock().mount(&server).await;
251+
if let Some(mock) = case.mock(&server) {
252+
mock.mount(&server).await;
253+
}
176254
permissions.allow(case.matcher(&server));
177255

178256
let TestCase {
257+
base,
179258
method,
180259
path,
181-
resp_body,
182-
resp_status,
260+
requ_headers,
261+
resp,
183262
} = case;
184263

185264
builder_method.append_value(method);
186-
builder_url.append_value(format!("{}{}", server.uri(), path));
187-
builder_result.append_value(format!("status={resp_status} body='{resp_body}'"));
265+
builder_url.append_value(format!(
266+
"{}{}",
267+
base.map(|b| b.to_owned()).unwrap_or_else(|| server.uri()),
268+
path
269+
));
270+
builder_headers.append_option(headers_to_string(requ_headers));
271+
272+
match resp {
273+
Ok(TestResponse {
274+
status,
275+
headers,
276+
body,
277+
}) => {
278+
let resp_headers = headers_to_string(headers).unwrap_or_else(|| "n/a".to_owned());
279+
builder_result.append_value(format!(
280+
"OK: status={status} headers={resp_headers} body='{body}'"
281+
));
282+
}
283+
Err(e) => {
284+
builder_result.append_value(format!("ERR: {e}"));
285+
}
286+
}
188287
}
189288

190289
let udfs = WasmScalarUdf::new(
@@ -202,10 +301,12 @@ def perform_request(method: str, url: str) -> str:
202301
args: vec![
203302
ColumnarValue::Array(Arc::new(builder_method.finish())),
204303
ColumnarValue::Array(Arc::new(builder_url.finish())),
304+
ColumnarValue::Array(Arc::new(builder_headers.finish())),
205305
],
206306
arg_fields: vec![
207307
Arc::new(Field::new("method", DataType::Utf8, true)),
208308
Arc::new(Field::new("url", DataType::Utf8, true)),
309+
Arc::new(Field::new("headers", DataType::Utf8, true)),
209310
],
210311
number_rows: cases.len(),
211312
return_field: Arc::new(Field::new("r", DataType::Utf8, true)),
@@ -216,20 +317,39 @@ def perform_request(method: str, url: str) -> str:
216317
assert_eq!(array.as_ref(), &builder_result.finish() as &dyn Array,);
217318
}
218319

320+
#[derive(Debug, Clone)]
321+
struct TestResponse {
322+
status: u16,
323+
headers: Vec<(String, &'static [&'static str])>,
324+
body: &'static str,
325+
}
326+
327+
impl Default for TestResponse {
328+
fn default() -> Self {
329+
Self {
330+
status: 200,
331+
headers: vec![],
332+
body: "",
333+
}
334+
}
335+
}
336+
219337
struct TestCase {
338+
base: Option<&'static str>,
220339
method: &'static str,
221-
path: &'static str,
222-
resp_body: &'static str,
223-
resp_status: u16,
340+
path: String,
341+
requ_headers: Vec<(String, &'static [&'static str])>,
342+
resp: Result<TestResponse, String>,
224343
}
225344

226345
impl Default for TestCase {
227346
fn default() -> Self {
228347
Self {
348+
base: None,
229349
method: "GET",
230-
path: "/",
231-
resp_body: "",
232-
resp_status: 200,
350+
path: "/".to_owned(),
351+
requ_headers: vec![],
352+
resp: Ok(TestResponse::default()),
233353
}
234354
}
235355
}
@@ -243,17 +363,83 @@ impl TestCase {
243363
}
244364
}
245365

246-
fn mock(&self) -> Mock {
366+
fn mock(&self, server: &MockServer) -> Option<Mock> {
247367
let Self {
368+
base,
248369
method,
249370
path,
250-
resp_body,
251-
resp_status,
371+
requ_headers,
372+
resp,
252373
} = self;
374+
if base.is_some() {
375+
return None;
376+
}
377+
378+
let TestResponse {
379+
status: resp_status,
380+
headers: resp_headers,
381+
body: resp_body,
382+
} = resp.clone().unwrap_or_default();
383+
384+
let mut builder = Mock::given(matchers::method(method))
385+
.and(matchers::path(path.as_str()))
386+
.and(NoForbiddenHeaders::new(
387+
server.address().ip().to_string(),
388+
server.address().port(),
389+
));
390+
391+
for (k, v) in requ_headers {
392+
builder = builder.and(matchers::headers(k.as_str(), v.to_vec()));
393+
}
394+
395+
let mock = builder
396+
.respond_with(
397+
ResponseTemplate::new(resp_status)
398+
.set_body_string(resp_body)
399+
.append_headers(resp_headers.iter().map(|(k, v)| (k, v.join(",")))),
400+
)
401+
.expect(resp.is_ok() as u64);
402+
Some(mock)
403+
}
404+
}
405+
406+
fn headers_to_string(headers: &[(String, &[&str])]) -> Option<String> {
407+
if headers.is_empty() {
408+
None
409+
} else {
410+
let headers = headers
411+
.iter()
412+
.map(|(k, v)| format!("{k}:{}", v.join(",")))
413+
.collect::<Vec<_>>();
414+
Some(headers.join(";"))
415+
}
416+
}
417+
418+
struct NoForbiddenHeaders {
419+
host: String,
420+
port: u16,
421+
}
422+
423+
impl NoForbiddenHeaders {
424+
fn new(host: String, port: u16) -> Self {
425+
Self { host, port }
426+
}
427+
}
428+
429+
impl wiremock::Match for NoForbiddenHeaders {
430+
fn matches(&self, request: &wiremock::Request) -> bool {
431+
// "host" is part of the forbidden headers that the client is not supposed to use, but it is set by our own
432+
// host HTTP lib
433+
let Some(host_val) = request.headers.get(http::header::HOST) else {
434+
return false;
435+
};
436+
if host_val.to_str().expect("always a string") != format!("{}:{}", self.host, self.port) {
437+
return false;
438+
}
253439

254-
Mock::given(matchers::method(method))
255-
.and(matchers::path(*path))
256-
.respond_with(ResponseTemplate::new(*resp_status).set_body_string(*resp_body))
257-
.expect(1)
440+
DEFAULT_FORBIDDEN_HEADERS
441+
.iter()
442+
.filter(|h| *h != http::header::HOST)
443+
.all(|h| !request.headers.contains_key(h))
258444
}
259445
}

0 commit comments

Comments
 (0)