@@ -44,11 +44,13 @@ const (
4444 FeatureQuickReconnects = "quick_reconnects"
4545)
4646
47- type registerRPCName string
47+ type rpcName string
4848
4949const (
50- register registerRPCName = "register"
51- reconnect registerRPCName = "reconnect"
50+ register rpcName = "register"
51+ reconnect rpcName = "reconnect"
52+ unregister rpcName = "unregister"
53+ authenticate rpcName = " authenticate"
5254)
5355
5456type TunnelConfig struct {
@@ -121,7 +123,7 @@ type clientRegisterTunnelError struct {
121123 cause error
122124}
123125
124- func newClientRegisterTunnelError (cause error , counter * prometheus.CounterVec , name registerRPCName ) clientRegisterTunnelError {
126+ func newRPCError (cause error , counter * prometheus.CounterVec , name rpcName ) clientRegisterTunnelError {
125127 counter .WithLabelValues (cause .Error (), string (name )).Inc ()
126128 return clientRegisterTunnelError {cause : cause }
127129}
@@ -337,7 +339,7 @@ func ServeTunnel(
337339 if config .NamedTunnel != nil {
338340 _ = UnregisterConnection (ctx , handler .muxer , config )
339341 } else {
340- _ = UnregisterTunnel (handler .muxer , config . GracePeriod , config . TransportLogger )
342+ _ = UnregisterTunnel (handler .muxer , config )
341343 }
342344 }
343345 handler .muxer .Shutdown ()
@@ -417,14 +419,13 @@ func RegisterConnection(
417419 const registerConnection = "registerConnection"
418420
419421 config .TransportLogger .Debug ("initiating RPC stream for RegisterConnection" )
420- rpc , err := connection . NewRPCClient (ctx , muxer , config . TransportLogger , openStreamTimeout )
422+ rpcClient , err := newTunnelRPCClient (ctx , muxer , config , registerConnection )
421423 if err != nil {
422- // RPC stream open error
423- return newClientRegisterTunnelError (err , config .Metrics .rpcFail , registerConnection )
424+ return err
424425 }
425- defer rpc .Close ()
426+ defer rpcClient .Close ()
426427
427- conn , err := rpc .RegisterConnection (
428+ conn , err := rpcClient .RegisterConnection (
428429 ctx ,
429430 config .NamedTunnel .Auth ,
430431 config .NamedTunnel .ID ,
@@ -470,14 +471,14 @@ func UnregisterConnection(
470471 config * TunnelConfig ,
471472) error {
472473 config .TransportLogger .Debug ("initiating RPC stream for UnregisterConnection" )
473- rpc , err := connection . NewRPCClient (ctx , muxer , config . TransportLogger , openStreamTimeout )
474+ rpcClient , err := newTunnelRPCClient (ctx , muxer , config , register )
474475 if err != nil {
475476 // RPC stream open error
476- return newClientRegisterTunnelError ( err , config . Metrics . rpcFail , register )
477+ return err
477478 }
478- defer rpc .Close ()
479+ defer rpcClient .Close ()
479480
480- return rpc .UnregisterConnection (ctx )
481+ return rpcClient .UnregisterConnection (ctx )
481482}
482483
483484func RegisterTunnel (
@@ -494,18 +495,18 @@ func RegisterTunnel(
494495 if config .TunnelEventChan != nil {
495496 config .TunnelEventChan <- ui.TunnelEvent {EventType : ui .RegisteringTunnel }
496497 }
497- tunnelServer , err := connection .NewRPCClient (ctx , muxer , config .TransportLogger , openStreamTimeout )
498+
499+ rpcClient , err := newTunnelRPCClient (ctx , muxer , config , register )
498500 if err != nil {
499- // RPC stream open error
500- return newClientRegisterTunnelError (err , config .Metrics .rpcFail , register )
501+ return err
501502 }
502- defer tunnelServer .Close ()
503+ defer rpcClient .Close ()
503504 // Request server info without blocking tunnel registration; must use capnp library directly.
504- serverInfoPromise := tunnelrpc.TunnelServer {Client : tunnelServer .Client }.GetServerInfo (ctx , func (tunnelrpc.TunnelServer_getServerInfo_Params ) error {
505+ serverInfoPromise := tunnelrpc.TunnelServer {Client : rpcClient .Client }.GetServerInfo (ctx , func (tunnelrpc.TunnelServer_getServerInfo_Params ) error {
505506 return nil
506507 })
507508 LogServerInfo (serverInfoPromise .Result (), connectionID , config .Metrics , logger , config .TunnelEventChan )
508- registration := tunnelServer .RegisterTunnel (
509+ registration := rpcClient .RegisterTunnel (
509510 ctx ,
510511 config .OriginCert ,
511512 config .Hostname ,
@@ -529,7 +530,7 @@ func processRegistrationSuccess(
529530 logger logger.Service ,
530531 connectionID uint8 ,
531532 registration * tunnelpogs.TunnelRegistration ,
532- name registerRPCName ,
533+ name rpcName ,
533534 credentialManager * reconnectCredentialManager ,
534535) error {
535536 for _ , logLine := range registration .LogLines {
@@ -563,7 +564,7 @@ func processRegistrationSuccess(
563564 return nil
564565}
565566
566- func processRegisterTunnelError (err tunnelpogs.TunnelRegistrationError , metrics * TunnelMetrics , name registerRPCName ) error {
567+ func processRegisterTunnelError (err tunnelpogs.TunnelRegistrationError , metrics * TunnelMetrics , name rpcName ) error {
567568 if err .Error () == DuplicateConnectionError {
568569 metrics .regFail .WithLabelValues ("dup_edge_conn" , string (name )).Inc ()
569570 return errDuplicationConnection
@@ -575,18 +576,18 @@ func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics
575576 }
576577}
577578
578- func UnregisterTunnel (muxer * h2mux.Muxer , gracePeriod time. Duration , logger logger. Service ) error {
579- logger .Debug ("initiating RPC stream to unregister" )
579+ func UnregisterTunnel (muxer * h2mux.Muxer , config * TunnelConfig ) error {
580+ config . TransportLogger .Debug ("initiating RPC stream to unregister" )
580581 ctx := context .Background ()
581- tunnelServer , err := connection . NewRPCClient (ctx , muxer , logger , openStreamTimeout )
582+ rpcClient , err := newTunnelRPCClient (ctx , muxer , config , unregister )
582583 if err != nil {
583584 // RPC stream open error
584585 return err
585586 }
586- defer tunnelServer .Close ()
587+ defer rpcClient .Close ()
587588
588589 // gracePeriod is encoded in int64 using capnproto
589- return tunnelServer .UnregisterTunnel (ctx , gracePeriod .Nanoseconds ())
590+ return rpcClient .UnregisterTunnel (ctx , config . GracePeriod .Nanoseconds ())
590591}
591592
592593func LogServerInfo (
@@ -909,3 +910,18 @@ func findCfRayHeader(h1 *http.Request) string {
909910func isLBProbeRequest (req * http.Request ) bool {
910911 return strings .HasPrefix (req .UserAgent (), lbProbeUserAgentPrefix )
911912}
913+
914+ func newTunnelRPCClient (ctx context.Context , muxer * h2mux.Muxer , config * TunnelConfig , rpcName rpcName ) (tunnelpogs.TunnelServer_PogsClient , error ) {
915+ openStreamCtx , openStreamCancel := context .WithTimeout (ctx , openStreamTimeout )
916+ defer openStreamCancel ()
917+ stream , err := muxer .OpenRPCStream (openStreamCtx )
918+ if err != nil {
919+ return tunnelpogs.TunnelServer_PogsClient {}, err
920+ }
921+ rpcClient , err := connection .NewTunnelRPCClient (ctx , stream , config .TransportLogger )
922+ if err != nil {
923+ // RPC stream open error
924+ return tunnelpogs.TunnelServer_PogsClient {}, newRPCError (err , config .Metrics .rpcFail , rpcName )
925+ }
926+ return rpcClient , nil
927+ }
0 commit comments