Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 26 additions & 39 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
}

// Decode the response
r := message.EnrollResponse{}
r := message.APIResponse[message.EnrollResponseData]{}
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
Expand Down Expand Up @@ -660,32 +660,23 @@ func (c *Client) EndpointPreAuth(ctx context.Context) (*message.PreAuthData, err
defer resp.Body.Close()

reqID := resp.Header.Get("X-Request-ID")
respBody, err := io.ReadAll(resp.Body)

r := message.APIResponse[message.PreAuthData]{}
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, &APIError{e: fmt.Errorf("failed to read the response body: %s", err), ReqID: reqID}
return nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
}

switch resp.StatusCode {
case http.StatusOK:
r := message.PreAuthResponse{}
if err = json.Unmarshal(respBody, &r); err != nil {
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, respBody), ReqID: reqID}
}

if r.Data.PollToken == "" || r.Data.LoginURL == "" {
return nil, &APIError{e: fmt.Errorf("missing pollToken or loginURL"), ReqID: reqID}
}
if err := json.Unmarshal(b, &r); err != nil {
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
}

return &r.Data, nil
default:
var errors struct {
Errors message.APIErrors
}
if err := json.Unmarshal(respBody, &errors); err != nil {
return nil, fmt.Errorf("bad status code '%d', body: %s", resp.StatusCode, respBody)
}
return nil, &APIError{e: errors.Errors.ToError(), ReqID: reqID}
// Check for any errors returned by the API
if err := r.Errors.ToError(); err != nil {
return nil, &APIError{e: err, ReqID: reqID}
}

return &r.Data, nil
}

func (c *Client) EndpointAuthPoll(ctx context.Context, pollCode string) (*message.EndpointAuthPollData, error) {
Expand All @@ -707,25 +698,21 @@ func (c *Client) EndpointAuthPoll(ctx context.Context, pollCode string) (*messag
defer resp.Body.Close()

reqID := resp.Header.Get("X-Request-ID")
respBody, err := io.ReadAll(resp.Body)

r := message.APIResponse[message.EndpointAuthPollData]{}
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, &APIError{e: fmt.Errorf("failed to read the response body: %s", err), ReqID: reqID}
return nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
}

switch resp.StatusCode {
case http.StatusOK:
r := message.EndpointAuthPollResponse{}
if err = json.Unmarshal(respBody, &r); err != nil {
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, respBody), ReqID: reqID}
}
return &r.Data, nil
default:
var errors struct {
Errors message.APIErrors
}
if err := json.Unmarshal(respBody, &errors); err != nil {
return nil, fmt.Errorf("bad status code '%d', body: %s", resp.StatusCode, respBody)
}
return nil, &APIError{e: errors.Errors.ToError(), ReqID: reqID}
if err := json.Unmarshal(b, &r); err != nil {
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
}

// Check for any errors returned by the API
if err := r.Errors.ToError(); err != nil {
return nil, &APIError{e: err, ReqID: reqID}
}

return &r.Data, nil
}
51 changes: 28 additions & 23 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ func TestEnroll(t *testing.T) {
"test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519},
})
if err != nil {
return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Code: "ERR_FAILED_TO_MARSHAL_YAML",
Message: "failed to marshal test response config",
}},
})
}

return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Data: message.EnrollResponseData{
HostID: hostID,
Counter: counter,
Expand Down Expand Up @@ -148,7 +148,7 @@ func TestEnroll(t *testing.T) {
// Test error handling
errorMsg := "invalid enrollment code"
ts.ExpectEnrollment(code, message.NetworkCurve25519, func(req message.EnrollRequest) []byte {
return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Code: "ERR_INVALID_ENROLLMENT_CODE",
Message: errorMsg,
Expand Down Expand Up @@ -193,15 +193,15 @@ func TestDoUpdate(t *testing.T) {
"test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519},
})
if err != nil {
return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Code: "ERR_FAILED_TO_MARSHAL_YAML",
Message: "failed to marshal test response config",
}},
})
}

return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Data: message.EnrollResponseData{
HostID: "foobar",
Counter: 1,
Expand Down Expand Up @@ -462,15 +462,15 @@ func TestDoUpdate_P256(t *testing.T) {
"test": m{"code": req.Code, "p256Pubkey": req.NebulaPubkeyP256},
})
if err != nil {
return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Code: "ERR_FAILED_TO_MARSHAL_YAML",
Message: "failed to marshal test response config",
}},
})
}

