Skip to content

Commit 6803835

Browse files
committed
rpc: set mime application/json
1 parent 50888e4 commit 6803835

File tree

2 files changed

+46
-14
lines changed

2 files changed

+46
-14
lines changed

ra-rpc/src/rocket_helper.rs

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,42 @@ use rocket::{
88
listener::Endpoint,
99
mtls::Certificate,
1010
request::{FromRequest, Outcome},
11-
response::status::Custom,
11+
response::{status::Custom, Responder},
1212
Request,
1313
};
1414
use rocket_vsock_listener::VsockEndpoint;
1515
use tracing::warn;
1616

1717
use crate::{encode_error, CallContext, RemoteEndpoint, RpcCall};
1818

19+
pub struct RpcResponse {
20+
is_json: bool,
21+
status: Status,
22+
body: Vec<u8>,
23+
}
24+
25+
impl<'r> Responder<'r, 'static> for RpcResponse {
26+
fn respond_to(self, request: &'r Request<'_>) -> rocket::response::Result<'static> {
27+
use rocket::http::ContentType;
28+
let content_type = if self.is_json {
29+
ContentType::JSON
30+
} else {
31+
ContentType::Binary
32+
};
33+
let response = Custom(self.status, self.body).respond_to(request)?;
34+
rocket::Response::build_from(response)
35+
.header(content_type)
36+
.ok()
37+
}
38+
}
39+
1940
#[derive(Debug, Clone)]
2041
pub struct QuoteVerifier {
2142
pccs_url: Option<String>,
2243
}
2344

2445
pub mod deps {
25-
pub use super::{PrpcHandler, RpcRequest};
26-
pub use rocket::response::status::Custom;
46+
pub use super::{PrpcHandler, RpcRequest, RpcResponse};
2747
pub use rocket::{Data, State};
2848
}
2949

@@ -64,7 +84,7 @@ macro_rules! declare_prpc_routes {
6484
method: &'a str,
6585
rpc_request: $crate::rocket_helper::deps::RpcRequest<'a>,
6686
data: $crate::rocket_helper::deps::Data<'d>,
67-
) -> $crate::rocket_helper::deps::Custom<Vec<u8>> {
87+
) -> $crate::rocket_helper::deps::RpcResponse {
6888
$crate::rocket_helper::deps::PrpcHandler::builder()
6989
.state(&**state)
7090
.request(rpc_request)
@@ -80,7 +100,7 @@ macro_rules! declare_prpc_routes {
80100
state: &$crate::rocket_helper::deps::State<$state>,
81101
method: &str,
82102
rpc_request: $crate::rocket_helper::deps::RpcRequest<'_>,
83-
) -> $crate::rocket_helper::deps::Custom<Vec<u8>> {
103+
) -> $crate::rocket_helper::deps::RpcResponse {
84104
$crate::rocket_helper::deps::PrpcHandler::builder()
85105
.state(&**state)
86106
.request(rpc_request)
@@ -99,7 +119,7 @@ macro_rules! prpc_alias {
99119
async fn $name(
100120
state: &$crate::rocket_helper::deps::State<$state>,
101121
rpc_request: $crate::rocket_helper::deps::RpcRequest<'_>,
102-
) -> $crate::rocket_helper::deps::Custom<Vec<u8>> {
122+
) -> $crate::rocket_helper::deps::RpcResponse {
103123
$prpc(state, $method, rpc_request).await
104124
}
105125
};
@@ -109,7 +129,7 @@ macro_rules! prpc_alias {
109129
state: &'a $crate::rocket_helper::deps::State<$state>,
110130
rpc_request: $crate::rocket_helper::deps::RpcRequest<'a>,
111131
data: $crate::rocket_helper::deps::Data<'d>,
112-
) -> $crate::rocket_helper::deps::Custom<Vec<u8>> {
132+
) -> $crate::rocket_helper::deps::RpcResponse {
113133
$prpc(state, $method, rpc_request, data).await
114134
}
115135
};
@@ -175,6 +195,7 @@ pub struct RpcRequest<'r> {
175195
limits: &'r Limits,
176196
content_type: Option<&'r ContentType>,
177197
json: bool,
198+
is_get: bool,
178199
}
179200

180201
#[rocket::async_trait]
@@ -190,12 +211,13 @@ impl<'r> FromRequest<'r> for RpcRequest<'r> {
190211
limits: from_request!(request),
191212
content_type: from_request!(request),
192213
json: request.method() == Method::Get || query_field_get_bool(request, "json"),
214+
is_get: request.method() == Method::Get,
193215
})
194216
}
195217
}
196218

197219
impl<S> PrpcHandler<'_, '_, S> {
198-
pub async fn handle<Call: RpcCall<S>>(self) -> Custom<Vec<u8>> {
220+
pub async fn handle<Call: RpcCall<S>>(self) -> RpcResponse {
199221
let json = self.request.json;
200222
let result = handle_prpc_impl::<S, Call>(self).await;
201223
match result {
@@ -204,7 +226,11 @@ impl<S> PrpcHandler<'_, '_, S> {
204226
let estr = format!("{e:?}");
205227
warn!("error handling prpc: {estr}");
206228
let body = encode_error(json, estr);
207-
Custom(Status::BadRequest, body)
229+
RpcResponse {
230+
is_json: json,
231+
status: Status::BadRequest,
232+
body,
233+
}
208234
}
209235
}
210236
}
@@ -232,7 +258,7 @@ impl From<Endpoint> for RemoteEndpoint {
232258

233259
pub async fn handle_prpc_impl<S, Call: RpcCall<S>>(
234260
args: PrpcHandler<'_, '_, S>,
235-
) -> Result<Custom<Vec<u8>>> {
261+
) -> Result<RpcResponse> {
236262
let PrpcHandler {
237263
state,
238264
request,
@@ -267,7 +293,6 @@ pub async fn handle_prpc_impl<S, Call: RpcCall<S>>(
267293
}
268294
_ => None,
269295
};
270-
let is_get = data.is_none();
271296
let payload = match data {
272297
Some(data) => {
273298
let limit = limit_for_method(method, request.limits);
@@ -280,16 +305,22 @@ pub async fn handle_prpc_impl<S, Call: RpcCall<S>>(
280305
.query()
281306
.map_or(vec![], |q| q.as_bytes().to_vec()),
282307
};
283-
let json = request.json || request.content_type.map(|t| t.is_json()).unwrap_or(false);
308+
let is_json = request.json || request.content_type.map(|t| t.is_json()).unwrap_or(false);
284309
let context = CallContext {
285310
state,
286311
attestation,
287312
remote_endpoint: request.remote_addr.cloned().map(RemoteEndpoint::from),
288313
remote_app_id,
289314
};
290315
let call = Call::construct(context).context("failed to construct call")?;
291-
let (status_code, output) = call.call(method.to_string(), payload, json, is_get).await;
292-
Ok(Custom(Status::new(status_code), output))
316+
let (status_code, output) = call
317+
.call(method.to_string(), payload, is_json, request.is_get)
318+
.await;
319+
Ok(RpcResponse {
320+
is_json,
321+
status: Status::new(status_code),
322+
body: output,
323+
})
293324
}
294325

295326
struct RocketCertificate<'a>(&'a rocket::mtls::Certificate<'a>);

vmm/rpc/build.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ fn main() {
55
.build_scale_ext(false)
66
.disable_package_emission()
77
.enable_serde_extension()
8+
.disable_service_name_emission()
89
.compile_dir("./proto")
910
.expect("failed to compile proto files");
1011
}

0 commit comments

Comments
 (0)