Skip to content

Commit 364a711

Browse files
committed
Simplify types in message package with generics
1 parent 70e0ea7 commit 364a711

File tree

3 files changed

+100
-119
lines changed

3 files changed

+100
-119
lines changed

client.go

Lines changed: 26 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
172172
}
173173

174174
// Decode the response
175-
r := message.EnrollResponse{}
175+
r := message.APIResponse[message.EnrollResponseData]{}
176176
b, err := io.ReadAll(resp.Body)
177177
if err != nil {
178178
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
@@ -660,32 +660,23 @@ func (c *Client) EndpointPreAuth(ctx context.Context) (*message.PreAuthData, err
660660
defer resp.Body.Close()
661661

662662
reqID := resp.Header.Get("X-Request-ID")
663-
respBody, err := io.ReadAll(resp.Body)
663+
664+
r := message.APIResponse[message.PreAuthData]{}
665+
b, err := io.ReadAll(resp.Body)
664666
if err != nil {
665-
return nil, &APIError{e: fmt.Errorf("failed to read the response body: %s", err), ReqID: reqID}
667+
return nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
666668
}
667669

668-
switch resp.StatusCode {
669-
case http.StatusOK:
670-
r := message.PreAuthResponse{}
671-
if err = json.Unmarshal(respBody, &r); err != nil {
672-
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, respBody), ReqID: reqID}
673-
}
674-
675-
if r.Data.PollToken == "" || r.Data.LoginURL == "" {
676-
return nil, &APIError{e: fmt.Errorf("missing pollToken or loginURL"), ReqID: reqID}
677-
}
670+
if err := json.Unmarshal(b, &r); err != nil {
671+
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
672+
}
678673

679-
return &r.Data, nil
680-
default:
681-
var errors struct {
682-
Errors message.APIErrors
683-
}
684-
if err := json.Unmarshal(respBody, &errors); err != nil {
685-
return nil, fmt.Errorf("bad status code '%d', body: %s", resp.StatusCode, respBody)
686-
}
687-
return nil, &APIError{e: errors.Errors.ToError(), ReqID: reqID}
674+
// Check for any errors returned by the API
675+
if err := r.Errors.ToError(); err != nil {
676+
return nil, &APIError{e: err, ReqID: reqID}
688677
}
678+
679+
return &r.Data, nil
689680
}
690681

691682
func (c *Client) EndpointAuthPoll(ctx context.Context, pollCode string) (*message.EndpointAuthPollData, error) {
@@ -707,25 +698,21 @@ func (c *Client) EndpointAuthPoll(ctx context.Context, pollCode string) (*messag
707698
defer resp.Body.Close()
708699

709700
reqID := resp.Header.Get("X-Request-ID")
710-
respBody, err := io.ReadAll(resp.Body)
701+
702+
r := message.APIResponse[message.EndpointAuthPollData]{}
703+
b, err := io.ReadAll(resp.Body)
711704
if err != nil {
712-
return nil, &APIError{e: fmt.Errorf("failed to read the response body: %s", err), ReqID: reqID}
705+
return nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
713706
}
714707

715-
switch resp.StatusCode {
716-
case http.StatusOK:
717-
r := message.EndpointAuthPollResponse{}
718-
if err = json.Unmarshal(respBody, &r); err != nil {
719-
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, respBody), ReqID: reqID}
720-
}
721-
return &r.Data, nil
722-
default:
723-
var errors struct {
724-
Errors message.APIErrors
725-
}
726-
if err := json.Unmarshal(respBody, &errors); err != nil {
727-
return nil, fmt.Errorf("bad status code '%d', body: %s", resp.StatusCode, respBody)
728-
}
729-
return nil, &APIError{e: errors.Errors.ToError(), ReqID: reqID}
708+
if err := json.Unmarshal(b, &r); err != nil {
709+
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
730710
}
711+
712+
// Check for any errors returned by the API
713+
if err := r.Errors.ToError(); err != nil {
714+
return nil, &APIError{e: err, ReqID: reqID}
715+
}
716+
717+
return &r.Data, nil
731718
}

client_test.go

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ func TestEnroll(t *testing.T) {
6464
"test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519},
6565
})
6666
if err != nil {
67-
return jsonMarshal(message.EnrollResponse{
67+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
6868
Errors: message.APIErrors{{
6969
Code: "ERR_FAILED_TO_MARSHAL_YAML",
7070
Message: "failed to marshal test response config",
7171
}},
7272
})
7373
}
7474

