Skip to content

Commit bd115dd

Browse files
authored
Simplify response types in messages with generics (#31)
1 parent 70e0ea7 commit bd115dd

File tree

4 files changed

+146
-164
lines changed

4 files changed

+146
-164
lines changed

client.go

Lines changed: 66 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
146146
return nil, nil, nil, nil, err
147147
}
148148

149-
enrollURL, err := url.JoinPath(c.dnServer, message.EnrollEndpoint)
149+
enrollURL, err := urlPath(c.dnServer, message.EnrollEndpoint)
150150
if err != nil {
151151
return nil, nil, nil, nil, err
152152
}
@@ -172,7 +172,7 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
172172
}
173173

174174
// Decode the response
175-
r := message.EnrollResponse{}
175+
r := message.APIResponse[message.EnrollResponseData]{}
176176
b, err := io.ReadAll(resp.Body)
177177
if err != nil {
178178
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
@@ -480,7 +480,7 @@ func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, valu
480480
}
481481
pbb := bytes.NewBuffer(postBody)
482482

483-
endpointV1URL, err := url.JoinPath(c.dnServer, message.EndpointV1)
483+
endpointV1URL, err := urlPath(c.dnServer, message.EndpointV1)
484484
if err != nil {
485485
return nil, err
486486
}
@@ -535,7 +535,7 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte,
535535
return nil, err
536536
}
537537

538-
endpointV1URL, err := url.JoinPath(c.dnServer, message.EndpointV1)
538+
endpointV1URL, err := urlPath(c.dnServer, message.EndpointV1)
539539
if err != nil {
540540
return nil, err
541541
}
@@ -570,6 +570,57 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte,
570570
}
571571
}
572572

573+
func callAPI[T any](ctx context.Context, c *Client, method string, endpoint string, payload map[string]any) (*T, error) {
574+
dest, err := urlPath(c.dnServer, endpoint)
575+
if err != nil {
576+
return nil, err
577+
}
578+
579+
var br io.Reader
580+
if payload != nil {
581+
b, err := json.Marshal(payload)
582+
if err != nil {
583+
return nil, fmt.Errorf("failed to marshal payload: %s", err)
584+
}
585+
br = bytes.NewReader(b)
586+
}
587+
588+
req, err := http.NewRequestWithContext(ctx, method, dest, br)
589+
if err != nil {
590+
return nil, err
591+
}
592+
593+
resp, err := c.client.Do(req)
594+
if err != nil {
595+
return nil, err
596+
}
597+
defer resp.Body.Close()
598+
599+
reqID := resp.Header.Get("X-Request-ID")
600+
601+
r := message.APIResponse[T]{}
602+
b, err := io.ReadAll(resp.Body)
603+
if err != nil {
604+
return nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
605+
}
606+
607+
if err := json.Unmarshal(b, &r); err != nil {
608+
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
609+
}
610+
611+
// Check for any errors returned by the API
612+
if err := r.Errors.ToError(); err != nil {
613+
return nil, &APIError{e: err, ReqID: reqID}
614+
}
615+
616+
// If we didn't detect an error in the response, but received a 4XX or 5XX status code, return error
617+
if resp.StatusCode >= 400 {
618+
return nil, &APIError{e: fmt.Errorf("received HTTP %d from API without error details\nbody: %s", resp.StatusCode, b), ReqID: reqID}
619+
}
620+
621+
return &r.Data, nil
622+
}
623+
573624
// StreamController is used for interacting with streaming requests to the API.
574625
//
575626
// When a streaming request is started in a background goroutine, a StreamController is returned to the caller to allow
@@ -643,89 +694,25 @@ func nonce() []byte {
643694
}
644695

