diff --git a/client.go b/client.go index f64d98b..b4cc441 100644 --- a/client.go +++ b/client.go @@ -146,7 +146,7 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str return nil, nil, nil, nil, err } - enrollURL, err := url.JoinPath(c.dnServer, message.EnrollEndpoint) + enrollURL, err := urlPath(c.dnServer, message.EnrollEndpoint) if err != nil { return nil, nil, nil, nil, err } @@ -172,7 +172,7 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str } // Decode the response - r := message.EnrollResponse{} + r := message.APIResponse[message.EnrollResponseData]{} b, err := io.ReadAll(resp.Body) if err != nil { return nil, nil, nil, nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID} @@ -480,7 +480,7 @@ func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, valu } pbb := bytes.NewBuffer(postBody) - endpointV1URL, err := url.JoinPath(c.dnServer, message.EndpointV1) + endpointV1URL, err := urlPath(c.dnServer, message.EndpointV1) if err != nil { return nil, err } @@ -535,7 +535,7 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte, return nil, err } - endpointV1URL, err := url.JoinPath(c.dnServer, message.EndpointV1) + endpointV1URL, err := urlPath(c.dnServer, message.EndpointV1) if err != nil { return nil, err } @@ -570,6 +570,57 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte, } } +func callAPI[T any](ctx context.Context, c *Client, method string, endpoint string, payload map[string]any) (*T, error) { + dest, err := urlPath(c.dnServer, endpoint) + if err != nil { + return nil, err + } + + var br io.Reader + if payload != nil { + b, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal payload: %s", err) + } + br = bytes.NewReader(b) + } + + req, err := http.NewRequestWithContext(ctx, method, dest, br) + if err != nil { + return nil, err + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + reqID := resp.Header.Get("X-Request-ID") + + r := message.APIResponse[T]{} + b, err := io.ReadAll(resp.Body) + if err != nil { + return nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID} + } + + if err := json.Unmarshal(b, &r); err != nil { + return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID} + } + + // Check for any errors returned by the API + if err := r.Errors.ToError(); err != nil { + return nil, &APIError{e: err, ReqID: reqID} + } + + // If we didn't detect an error in the response, but received a 4XX or 5XX status code, return error + if resp.StatusCode >= 400 { + return nil, &APIError{e: fmt.Errorf("received HTTP %d from API without error details\nbody: %s", resp.StatusCode, b), ReqID: reqID} + } + + return &r.Data, nil +} + // StreamController is used for interacting with streaming requests to the API. // // When a streaming request is started in a background goroutine, a StreamController is returned to the caller to allow @@ -643,89 +694,25 @@ func nonce() []byte { } func (c *Client) EndpointPreAuth(ctx context.Context) (*message.PreAuthData, error) { - dest, err := url.JoinPath(c.dnServer, message.PreAuthEndpoint) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, "POST", dest, nil) - if err != nil { - return nil, err - } - - resp, err := c.client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - reqID := resp.Header.Get("X-Request-ID") - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, &APIError{e: fmt.Errorf("failed to read the response body: %s", err), ReqID: reqID} - } - - switch resp.StatusCode { - case http.StatusOK: - r := message.PreAuthResponse{} - if err = json.Unmarshal(respBody, &r); err != nil { - return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, respBody), ReqID: reqID} - } - - if r.Data.PollToken == "" || r.Data.LoginURL == "" { - return nil, &APIError{e: fmt.Errorf("missing pollToken or loginURL"), ReqID: reqID} - } - - return &r.Data, nil - default: - var errors struct { - Errors message.APIErrors - } - if err := json.Unmarshal(respBody, &errors); err != nil { - return nil, fmt.Errorf("bad status code '%d', body: %s", resp.StatusCode, respBody) - } - return nil, &APIError{e: errors.Errors.ToError(), ReqID: reqID} - } + return callAPI[message.PreAuthData](ctx, c, "POST", message.PreAuthEndpoint, nil) } func (c *Client) EndpointAuthPoll(ctx context.Context, pollCode string) (*message.EndpointAuthPollData, error) { - pollURL, err := url.JoinPath(c.dnServer, message.EndpointAuthPoll) - if err != nil { - return nil, err - } - pollURL = fmt.Sprintf("%s?pollToken=%s", pollURL, url.QueryEscape(pollCode)) - - req, err := http.NewRequestWithContext(ctx, "GET", pollURL, nil) - if err != nil { - return nil, err - } + pollURL := fmt.Sprintf("%s?pollToken=%s", message.EndpointAuthPoll, url.QueryEscape(pollCode)) + return callAPI[message.EndpointAuthPollData](ctx, c, "GET", pollURL, nil) +} - resp, err := c.client.Do(req) +func urlPath(base, path string) (string, error) { + baseURL, err := url.Parse(base) if err != nil { - return nil, err + return "", fmt.Errorf("invalid base: %s", err) } - defer resp.Body.Close() - reqID := resp.Header.Get("X-Request-ID") - respBody, err := io.ReadAll(resp.Body) + pathURL, err := url.Parse(path) if err != nil { - return nil, &APIError{e: fmt.Errorf("failed to read the response body: %s", err), ReqID: reqID} + return "", fmt.Errorf("invalid path: %s", err) } - switch resp.StatusCode { - case http.StatusOK: - r := message.EndpointAuthPollResponse{} - if err = json.Unmarshal(respBody, &r); err != nil { - return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, respBody), ReqID: reqID} - } - return &r.Data, nil - default: - var errors struct { - Errors message.APIErrors - } - if err := json.Unmarshal(respBody, &errors); err != nil { - return nil, fmt.Errorf("bad status code '%d', body: %s", resp.StatusCode, respBody) - } - return nil, &APIError{e: errors.Errors.ToError(), ReqID: reqID} - } + finalURL := baseURL.ResolveReference(pathURL) + return finalURL.String(), nil } diff --git a/client_test.go b/client_test.go index 2ed70e4..311ea5a 100644 --- a/client_test.go +++ b/client_test.go @@ -64,7 +64,7 @@ func TestEnroll(t *testing.T) { "test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519}, }) if err != nil { - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Errors: message.APIErrors{{ Code: "ERR_FAILED_TO_MARSHAL_YAML", Message: "failed to marshal test response config", @@ -72,7 +72,7 @@ func TestEnroll(t *testing.T) { }) } - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Data: message.EnrollResponseData{ HostID: hostID, Counter: counter, @@ -148,7 +148,7 @@ func TestEnroll(t *testing.T) { // Test error handling errorMsg := "invalid enrollment code" ts.ExpectEnrollment(code, message.NetworkCurve25519, func(req message.EnrollRequest) []byte { - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Errors: message.APIErrors{{ Code: "ERR_INVALID_ENROLLMENT_CODE", Message: errorMsg, @@ -193,7 +193,7 @@ func TestDoUpdate(t *testing.T) { "test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519}, }) if err != nil { - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Errors: message.APIErrors{{ Code: "ERR_FAILED_TO_MARSHAL_YAML", Message: "failed to marshal test response config", @@ -201,7 +201,7 @@ func TestDoUpdate(t *testing.T) { }) } - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Data: message.EnrollResponseData{ HostID: "foobar", Counter: 1, @@ -462,7 +462,7 @@ func TestDoUpdate_P256(t *testing.T) { "test": m{"code": req.Code, "p256Pubkey": req.NebulaPubkeyP256}, }) if err != nil { - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Errors: message.APIErrors{{ Code: "ERR_FAILED_TO_MARSHAL_YAML", Message: "failed to marshal test response config", @@ -470,7 +470,7 @@ func TestDoUpdate_P256(t *testing.T) { }) } - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Data: message.EnrollResponseData{ HostID: "foobar", Counter: 1, @@ -556,7 +556,7 @@ func TestDoUpdate_P256(t *testing.T) { sig, err := nk.HostP256PrivateKey.Sign(rawRes) if err != nil { - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Errors: message.APIErrors{{ Code: "ERR_FAILED_TO_SIGN_MESSAGE", Message: "failed to sign message", @@ -600,7 +600,7 @@ func TestDoUpdate_P256(t *testing.T) { hashed := sha256.Sum256(rawRes) sig, err := ecdsa.SignASN1(rand.Reader, caPrivkey, hashed[:]) if err != nil { - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Errors: message.APIErrors{{ Code: "ERR_FAILED_TO_SIGN_MESSAGE", Message: "failed to sign message", @@ -654,7 +654,7 @@ func TestDoUpdate_P256(t *testing.T) { hashed := sha256.Sum256(rawRes) sig, err := ecdsa.SignASN1(rand.Reader, caPrivkey, hashed[:]) if err != nil { - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Errors: message.APIErrors{{ Code: "ERR_FAILED_TO_SIGN_MESSAGE", Message: "failed to sign message", @@ -702,7 +702,7 @@ func TestCommandResponse(t *testing.T) { "test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519}, }) if err != nil { - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Errors: message.APIErrors{{ Code: "ERR_FAILED_TO_MARSHAL_YAML", Message: "failed to marshal test response config", @@ -710,7 +710,7 @@ func TestCommandResponse(t *testing.T) { }) } - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Data: message.EnrollResponseData{ HostID: "foobar", Counter: 1, @@ -773,7 +773,7 @@ func TestCommandResponse(t *testing.T) { // Test error handling errorMsg := "sample error" ts.ExpectDNClientRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte { - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Errors: message.APIErrors{{ Code: "ERR_INVALID_VALUE", Message: errorMsg, @@ -807,7 +807,7 @@ func TestStreamCommandResponse(t *testing.T) { "test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519}, }) if err != nil { - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Errors: message.APIErrors{{ Code: "ERR_FAILED_TO_MARSHAL_YAML", Message: "failed to marshal test response config", @@ -815,7 +815,7 @@ func TestStreamCommandResponse(t *testing.T) { }) } - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Data: message.EnrollResponseData{ HostID: "foobar", Counter: 1, @@ -884,7 +884,7 @@ func TestStreamCommandResponse(t *testing.T) { // Test error handling errorMsg := "sample error" ts.ExpectStreamingRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte { - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Errors: message.APIErrors{{ Code: "ERR_INVALID_VALUE", Message: errorMsg, @@ -933,7 +933,7 @@ func TestReauthenticate(t *testing.T) { "test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519}, }) if err != nil { - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Errors: message.APIErrors{{ Code: "ERR_FAILED_TO_MARSHAL_YAML", Message: "failed to marshal test response config", @@ -941,7 +941,7 @@ func TestReauthenticate(t *testing.T) { }) } - return jsonMarshal(message.EnrollResponse{ + return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ Data: message.EnrollResponseData{ HostID: "foobar", Counter: 1, @@ -1078,7 +1078,7 @@ func TestGetOidcPollCode(t *testing.T) { t.Cleanup(func() { ts.Close() }) const expectedCode = "123456" ts.ExpectAPIRequest(http.StatusOK, func(req any) []byte { - return jsonMarshal(message.PreAuthResponse{Data: message.PreAuthData{PollToken: expectedCode, LoginURL: "https://example.com"}}) + return jsonMarshal(message.APIResponse[message.PreAuthData]{Data: message.PreAuthData{PollToken: expectedCode, LoginURL: "https://example.com"}}) }) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) @@ -1092,8 +1092,13 @@ func TestGetOidcPollCode(t *testing.T) { assert.Equal(t, 0, ts.RequestsRemaining()) //unhappy path - ts.ExpectAPIRequest(http.StatusBadGateway, func(req any) []byte { - return jsonMarshal(message.PreAuthResponse{Data: message.PreAuthData{PollToken: expectedCode, LoginURL: "https://example.com"}}) + ts.ExpectAPIRequest(http.StatusInternalServerError, func(req any) []byte { + return jsonMarshal(message.APIResponse[message.PreAuthData]{ + Errors: message.APIErrors{{ + Code: "ERR_INTERNAL_SERVER_ERROR", + Message: "internal server error", + }}, + }) }) resp, err = client.EndpointPreAuth(ctx) require.Error(t, err) @@ -1112,9 +1117,9 @@ func TestDoOidcPoll(t *testing.T) { t.Cleanup(func() { ts.Close() }) const expectedCode = "123456" ts.ExpectAPIRequest(http.StatusOK, func(r any) []byte { - return jsonMarshal(message.EndpointAuthPollResponse{Data: message.EndpointAuthPollData{ + return jsonMarshal(message.APIResponse[message.EndpointAuthPollData]{Data: message.EndpointAuthPollData{ Status: message.EndpointAuthStarted, - EnrollmentCode: "", + EnrollmentCode: expectedCode, }}) }) @@ -1122,8 +1127,8 @@ func TestDoOidcPoll(t *testing.T) { defer cancel() resp, err := client.EndpointAuthPoll(ctx, expectedCode) require.NoError(t, err) - assert.Equal(t, resp.Status, message.EndpointAuthStarted) - assert.Equal(t, resp.EnrollmentCode, "") + assert.Equal(t, message.EndpointAuthStarted, resp.Status) + assert.Equal(t, expectedCode, resp.EnrollmentCode) assert.Empty(t, ts.Errors()) assert.Equal(t, 0, ts.RequestsRemaining()) @@ -1139,7 +1144,7 @@ func TestDoOidcPoll(t *testing.T) { //complete path ts.ExpectAPIRequest(http.StatusOK, func(r any) []byte { - return jsonMarshal(message.EndpointAuthPollResponse{Data: message.EndpointAuthPollData{ + return jsonMarshal(message.APIResponse[message.EndpointAuthPollData]{Data: message.EndpointAuthPollData{ Status: message.EndpointAuthCompleted, EnrollmentCode: "deadbeef", }}) diff --git a/examples/simple/main.go b/examples/simple/main.go index a1889f2..8862767 100644 --- a/examples/simple/main.go +++ b/examples/simple/main.go @@ -28,12 +28,12 @@ func main() { // initial enrollment example config, pkey, creds, meta, err := c.Enroll(context.Background(), logger, *code) if err != nil { - logger.WithError(err).Error("Failed to enroll") + logger.WithError(err).Fatal("Failed to enroll") } config, err = dnapi.InsertConfigPrivateKey(config, pkey) if err != nil { - logger.WithError(err).Error("Failed to insert private key into config") + logger.WithError(err).Fatal("Failed to insert private key into config") } fmt.Printf( @@ -70,6 +70,7 @@ func main() { config, err = dnapi.InsertConfigPrivateKey(config, pkey) if err != nil { logger.WithError(err).Error("Failed to insert private key into config") + continue } creds = newCreds diff --git a/message/message.go b/message/message.go index cf953e1..b5d1d4e 100644 --- a/message/message.go +++ b/message/message.go @@ -130,6 +130,36 @@ type ReauthenticateResponse struct { LoginURL string `json:"loginURL"` } +// APIResponse is a standard format for the DN API. It does not apply to the DNClient API. +type APIResponse[T any] struct { + Data T `json:"data"` + Errors APIErrors `json:"errors"` +} + +// APIError represents a single error returned in an API error response. +type APIError struct { + Code string `json:"code"` + Message string `json:"message"` + Path string `json:"path"` // may or may not be present +} + +// APIErrors facilitates converting multiple API errors into a single Golang +// error to be returned to callers. +type APIErrors []APIError + +func (errs APIErrors) ToError() error { + if len(errs) == 0 { + return nil + } + + s := make([]string, len(errs)) + for i := range errs { + s[i] = errs[i].Message + } + + return errors.New(strings.Join(s, ", ")) +} + // EnrollEndpoint is the REST enrollment endpoint. const EnrollEndpoint = "/v2/enroll" @@ -143,14 +173,6 @@ type EnrollRequest struct { Timestamp time.Time `json:"timestamp"` } -// EnrollResponse represents a response from the enrollment endpoint. -type EnrollResponse struct { - // Only one of Data or Errors should be set in a response - Data EnrollResponseData `json:"data"` - - Errors APIErrors `json:"errors"` -} - // EnrollResponseData is included in the EnrollResponse. type EnrollResponseData struct { Config []byte `json:"config"` @@ -189,26 +211,27 @@ type HostEndpointOIDCMetadata struct { Email string `json:"email"` } -// APIError represents a single error returned in an API error response. -type APIError struct { - Code string `json:"code"` - Message string `json:"message"` - Path string `json:"path"` // may or may not be present +// PreAuthEndpoint is called when starting an OIDC auth flow. +const PreAuthEndpoint = "/v1/endpoint-auth/preauth" + +type PreAuthData struct { + PollToken string `json:"pollToken"` + LoginURL string `json:"loginURL"` } -type APIErrors []APIError +const EndpointAuthPoll = "/v1/endpoint-auth/poll" -func (errs APIErrors) ToError() error { - if len(errs) == 0 { - return nil - } +type EndpointAuthState string - s := make([]string, len(errs)) - for i := range errs { - s[i] = errs[i].Message - } +const ( + EndpointAuthWaiting EndpointAuthState = "WAITING" + EndpointAuthStarted EndpointAuthState = "STARTED" + EndpointAuthCompleted EndpointAuthState = "COMPLETED" +) - return errors.New(strings.Join(s, ", ")) +type EndpointAuthPollData struct { + Status EndpointAuthState `json:"state"` + EnrollmentCode string `json:"enrollmentCode"` } // NetworkCurve represents the network curve specified by the API. @@ -236,37 +259,3 @@ func (nc *NetworkCurve) UnmarshalJSON(b []byte) error { return nil } - -const PreAuthEndpoint = "/v1/endpoint-auth/preauth" - -type PreAuthResponse struct { - // Only one of Data or Errors should be set in a response - Data PreAuthData `json:"data"` - Errors APIErrors `json:"errors"` -} - -type PreAuthData struct { - PollToken string `json:"pollToken"` - LoginURL string `json:"loginURL"` -} - -const EndpointAuthPoll = "/v1/endpoint-auth/poll" - -type EndpointAuthState string - -const ( - EndpointAuthWaiting EndpointAuthState = "WAITING" - EndpointAuthStarted EndpointAuthState = "STARTED" - EndpointAuthCompleted EndpointAuthState = "COMPLETED" -) - -type EndpointAuthPollResponse struct { - // Only one of Data or Errors should be set in a response - Data EndpointAuthPollData `json:"data"` - Errors APIErrors `json:"errors"` -} - -type EndpointAuthPollData struct { - Status EndpointAuthState `json:"state"` - EnrollmentCode string `json:"enrollmentCode"` -}