75-
return jsonMarshal(message.EnrollResponse{
75+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
7676
Data: message.EnrollResponseData{
7777
HostID: hostID,
7878
Counter: counter,
@@ -148,7 +148,7 @@ func TestEnroll(t *testing.T) {
148148
// Test error handling
149149
errorMsg := "invalid enrollment code"
150150
ts.ExpectEnrollment(code, message.NetworkCurve25519, func(req message.EnrollRequest) []byte {
151-
return jsonMarshal(message.EnrollResponse{
151+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
152152
Errors: message.APIErrors{{
153153
Code: "ERR_INVALID_ENROLLMENT_CODE",
154154
Message: errorMsg,
@@ -193,15 +193,15 @@ func TestDoUpdate(t *testing.T) {
193193
"test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519},
194194
})
195195
if err != nil {
196-
return jsonMarshal(message.EnrollResponse{
196+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
197197
Errors: message.APIErrors{{
198198
Code: "ERR_FAILED_TO_MARSHAL_YAML",
199199
Message: "failed to marshal test response config",
200200
}},
201201
})
202202
}
203203

204-
return jsonMarshal(message.EnrollResponse{
204+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
205205
Data: message.EnrollResponseData{
206206
HostID: "foobar",
207207
Counter: 1,
@@ -462,15 +462,15 @@ func TestDoUpdate_P256(t *testing.T) {
462462
"test": m{"code": req.Code, "p256Pubkey": req.NebulaPubkeyP256},
463463
})
464464
if err != nil {
465-
return jsonMarshal(message.EnrollResponse{
465+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
466466
Errors: message.APIErrors{{
467467
Code: "ERR_FAILED_TO_MARSHAL_YAML",
468468
Message: "failed to marshal test response config",
469469
}},
470470
})
471471
}
472472

473-
return jsonMarshal(message.EnrollResponse{
473+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
474474
Data: message.EnrollResponseData{
475475
HostID: "foobar",
476476
Counter: 1,
@@ -556,7 +556,7 @@ func TestDoUpdate_P256(t *testing.T) {
556556

557557
sig, err := nk.HostP256PrivateKey.Sign(rawRes)
558558
if err != nil {
559-
return jsonMarshal(message.EnrollResponse{
559+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
560560
Errors: message.APIErrors{{
561561
Code: "ERR_FAILED_TO_SIGN_MESSAGE",
562562
Message: "failed to sign message",
@@ -600,7 +600,7 @@ func TestDoUpdate_P256(t *testing.T) {
600600
hashed := sha256.Sum256(rawRes)
601601
sig, err := ecdsa.SignASN1(rand.Reader, caPrivkey, hashed[:])
602602
if err != nil {
603-
return jsonMarshal(message.EnrollResponse{
603+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
604604
Errors: message.APIErrors{{
605605
Code: "ERR_FAILED_TO_SIGN_MESSAGE",
606606
Message: "failed to sign message",
@@ -654,7 +654,7 @@ func TestDoUpdate_P256(t *testing.T) {
654654
hashed := sha256.Sum256(rawRes)
655655
sig, err := ecdsa.SignASN1(rand.Reader, caPrivkey, hashed[:])
656656
if err != nil {
657-
return jsonMarshal(message.EnrollResponse{
657+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
658658
Errors: message.APIErrors{{
659659
Code: "ERR_FAILED_TO_SIGN_MESSAGE",
660660
Message: "failed to sign message",
@@ -702,15 +702,15 @@ func TestCommandResponse(t *testing.T) {
702702
"test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519},
703703
})
704704
if err != nil {
705-
return jsonMarshal(message.EnrollResponse{
705+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
706706
Errors: message.APIErrors{{
707707
Code: "ERR_FAILED_TO_MARSHAL_YAML",
708708
Message: "failed to marshal test response config",
709709
}},
710710
})
711711
}
712712

713-
return jsonMarshal(message.EnrollResponse{
713+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
714714
Data: message.EnrollResponseData{
715715
HostID: "foobar",
716716
Counter: 1,
@@ -773,7 +773,7 @@ func TestCommandResponse(t *testing.T) {
773773
// Test error handling
774774
errorMsg := "sample error"
775775
ts.ExpectDNClientRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte {
776-
return jsonMarshal(message.EnrollResponse{
776+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
777777
Errors: message.APIErrors{{
778778
Code: "ERR_INVALID_VALUE",
779779
Message: errorMsg,
@@ -807,15 +807,15 @@ func TestStreamCommandResponse(t *testing.T) {
807807
"test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519},
808808
})
809809
if err != nil {
810-
return jsonMarshal(message.EnrollResponse{
810+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
811811
Errors: message.APIErrors{{
812812
Code: "ERR_FAILED_TO_MARSHAL_YAML",
813813
Message: "failed to marshal test response config",
814814
}},
815815
})
816816
}
817817

818-
return jsonMarshal(message.EnrollResponse{
818+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
819819
Data: message.EnrollResponseData{
820820
HostID: "foobar",
821821
Counter: 1,
@@ -884,7 +884,7 @@ func TestStreamCommandResponse(t *testing.T) {
884884
// Test error handling
885885
errorMsg := "sample error"
886886
ts.ExpectStreamingRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte {
887-
return jsonMarshal(message.EnrollResponse{
887+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
888888
Errors: message.APIErrors{{
889889
Code: "ERR_INVALID_VALUE",
890890
Message: errorMsg,
@@ -933,15 +933,15 @@ func TestReauthenticate(t *testing.T) {
933933
"test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519},
934934
})
935935
if err != nil {
936-
return jsonMarshal(message.EnrollResponse{
936+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
937937
Errors: message.APIErrors{{
938938
Code: "ERR_FAILED_TO_MARSHAL_YAML",
939939
Message: "failed to marshal test response config",
940940
}},
941941
})
942942
}
943943

944-
return jsonMarshal(message.EnrollResponse{
944+
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
945945
Data: message.EnrollResponseData{
946946
HostID: "foobar",
947947
Counter: 1,
@@ -1078,7 +1078,7 @@ func TestGetOidcPollCode(t *testing.T) {
10781078
t.Cleanup(func() { ts.Close() })
10791079
const expectedCode = "123456"
10801080
ts.ExpectAPIRequest(http.StatusOK, func(req any) []byte {
1081-
return jsonMarshal(message.PreAuthResponse{Data: message.PreAuthData{PollToken: expectedCode, LoginURL: "https://example.com"}})
1081+
return jsonMarshal(message.APIResponse[message.PreAuthData]{Data: message.PreAuthData{PollToken: expectedCode, LoginURL: "https://example.com"}})
10821082
})
10831083

10841084
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
@@ -1092,8 +1092,13 @@ func TestGetOidcPollCode(t *testing.T) {
10921092
assert.Equal(t, 0, ts.RequestsRemaining())
10931093

10941094
//unhappy path
1095-
ts.ExpectAPIRequest(http.StatusBadGateway, func(req any) []byte {
1096-
return jsonMarshal(message.PreAuthResponse{Data: message.PreAuthData{PollToken: expectedCode, LoginURL: "https://example.com"}})
1095+
ts.ExpectAPIRequest(http.StatusInternalServerError, func(req any) []byte {
1096+
return jsonMarshal(message.APIResponse[message.PreAuthData]{
1097+
Errors: message.APIErrors{{
1098+
Code: "ERR_INTERNAL_SERVER_ERROR",
1099+
Message: "internal server error",
1100+
}},
1101+
})
10971102
})
10981103
resp, err = client.EndpointPreAuth(ctx)
10991104
require.Error(t, err)
@@ -1112,7 +1117,7 @@ func TestDoOidcPoll(t *testing.T) {
11121117
t.Cleanup(func() { ts.Close() })
11131118
const expectedCode = "123456"
11141119
ts.ExpectAPIRequest(http.StatusOK, func(r any) []byte {
1115-
return jsonMarshal(message.EndpointAuthPollResponse{Data: message.EndpointAuthPollData{
1120+
return jsonMarshal(message.APIResponse[message.EndpointAuthPollData]{Data: message.EndpointAuthPollData{
11161121
Status: message.EndpointAuthStarted,
11171122
EnrollmentCode: "",
11181123
}})
@@ -1139,7 +1144,7 @@ func TestDoOidcPoll(t *testing.T) {
11391144

11401145
//complete path
11411146
ts.ExpectAPIRequest(http.StatusOK, func(r any) []byte {
1142-
return jsonMarshal(message.EndpointAuthPollResponse{Data: message.EndpointAuthPollData{
1147+
return jsonMarshal(message.APIResponse[message.EndpointAuthPollData]{Data: message.EndpointAuthPollData{
11431148
Status: message.EndpointAuthCompleted,
11441149
EnrollmentCode: "deadbeef",
11451150
}})

0 commit comments

Comments
 (0)