@@ -4,17 +4,18 @@ use crate::pb::{hsm_client::HsmClient, Empty, HsmRequest, HsmRequestContext};
44use crate :: wire:: { DaemonConnection , Message } ;
55use anyhow:: { anyhow, Context } ;
66use anyhow:: { Error , Result } ;
7- use log:: { error, info, warn} ;
7+ use log:: { debug , error, info, warn} ;
88use std:: convert:: TryFrom ;
99use std:: env;
1010use std:: os:: unix:: io:: { AsRawFd , FromRawFd } ;
1111use std:: os:: unix:: net:: UnixStream ;
12+ use std:: path:: PathBuf ;
1213use std:: process:: Command ;
1314use std:: str;
1415use std:: sync:: atomic;
1516use std:: sync:: Arc ;
16- # [ cfg ( unix ) ]
17- use tokio:: net :: UnixStream as TokioUnixStream ;
17+ use std :: thread ;
18+ use tokio:: runtime :: Runtime ;
1819use tonic:: transport:: { Endpoint , Uri } ;
1920use tower:: service_fn;
2021use which:: which;
@@ -42,32 +43,35 @@ fn version() -> String {
4243
4344fn setup_node_stream ( ) -> Result < DaemonConnection , Error > {
4445 let ms = unsafe { UnixStream :: from_raw_fd ( 3 ) } ;
45- Ok ( DaemonConnection :: new ( TokioUnixStream :: from_std ( ms ) ? ) )
46+ Ok ( DaemonConnection :: new ( ms ) )
4647}
4748
48- fn start_handler ( local : NodeConnection , counter : Arc < atomic:: AtomicUsize > , grpc : GrpcClient ) {
49- tokio:: spawn ( async {
50- match process_requests ( local, counter, grpc)
51- . await
52- . context ( "processing requests" )
53- {
49+ fn start_handler (
50+ local : NodeConnection ,
51+ counter : Arc < atomic:: AtomicUsize > ,
52+ grpc : GrpcClient ,
53+ runtime : Arc < Runtime > ,
54+ ) {
55+ thread:: spawn ( move || {
56+ match process_requests ( local, counter, grpc, runtime) . context ( "processing requests" ) {
5457 Ok ( ( ) ) => panic ! ( "why did the hsmproxy stop processing requests without an error?" ) ,
5558 Err ( e) => warn ! ( "hsmproxy stopped processing requests with error: {}" , e) ,
5659 }
5760 } ) ;
5861}
5962
60- async fn process_requests (
63+ fn process_requests (
6164 node_conn : NodeConnection ,
6265 request_counter : Arc < atomic:: AtomicUsize > ,
6366 mut server : GrpcClient ,
67+ runtime : Arc < Runtime > ,
6468) -> Result < ( ) , Error > {
6569 let conn = node_conn. conn ;
6670 let context = node_conn. context ;
6771 info ! ( "Pinging server" ) ;
68- server. ping ( Empty :: default ( ) ) . await ?;
72+ runtime . block_on ( server. ping ( Empty :: default ( ) ) ) ?;
6973 loop {
70- if let Ok ( msg) = conn. read ( ) . await {
74+ if let Ok ( msg) = conn. read ( ) {
7175 match msg. msgtype ( ) {
7276 9 => {
7377 eprintln ! ( "Got a message from node: {:?}" , & msg. body) ;
@@ -79,16 +83,16 @@ async fn process_requests(
7983
8084 let ( local, remote) = UnixStream :: pair ( ) ?;
8185 let local = NodeConnection {
82- conn : DaemonConnection :: new ( TokioUnixStream :: from_std ( local) ? ) ,
86+ conn : DaemonConnection :: new ( local) ,
8387 context : Some ( ctx) ,
8488 } ;
8589 let remote = remote. as_raw_fd ( ) ;
8690 let msg = Message :: new_with_fds ( vec ! [ 0 , 109 ] , & vec ! [ remote] ) ;
8791
8892 let grpc = server. clone ( ) ;
8993 // Start new handler for the client
90- start_handler ( local, request_counter. clone ( ) , grpc) ;
91- if let Err ( e) = conn. write ( msg) . await {
94+ start_handler ( local, request_counter. clone ( ) , grpc, runtime . clone ( ) ) ;
95+ if let Err ( e) = conn. write ( msg) {
9296 error ! ( "error writing msg to node_connection: {:?}" , e) ;
9397 return Err ( e) ;
9498 }
@@ -102,22 +106,23 @@ async fn process_requests(
102106 requests : Vec :: new ( ) ,
103107 signer_state : Vec :: new ( ) ,
104108 } ) ;
105- let start_time = tokio :: time :: Instant :: now ( ) ;
109+
106110 eprintln ! (
107111 "WIRE: lightningd -> hsmd: Got a message from node: {:?}" ,
108112 & req
109113 ) ;
110- eprintln ! ( "WIRE: hsmd -> plugin: Forwarding: {:?}" , & req ) ;
111- let res = server . request ( req ) . await ? . into_inner ( ) ;
112- let msg = Message :: from_raw ( res . raw ) ;
114+ let start_time = tokio :: time :: Instant :: now ( ) ;
115+ debug ! ( "Got a message from node: {:?}" , & req ) ;
116+ let res = runtime . block_on ( server . request ( req ) ) ? . into_inner ( ) ;
113117 let delta = start_time. elapsed ( ) ;
118+ let msg = Message :: from_raw ( res. raw ) ;
114119 eprintln ! (
115120 "WIRE: plugin -> hsmd: Got respone from hsmd: {:?} after {}ms" ,
116121 & msg,
117122 delta. as_millis( )
118123 ) ;
119124 eprintln ! ( "WIRE: hsmd -> lightningd: {:?}" , & msg) ;
120- conn. write ( msg) . await ?
125+ conn. write ( msg) ?
121126 }
122127 }
123128 } else {
@@ -126,32 +131,34 @@ async fn process_requests(
126131 }
127132 }
128133}
129- use std:: path:: PathBuf ;
130- async fn grpc_connect ( ) -> Result < GrpcClient , Error > {
131- // We will ignore this uri because uds do not use it
132- // if your connector does use the uri it will be provided
133- // as the request to the `MakeConnection`.
134- // Connect to a Uds socket
135- let channel = Endpoint :: try_from ( "http://[::]:50051" ) ?
136- . connect_with_connector ( service_fn ( |_: Uri | {
137- let sock_path = get_sock_path ( ) . unwrap ( ) ;
138- let mut path = PathBuf :: new ( ) ;
139- if !sock_path. starts_with ( '/' ) {
140- path. push ( env:: current_dir ( ) . unwrap ( ) ) ;
141- }
142- path. push ( & sock_path) ;
143134
144- let path = path. to_str ( ) . unwrap ( ) . to_string ( ) ;
145- info ! ( "Connecting to hsmserver at {}" , path) ;
146- TokioUnixStream :: connect ( path)
147- } ) )
148- . await
149- . context ( "could not connect to the socket file" ) ?;
135+ fn grpc_connect ( runtime : & Runtime ) -> Result < GrpcClient , Error > {
136+ runtime. block_on ( async {
137+ // We will ignore this uri because uds do not use it
138+ // if your connector does use the uri it will be provided
139+ // as the request to the `MakeConnection`.
140+ // Connect to a Uds socket
141+ let channel = Endpoint :: try_from ( "http://[::]:50051" ) ?
142+ . connect_with_connector ( service_fn ( |_: Uri | {
143+ let sock_path = get_sock_path ( ) . unwrap ( ) ;
144+ let mut path = PathBuf :: new ( ) ;
145+ if !sock_path. starts_with ( '/' ) {
146+ path. push ( env:: current_dir ( ) . unwrap ( ) ) ;
147+ }
148+ path. push ( & sock_path) ;
149+
150+ let path = path. to_str ( ) . unwrap ( ) . to_string ( ) ;
151+ info ! ( "Connecting to hsmserver at {}" , path) ;
152+ tokio:: net:: UnixStream :: connect ( path)
153+ } ) )
154+ . await
155+ . context ( "could not connect to the socket file" ) ?;
150156
151- Ok ( HsmClient :: new ( channel) )
157+ Ok ( HsmClient :: new ( channel) )
158+ } )
152159}
153160
154- pub async fn run ( ) -> Result < ( ) , Error > {
161+ pub fn run ( ) -> Result < ( ) , Error > {
155162 let args: Vec < String > = std:: env:: args ( ) . collect ( ) ;
156163
157164 // Start the counter at 1000 so we can inject some message before
@@ -164,8 +171,16 @@ pub async fn run() -> Result<(), Error> {
164171
165172 info ! ( "Starting hsmproxy" ) ;
166173
174+ // Create a dedicated tokio runtime for gRPC operations
175+ let runtime = Arc :: new (
176+ tokio:: runtime:: Builder :: new_current_thread ( )
177+ . enable_all ( )
178+ . build ( )
179+ . context ( "failed to create tokio runtime" ) ?,
180+ ) ;
181+
167182 let node = setup_node_stream ( ) ?;
168- let grpc = grpc_connect ( ) . await ?;
183+ let grpc = grpc_connect ( & runtime ) ?;
169184
170185 process_requests (
171186 NodeConnection {
@@ -174,6 +189,6 @@ pub async fn run() -> Result<(), Error> {
174189 } ,
175190 request_counter,
176191 grpc,
192+ runtime,
177193 )
178- . await
179194}
0 commit comments