Skip to content

Commit 4f30b8a

Browse files
committed
Fix URL joining
1 parent 09ba9ef commit 4f30b8a

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

client.go

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
146146
return nil, nil, nil, nil, err
147147
}
148148

149-
enrollURL, err := url.JoinPath(c.dnServer, message.EnrollEndpoint)
149+
enrollURL, err := urlPath(c.dnServer, message.EnrollEndpoint)
150150
if err != nil {
151151
return nil, nil, nil, nil, err
152152
}
@@ -480,7 +480,7 @@ func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, valu
480480
}
481481
pbb := bytes.NewBuffer(postBody)
482482

483-
endpointV1URL, err := url.JoinPath(c.dnServer, message.EndpointV1)
483+
endpointV1URL, err := urlPath(c.dnServer, message.EndpointV1)
484484
if err != nil {
485485
return nil, err
486486
}
@@ -535,7 +535,7 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte,
535535
return nil, err
536536
}
537537

538-
endpointV1URL, err := url.JoinPath(c.dnServer, message.EndpointV1)
538+
endpointV1URL, err := urlPath(c.dnServer, message.EndpointV1)
539539
if err != nil {
540540
return nil, err
541541
}
@@ -571,7 +571,7 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte,
571571
}
572572

573573
func callAPI[T any](ctx context.Context, c *Client, method string, endpoint string, payload map[string]any) (*T, error) {
574-
dest, err := url.JoinPath(c.dnServer, endpoint)
574+
dest, err := urlPath(c.dnServer, endpoint)
575575
if err != nil {
576576
return nil, err
577577
}
@@ -701,3 +701,18 @@ func (c *Client) EndpointAuthPoll(ctx context.Context, pollCode string) (*messag
701701
pollURL := fmt.Sprintf("%s?pollToken=%s", message.EndpointAuthPoll, url.QueryEscape(pollCode))
702702
return callAPI[message.EndpointAuthPollData](ctx, c, "GET", pollURL, nil)
703703
}
704+
705+
func urlPath(base, path string) (string, error) {
706+
baseURL, err := url.Parse(base)
707+
if err != nil {
708+
return "", fmt.Errorf("invalid base: %s", err)
709+
}
710+
711+
pathURL, err := url.Parse(path)
712+
if err != nil {
713+
return "", fmt.Errorf("invalid path: %s", err)
714+
}
715+
716+
finalURL := baseURL.ResolveReference(pathURL)
717+
return finalURL.String(), nil
718+
}

client_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,16 +1119,16 @@ func TestDoOidcPoll(t *testing.T) {
11191119
ts.ExpectAPIRequest(http.StatusOK, func(r any) []byte {
11201120
return jsonMarshal(message.APIResponse[message.EndpointAuthPollData]{Data: message.EndpointAuthPollData{
11211121
Status: message.EndpointAuthStarted,
1122-
EnrollmentCode: "",
1122+
EnrollmentCode: expectedCode,
11231123
}})
11241124
})
11251125

11261126
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
11271127
defer cancel()
11281128
resp, err := client.EndpointAuthPoll(ctx, expectedCode)
11291129
require.NoError(t, err)
1130-
assert.Equal(t, resp.Status, message.EndpointAuthStarted)
1131-
assert.Equal(t, resp.EnrollmentCode, "")
1130+
assert.Equal(t, message.EndpointAuthStarted, resp.Status)
1131+
assert.Equal(t, expectedCode, resp.EnrollmentCode)
11321132
assert.Empty(t, ts.Errors())
11331133
assert.Equal(t, 0, ts.RequestsRemaining())
11341134

0 commit comments

Comments
 (0)