@@ -146,7 +146,7 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
146146 return nil , nil , nil , nil , err
147147 }
148148
149- enrollURL , err := url . JoinPath (c .dnServer , message .EnrollEndpoint )
149+ enrollURL , err := urlPath (c .dnServer , message .EnrollEndpoint )
150150 if err != nil {
151151 return nil , nil , nil , nil , err
152152 }
@@ -172,7 +172,7 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
172172 }
173173
174174 // Decode the response
175- r := message.EnrollResponse {}
175+ r := message.APIResponse [message. EnrollResponseData ] {}
176176 b , err := io .ReadAll (resp .Body )
177177 if err != nil {
178178 return nil , nil , nil , nil , & APIError {e : fmt .Errorf ("error reading response body: %s" , err ), ReqID : reqID }
@@ -480,7 +480,7 @@ func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, valu
480480 }
481481 pbb := bytes .NewBuffer (postBody )
482482
483- endpointV1URL , err := url . JoinPath (c .dnServer , message .EndpointV1 )
483+ endpointV1URL , err := urlPath (c .dnServer , message .EndpointV1 )
484484 if err != nil {
485485 return nil , err
486486 }
@@ -535,7 +535,7 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte,
535535 return nil , err
536536 }
537537
538- endpointV1URL , err := url . JoinPath (c .dnServer , message .EndpointV1 )
538+ endpointV1URL , err := urlPath (c .dnServer , message .EndpointV1 )
539539 if err != nil {
540540 return nil , err
541541 }
@@ -570,6 +570,57 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte,
570570 }
571571}
572572
573+ func callAPI [T any ](ctx context.Context , c * Client , method string , endpoint string , payload map [string ]any ) (* T , error ) {
574+ dest , err := urlPath (c .dnServer , endpoint )
575+ if err != nil {
576+ return nil , err
577+ }
578+
579+ var br io.Reader
580+ if payload != nil {
581+ b , err := json .Marshal (payload )
582+ if err != nil {
583+ return nil , fmt .Errorf ("failed to marshal payload: %s" , err )
584+ }
585+ br = bytes .NewReader (b )
586+ }
587+
588+ req , err := http .NewRequestWithContext (ctx , method , dest , br )
589+ if err != nil {
590+ return nil , err
591+ }
592+
593+ resp , err := c .client .Do (req )
594+ if err != nil {
595+ return nil , err
596+ }
597+ defer resp .Body .Close ()
598+
599+ reqID := resp .Header .Get ("X-Request-ID" )
600+
601+ r := message.APIResponse [T ]{}
602+ b , err := io .ReadAll (resp .Body )
603+ if err != nil {
604+ return nil , & APIError {e : fmt .Errorf ("error reading response body: %s" , err ), ReqID : reqID }
605+ }
606+
607+ if err := json .Unmarshal (b , & r ); err != nil {
608+ return nil , & APIError {e : fmt .Errorf ("error decoding JSON response: %s\n body: %s" , err , b ), ReqID : reqID }
609+ }
610+
611+ // Check for any errors returned by the API
612+ if err := r .Errors .ToError (); err != nil {
613+ return nil , & APIError {e : err , ReqID : reqID }
614+ }
615+
616+ // If we didn't detect an error in the response, but received a 4XX or 5XX status code, return error
617+ if resp .StatusCode >= 400 {
618+ return nil , & APIError {e : fmt .Errorf ("received HTTP %d from API without error details\n body: %s" , resp .StatusCode , b ), ReqID : reqID }
619+ }
620+
621+ return & r .Data , nil
622+ }
623+
573624// StreamController is used for interacting with streaming requests to the API.
574625//
575626// When a streaming request is started in a background goroutine, a StreamController is returned to the caller to allow
@@ -643,89 +694,25 @@ func nonce() []byte {
643694}
644695
645696func (c * Client ) EndpointPreAuth (ctx context.Context ) (* message.PreAuthData , error ) {
646- dest , err := url .JoinPath (c .dnServer , message .PreAuthEndpoint )
647- if err != nil {
648- return nil , err
649- }
650-
651- req , err := http .NewRequestWithContext (ctx , "POST" , dest , nil )
652- if err != nil {
653- return nil , err
654- }
655-
656- resp , err := c .client .Do (req )
657- if err != nil {
658- return nil , err
659- }
660- defer resp .Body .Close ()
661-
662- reqID := resp .Header .Get ("X-Request-ID" )
663- respBody , err := io .ReadAll (resp .Body )
664- if err != nil {
665- return nil , & APIError {e : fmt .Errorf ("failed to read the response body: %s" , err ), ReqID : reqID }
666- }
667-
668- switch resp .StatusCode {
669- case http .StatusOK :
670- r := message.PreAuthResponse {}
671- if err = json .Unmarshal (respBody , & r ); err != nil {
672- return nil , & APIError {e : fmt .Errorf ("error decoding JSON response: %s\n body: %s" , err , respBody ), ReqID : reqID }
673- }
674-
675- if r .Data .PollToken == "" || r .Data .LoginURL == "" {
676- return nil , & APIError {e : fmt .Errorf ("missing pollToken or loginURL" ), ReqID : reqID }
677- }
678-
679- return & r .Data , nil
680- default :
681- var errors struct {
682- Errors message.APIErrors
683- }
684- if err := json .Unmarshal (respBody , & errors ); err != nil {
685- return nil , fmt .Errorf ("bad status code '%d', body: %s" , resp .StatusCode , respBody )
686- }
687- return nil , & APIError {e : errors .Errors .ToError (), ReqID : reqID }
688- }
697+ return callAPI [message.PreAuthData ](ctx , c , "POST" , message .PreAuthEndpoint , nil )
689698}
690699
691700func (c * Client ) EndpointAuthPoll (ctx context.Context , pollCode string ) (* message.EndpointAuthPollData , error ) {
692- pollURL , err := url .JoinPath (c .dnServer , message .EndpointAuthPoll )
693- if err != nil {
694- return nil , err
695- }
696- pollURL = fmt .Sprintf ("%s?pollToken=%s" , pollURL , url .QueryEscape (pollCode ))
697-
698- req , err := http .NewRequestWithContext (ctx , "GET" , pollURL , nil )
699- if err != nil {
700- return nil , err
701- }
701+ pollURL := fmt .Sprintf ("%s?pollToken=%s" , message .EndpointAuthPoll , url .QueryEscape (pollCode ))
702+ return callAPI [message.EndpointAuthPollData ](ctx , c , "GET" , pollURL , nil )
703+ }
702704
703- resp , err := c .client .Do (req )
705+ func urlPath (base , path string ) (string , error ) {
706+ baseURL , err := url .Parse (base )
704707 if err != nil {
705- return nil , err
708+ return "" , fmt . Errorf ( "invalid base: %s" , err )
706709 }
707- defer resp .Body .Close ()
708710
709- reqID := resp .Header .Get ("X-Request-ID" )
710- respBody , err := io .ReadAll (resp .Body )
711+ pathURL , err := url .Parse (path )
711712 if err != nil {
712- return nil , & APIError { e : fmt .Errorf ("failed to read the response body : %s" , err ), ReqID : reqID }
713+ return "" , fmt .Errorf ("invalid path : %s" , err )
713714 }
714715
715- switch resp .StatusCode {
716- case http .StatusOK :
717- r := message.EndpointAuthPollResponse {}
718- if err = json .Unmarshal (respBody , & r ); err != nil {
719- return nil , & APIError {e : fmt .Errorf ("error decoding JSON response: %s\n body: %s" , err , respBody ), ReqID : reqID }
720- }
721- return & r .Data , nil
722- default :
723- var errors struct {
724- Errors message.APIErrors
725- }
726- if err := json .Unmarshal (respBody , & errors ); err != nil {
727- return nil , fmt .Errorf ("bad status code '%d', body: %s" , resp .StatusCode , respBody )
728- }
729- return nil , & APIError {e : errors .Errors .ToError (), ReqID : reqID }
730- }
716+ finalURL := baseURL .ResolveReference (pathURL )
717+ return finalURL .String (), nil
731718}
0 commit comments