@@ -8,22 +8,42 @@ use rocket::{
88 listener:: Endpoint ,
99 mtls:: Certificate ,
1010 request:: { FromRequest , Outcome } ,
11- response:: status:: Custom ,
11+ response:: { status:: Custom , Responder } ,
1212 Request ,
1313} ;
1414use rocket_vsock_listener:: VsockEndpoint ;
1515use tracing:: warn;
1616
1717use crate :: { encode_error, CallContext , RemoteEndpoint , RpcCall } ;
1818
19+ pub struct RpcResponse {
20+ is_json : bool ,
21+ status : Status ,
22+ body : Vec < u8 > ,
23+ }
24+
25+ impl < ' r > Responder < ' r , ' static > for RpcResponse {
26+ fn respond_to ( self , request : & ' r Request < ' _ > ) -> rocket:: response:: Result < ' static > {
27+ use rocket:: http:: ContentType ;
28+ let content_type = if self . is_json {
29+ ContentType :: JSON
30+ } else {
31+ ContentType :: Binary
32+ } ;
33+ let response = Custom ( self . status , self . body ) . respond_to ( request) ?;
34+ rocket:: Response :: build_from ( response)
35+ . header ( content_type)
36+ . ok ( )
37+ }
38+ }
39+
1940#[ derive( Debug , Clone ) ]
2041pub struct QuoteVerifier {
2142 pccs_url : Option < String > ,
2243}
2344
2445pub mod deps {
25- pub use super :: { PrpcHandler , RpcRequest } ;
26- pub use rocket:: response:: status:: Custom ;
46+ pub use super :: { PrpcHandler , RpcRequest , RpcResponse } ;
2747 pub use rocket:: { Data , State } ;
2848}
2949
@@ -64,7 +84,7 @@ macro_rules! declare_prpc_routes {
6484 method: & ' a str ,
6585 rpc_request: $crate:: rocket_helper:: deps:: RpcRequest <' a>,
6686 data: $crate:: rocket_helper:: deps:: Data <' d>,
67- ) -> $crate:: rocket_helper:: deps:: Custom < Vec < u8 >> {
87+ ) -> $crate:: rocket_helper:: deps:: RpcResponse {
6888 $crate:: rocket_helper:: deps:: PrpcHandler :: builder( )
6989 . state( & * * state)
7090 . request( rpc_request)
@@ -80,7 +100,7 @@ macro_rules! declare_prpc_routes {
80100 state: & $crate:: rocket_helper:: deps:: State <$state>,
81101 method: & str ,
82102 rpc_request: $crate:: rocket_helper:: deps:: RpcRequest <' _>,
83- ) -> $crate:: rocket_helper:: deps:: Custom < Vec < u8 >> {
103+ ) -> $crate:: rocket_helper:: deps:: RpcResponse {
84104 $crate:: rocket_helper:: deps:: PrpcHandler :: builder( )
85105 . state( & * * state)
86106 . request( rpc_request)
@@ -99,7 +119,7 @@ macro_rules! prpc_alias {
99119 async fn $name(
100120 state: & $crate:: rocket_helper:: deps:: State <$state>,
101121 rpc_request: $crate:: rocket_helper:: deps:: RpcRequest <' _>,
102- ) -> $crate:: rocket_helper:: deps:: Custom < Vec < u8 >> {
122+ ) -> $crate:: rocket_helper:: deps:: RpcResponse {
103123 $prpc( state, $method, rpc_request) . await
104124 }
105125 } ;
@@ -109,7 +129,7 @@ macro_rules! prpc_alias {
109129 state: & ' a $crate:: rocket_helper:: deps:: State <$state>,
110130 rpc_request: $crate:: rocket_helper:: deps:: RpcRequest <' a>,
111131 data: $crate:: rocket_helper:: deps:: Data <' d>,
112- ) -> $crate:: rocket_helper:: deps:: Custom < Vec < u8 >> {
132+ ) -> $crate:: rocket_helper:: deps:: RpcResponse {
113133 $prpc( state, $method, rpc_request, data) . await
114134 }
115135 } ;
@@ -175,6 +195,7 @@ pub struct RpcRequest<'r> {
175195 limits : & ' r Limits ,
176196 content_type : Option < & ' r ContentType > ,
177197 json : bool ,
198+ is_get : bool ,
178199}
179200
180201#[ rocket:: async_trait]
@@ -190,12 +211,13 @@ impl<'r> FromRequest<'r> for RpcRequest<'r> {
190211 limits : from_request ! ( request) ,
191212 content_type : from_request ! ( request) ,
192213 json : request. method ( ) == Method :: Get || query_field_get_bool ( request, "json" ) ,
214+ is_get : request. method ( ) == Method :: Get ,
193215 } )
194216 }
195217}
196218
197219impl < S > PrpcHandler < ' _ , ' _ , S > {
198- pub async fn handle < Call : RpcCall < S > > ( self ) -> Custom < Vec < u8 > > {
220+ pub async fn handle < Call : RpcCall < S > > ( self ) -> RpcResponse {
199221 let json = self . request . json ;
200222 let result = handle_prpc_impl :: < S , Call > ( self ) . await ;
201223 match result {
@@ -204,7 +226,11 @@ impl<S> PrpcHandler<'_, '_, S> {
204226 let estr = format ! ( "{e:?}" ) ;
205227 warn ! ( "error handling prpc: {estr}" ) ;
206228 let body = encode_error ( json, estr) ;
207- Custom ( Status :: BadRequest , body)
229+ RpcResponse {
230+ is_json : json,
231+ status : Status :: BadRequest ,
232+ body,
233+ }
208234 }
209235 }
210236 }
@@ -232,7 +258,7 @@ impl From<Endpoint> for RemoteEndpoint {
232258
233259pub async fn handle_prpc_impl < S , Call : RpcCall < S > > (
234260 args : PrpcHandler < ' _ , ' _ , S > ,
235- ) -> Result < Custom < Vec < u8 > > > {
261+ ) -> Result < RpcResponse > {
236262 let PrpcHandler {
237263 state,
238264 request,
@@ -267,7 +293,6 @@ pub async fn handle_prpc_impl<S, Call: RpcCall<S>>(
267293 }
268294 _ => None ,
269295 } ;
270- let is_get = data. is_none ( ) ;
271296 let payload = match data {
272297 Some ( data) => {
273298 let limit = limit_for_method ( method, request. limits ) ;
@@ -280,16 +305,22 @@ pub async fn handle_prpc_impl<S, Call: RpcCall<S>>(
280305 . query ( )
281306 . map_or ( vec ! [ ] , |q| q. as_bytes ( ) . to_vec ( ) ) ,
282307 } ;
283- let json = request. json || request. content_type . map ( |t| t. is_json ( ) ) . unwrap_or ( false ) ;
308+ let is_json = request. json || request. content_type . map ( |t| t. is_json ( ) ) . unwrap_or ( false ) ;
284309 let context = CallContext {
285310 state,
286311 attestation,
287312 remote_endpoint : request. remote_addr . cloned ( ) . map ( RemoteEndpoint :: from) ,
288313 remote_app_id,
289314 } ;
290315 let call = Call :: construct ( context) . context ( "failed to construct call" ) ?;
291- let ( status_code, output) = call. call ( method. to_string ( ) , payload, json, is_get) . await ;
292- Ok ( Custom ( Status :: new ( status_code) , output) )
316+ let ( status_code, output) = call
317+ . call ( method. to_string ( ) , payload, is_json, request. is_get )
318+ . await ;
319+ Ok ( RpcResponse {
320+ is_json,
321+ status : Status :: new ( status_code) ,
322+ body : output,
323+ } )
293324}
294325
295326struct RocketCertificate < ' a > ( & ' a rocket:: mtls:: Certificate < ' a > ) ;
0 commit comments