return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Data: message.EnrollResponseData{
HostID: "foobar",
Counter: 1,
Expand Down Expand Up @@ -556,7 +556,7 @@ func TestDoUpdate_P256(t *testing.T) {

sig, err := nk.HostP256PrivateKey.Sign(rawRes)
if err != nil {
return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Code: "ERR_FAILED_TO_SIGN_MESSAGE",
Message: "failed to sign message",
Expand Down Expand Up @@ -600,7 +600,7 @@ func TestDoUpdate_P256(t *testing.T) {
hashed := sha256.Sum256(rawRes)
sig, err := ecdsa.SignASN1(rand.Reader, caPrivkey, hashed[:])
if err != nil {
return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Code: "ERR_FAILED_TO_SIGN_MESSAGE",
Message: "failed to sign message",
Expand Down Expand Up @@ -654,7 +654,7 @@ func TestDoUpdate_P256(t *testing.T) {
hashed := sha256.Sum256(rawRes)
sig, err := ecdsa.SignASN1(rand.Reader, caPrivkey, hashed[:])
if err != nil {
return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Code: "ERR_FAILED_TO_SIGN_MESSAGE",
Message: "failed to sign message",
Expand Down Expand Up @@ -702,15 +702,15 @@ func TestCommandResponse(t *testing.T) {
"test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519},
})
if err != nil {
return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Code: "ERR_FAILED_TO_MARSHAL_YAML",
Message: "failed to marshal test response config",
}},
})
}

return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Data: message.EnrollResponseData{
HostID: "foobar",
Counter: 1,
Expand Down Expand Up @@ -773,7 +773,7 @@ func TestCommandResponse(t *testing.T) {
// Test error handling
errorMsg := "sample error"
ts.ExpectDNClientRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte {
return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Code: "ERR_INVALID_VALUE",
Message: errorMsg,
Expand Down Expand Up @@ -807,15 +807,15 @@ func TestStreamCommandResponse(t *testing.T) {
"test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519},
})
if err != nil {
return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Code: "ERR_FAILED_TO_MARSHAL_YAML",
Message: "failed to marshal test response config",
}},
})
}

return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Data: message.EnrollResponseData{
HostID: "foobar",
Counter: 1,
Expand Down Expand Up @@ -884,7 +884,7 @@ func TestStreamCommandResponse(t *testing.T) {
// Test error handling
errorMsg := "sample error"
ts.ExpectStreamingRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte {
return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Code: "ERR_INVALID_VALUE",
Message: errorMsg,
Expand Down Expand Up @@ -933,15 +933,15 @@ func TestReauthenticate(t *testing.T) {
"test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519},
})
if err != nil {
return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Code: "ERR_FAILED_TO_MARSHAL_YAML",
Message: "failed to marshal test response config",
}},
})
}

return jsonMarshal(message.EnrollResponse{
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Data: message.EnrollResponseData{
HostID: "foobar",
Counter: 1,
Expand Down Expand Up @@ -1078,7 +1078,7 @@ func TestGetOidcPollCode(t *testing.T) {
t.Cleanup(func() { ts.Close() })
const expectedCode = "123456"
ts.ExpectAPIRequest(http.StatusOK, func(req any) []byte {
return jsonMarshal(message.PreAuthResponse{Data: message.PreAuthData{PollToken: expectedCode, LoginURL: "https://example.com"}})
return jsonMarshal(message.APIResponse[message.PreAuthData]{Data: message.PreAuthData{PollToken: expectedCode, LoginURL: "https://example.com"}})
})

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
Expand All @@ -1092,8 +1092,13 @@ func TestGetOidcPollCode(t *testing.T) {
assert.Equal(t, 0, ts.RequestsRemaining())

//unhappy path
ts.ExpectAPIRequest(http.StatusBadGateway, func(req any) []byte {
return jsonMarshal(message.PreAuthResponse{Data: message.PreAuthData{PollToken: expectedCode, LoginURL: "https://example.com"}})
ts.ExpectAPIRequest(http.StatusInternalServerError, func(req any) []byte {
return jsonMarshal(message.APIResponse[message.PreAuthData]{
Errors: message.APIErrors{{
Code: "ERR_INTERNAL_SERVER_ERROR",
Message: "internal server error",
}},
})
})
resp, err = client.EndpointPreAuth(ctx)
require.Error(t, err)
Expand All @@ -1112,7 +1117,7 @@ func TestDoOidcPoll(t *testing.T) {
t.Cleanup(func() { ts.Close() })
const expectedCode = "123456"
ts.ExpectAPIRequest(http.StatusOK, func(r any) []byte {
return jsonMarshal(message.EndpointAuthPollResponse{Data: message.EndpointAuthPollData{
return jsonMarshal(message.APIResponse[message.EndpointAuthPollData]{Data: message.EndpointAuthPollData{
Status: message.EndpointAuthStarted,
EnrollmentCode: "",
}})
Expand All @@ -1139,7 +1144,7 @@ func TestDoOidcPoll(t *testing.T) {

//complete path
ts.ExpectAPIRequest(http.StatusOK, func(r any) []byte {
return jsonMarshal(message.EndpointAuthPollResponse{Data: message.EndpointAuthPollData{
return jsonMarshal(message.APIResponse[message.EndpointAuthPollData]{Data: message.EndpointAuthPollData{
Status: message.EndpointAuthCompleted,
EnrollmentCode: "deadbeef",
}})
Expand Down
Loading
Loading