Skip to content

Commit d4fc09a

Browse files
committed
stuff for the OIDC auth flow
1 parent caa5a20 commit d4fc09a

File tree

2 files changed

+114
-19
lines changed

2 files changed

+114
-19
lines changed

client.go

Lines changed: 100 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,28 @@ func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, valu
467467
return sc, nil
468468
}
469469

470+
func (c *Client) handleBody(resp *http.Response) ([]byte, error) {
471+
respBody, err := io.ReadAll(resp.Body)
472+
if err != nil {
473+
return nil, fmt.Errorf("failed to read the response body: %s", err)
474+
}
475+
476+
switch resp.StatusCode {
477+
case http.StatusOK:
478+
return respBody, nil
479+
case http.StatusUnauthorized:
480+
return nil, ErrInvalidCredentials
481+
default:
482+
var errors struct {
483+
Errors message.APIErrors
484+
}
485+
if err := json.Unmarshal(respBody, &errors); err != nil {
486+
return nil, fmt.Errorf("dnclient endpoint returned bad status code '%d', body: %s", resp.StatusCode, respBody)
487+
}
488+
return nil, errors.Errors.ToError()
489+
}
490+
}
491+
470492
// postDNClient wraps and signs the given dnclientRequestWrapper message, and makes the API call.
471493
// On success, it returns the response message body. On error, the error is returned.
472494
func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte, hostID string, counter uint, privkey keys.PrivateKey) ([]byte, error) {
@@ -489,25 +511,7 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte,
489511
}
490512
defer resp.Body.Close()
491513

492-
respBody, err := io.ReadAll(resp.Body)
493-
if err != nil {
494-
return nil, fmt.Errorf("failed to read the response body: %s", err)
495-
}
496-
497-
switch resp.StatusCode {
498-
case http.StatusOK:
499-
return respBody, nil
500-
case http.StatusUnauthorized:
501-
return nil, ErrInvalidCredentials
502-
default:
503-
var errors struct {
504-
Errors message.APIErrors
505-
}
506-
if err := json.Unmarshal(respBody, &errors); err != nil {
507-
return nil, fmt.Errorf("dnclient endpoint returned bad status code '%d', body: %s", resp.StatusCode, respBody)
508-
}
509-
return nil, errors.Errors.ToError()
510-
}
514+
return c.handleBody(resp)
511515
}
512516

513517
// StreamController is used for interacting with streaming requests to the API.
@@ -581,3 +585,80 @@ func nonce() []byte {
581585
}
582586
return nonce
583587
}
588+
589+
func (c *Client) GetOidcPollCode(ctx context.Context, logger logrus.FieldLogger) (string, error) {
590+
logger.WithFields(logrus.Fields{"server": c.dnServer}).Debug("Making GetOidcPollCode request to API")
591+
592+
enrollURL, err := url.JoinPath(c.dnServer, message.PreAuthEndpoint)
593+
if err != nil {
594+
return "", err
595+
}
596+
597+
req, err := http.NewRequestWithContext(ctx, "POST", enrollURL, nil)
598+
if err != nil {
599+
return "", err
600+
}
601+
602+
resp, err := c.client.Do(req)
603+
if err != nil {
604+
return "", err
605+
}
606+
defer resp.Body.Close()
607+
608+
// Log the request ID returned from the server
609+
reqID := resp.Header.Get("X-Request-ID")
610+
l := logger.WithFields(logrus.Fields{"statusCode": resp.StatusCode, "reqID": reqID})
611+
b, err := c.handleBody(resp)
612+
if err != nil {
613+
l.Error(err) //todo I don't like erroring and also logging?
614+
return "", err
615+
}
616+
617+
// Decode the response
618+
r := message.PreAuthResponse{}
619+
if err = json.Unmarshal(b, &r); err != nil {
620+
return "", &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
621+
}
622+
623+
return r.PollToken, nil
624+
}
625+
626+
func (c *Client) DoOidcPoll(ctx context.Context, logger logrus.FieldLogger, pollCode string) (*message.EnduserAuthPollResponse, error) {
627+
logger.WithFields(logrus.Fields{"server": c.dnServer}).Debug("Making DoOidcPoll request to API")
628+
629+
enrollURL, err := url.JoinPath(c.dnServer, message.EnduserAuthPoll)
630+
if err != nil {
631+
return nil, err
632+
}
633+
634+
req, err := http.NewRequestWithContext(ctx, "GET", enrollURL, nil)
635+
if err != nil {
636+
return nil, err
637+
}
638+
q := req.URL.Query()
639+
q.Add("token", pollCode)
640+
req.URL.RawQuery = q.Encode()
641+
642+
resp, err := c.client.Do(req)
643+
if err != nil {
644+
return nil, err
645+
}
646+
defer resp.Body.Close()
647+
648+
// Log the request ID returned from the server
649+
reqID := resp.Header.Get("X-Request-ID")
650+
l := logger.WithFields(logrus.Fields{"statusCode": resp.StatusCode, "reqID": reqID})
651+
b, err := c.handleBody(resp)
652+
if err != nil {
653+
l.Error(err) //todo I don't like erroring and also logging?
654+
return nil, err
655+
}
656+
657+
// Decode the response
658+
r := message.EnduserAuthPollResponse{}
659+
if err = json.Unmarshal(b, &r); err != nil {
660+
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
661+
}
662+
663+
return &r, nil
664+
}

message/message.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,17 @@ func (nc *NetworkCurve) UnmarshalJSON(b []byte) error {
218218

219219
return nil
220220
}
221+
222+
const PreAuthEndpoint = "/v1/enduser-auth/preauth"
223+
224+
type PreAuthResponse struct {
225+
PollToken string `json:"pollToken"`
226+
}
227+
228+
const EnduserAuthPoll = "/v1/enduser-auth/poll"
229+
230+
type EnduserAuthPollResponse struct {
231+
Status string `json:"status"`
232+
LoginUrl string `json:"loginUrl"`
233+
EnrollmentCode string `json:"enrollmentCode"`
234+
}

0 commit comments

Comments
 (0)