645696
func (c *Client) EndpointPreAuth(ctx context.Context) (*message.PreAuthData, error) {
646-
dest, err := url.JoinPath(c.dnServer, message.PreAuthEndpoint)
647-
if err != nil {
648-
return nil, err
649-
}
650-
651-
req, err := http.NewRequestWithContext(ctx, "POST", dest, nil)
652-
if err != nil {
653-
return nil, err
654-
}
655-
656-
resp, err := c.client.Do(req)
657-
if err != nil {
658-
return nil, err
659-
}
660-
defer resp.Body.Close()
661-
662-
reqID := resp.Header.Get("X-Request-ID")
663-
respBody, err := io.ReadAll(resp.Body)
664-
if err != nil {
665-
return nil, &APIError{e: fmt.Errorf("failed to read the response body: %s", err), ReqID: reqID}
666-
}
667-
668-
switch resp.StatusCode {
669-
case http.StatusOK:
670-
r := message.PreAuthResponse{}
671-
if err = json.Unmarshal(respBody, &r); err != nil {
672-
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, respBody), ReqID: reqID}
673-
}
674-
675-
if r.Data.PollToken == "" || r.Data.LoginURL == "" {
676-
return nil, &APIError{e: fmt.Errorf("missing pollToken or loginURL"), ReqID: reqID}
677-
}
678-
679-
return &r.Data, nil
680-
default:
681-
var errors struct {
682-
Errors message.APIErrors
683-
}
684-
if err := json.Unmarshal(respBody, &errors); err != nil {
685-
return nil, fmt.Errorf("bad status code '%d', body: %s", resp.StatusCode, respBody)
686-
}
687-
return nil, &APIError{e: errors.Errors.ToError(), ReqID: reqID}
688-
}
697+
return callAPI[message.PreAuthData](ctx, c, "POST", message.PreAuthEndpoint, nil)
689698
}
690699

691700
func (c *Client) EndpointAuthPoll(ctx context.Context, pollCode string) (*message.EndpointAuthPollData, error) {
692-
pollURL, err := url.JoinPath(c.dnServer, message.EndpointAuthPoll)
693-
if err != nil {
694-
return nil, err
695-
}
696-
pollURL = fmt.Sprintf("%s?pollToken=%s", pollURL, url.QueryEscape(pollCode))
697-
698-
req, err := http.NewRequestWithContext(ctx, "GET", pollURL, nil)
699-
if err != nil {
700-
return nil, err
701-
}
701+
pollURL := fmt.Sprintf("%s?pollToken=%s", message.EndpointAuthPoll, url.QueryEscape(pollCode))
702+
return callAPI[message.EndpointAuthPollData](ctx, c, "GET", pollURL, nil)
703+
}
702704

703-
resp, err := c.client.Do(req)
705+
func urlPath(base, path string) (string, error) {
706+
baseURL, err := url.Parse(base)
704707
if err != nil {
705-
return nil, err
708+
return "", fmt.Errorf("invalid base: %s", err)
706709
}
707-
defer resp.Body.Close()
708710

709-
reqID := resp.Header.Get("X-Request-ID")
710-
respBody, err := io.ReadAll(resp.Body)
711+
pathURL, err := url.Parse(path)
711712
if err != nil {
712-
return nil, &APIError{e: fmt.Errorf("failed to read the response body: %s", err), ReqID: reqID}
713+
return "", fmt.Errorf("invalid path: %s", err)
713714
}
714715

715-
switch resp.StatusCode {
716-
case http.StatusOK:
717-
r := message.EndpointAuthPollResponse{}
718-
if err = json.Unmarshal(respBody, &r); err != nil {
719-
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, respBody), ReqID: reqID}
720-
}
721-
return &r.Data, nil
722-
default:
723-
var errors struct {
724-
Errors message.APIErrors
725-
}
726-
if err := json.Unmarshal(respBody, &errors); err != nil {
727-
return nil, fmt.Errorf("bad status code '%d', body: %s", resp.StatusCode, respBody)
728-
}
729-
return nil, &APIError{e: errors.Errors.ToError(), ReqID: reqID}
730-
}
716+
finalURL := baseURL.ResolveReference(pathURL)
717+
return finalURL.String(), nil
731718
}

0 commit comments

Comments
 (0)