Skip to content

Commit 75e5573

Browse files
authored
Improve common package test coverage (#1344)
* Make client error messages friendlier * Allow passing `io.Reader` as request body * Increase test coverage for `common/http.go` * Moved `TrimLeadingWhitespace` from `internal` to `commands`
1 parent a89b9d7 commit 75e5573

29 files changed

+468
-229
lines changed

aws/resource_service_principal_role_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
package aws
22

33
import (
4-
"github.com/databrickslabs/terraform-provider-databricks/common"
54
"testing"
65

6+
"github.com/databrickslabs/terraform-provider-databricks/common"
7+
78
"github.com/databrickslabs/terraform-provider-databricks/scim"
89

910
"github.com/databrickslabs/terraform-provider-databricks/qa"

catalog/resource_external_location.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ func NewExternalLocationsAPI(ctx context.Context, m interface{}) ExternalLocatio
1818
}
1919

2020
type ExternalLocationInfo struct {
21-
Name string `json:"name" tf:"force_new"`
22-
URL string `json:"url"`
23-
CredentialName string `json:"credential_name"`
24-
Comment string `json:"comment,omitempty"`
25-
SkipValidation bool `json:"skip_validation,omitempty"`
26-
Owner string `json:"owner,omitempty" tf:"computed"`
27-
MetastoreID string `json:"metastore_id,omitempty" tf:"computed"`
21+
Name string `json:"name" tf:"force_new"`
22+
URL string `json:"url"`
23+
CredentialName string `json:"credential_name"`
24+
Comment string `json:"comment,omitempty"`
25+
SkipValidation bool `json:"skip_validation,omitempty"`
26+
Owner string `json:"owner,omitempty" tf:"computed"`
27+
MetastoreID string `json:"metastore_id,omitempty" tf:"computed"`
2828
}
2929

3030
func (a ExternalLocationsAPI) create(el *ExternalLocationInfo) error {

commands/commands.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88

99
"github.com/databrickslabs/terraform-provider-databricks/clusters"
1010
"github.com/databrickslabs/terraform-provider-databricks/common"
11-
"github.com/databrickslabs/terraform-provider-databricks/internal"
1211

1312
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
1413
)
@@ -52,7 +51,7 @@ func (a CommandsAPI) Execute(clusterID, language, commandStr string) common.Comm
5251
Summary: fmt.Sprintf("Cluster %s has to be running or resizing, but is %s", clusterID, cluster.State),
5352
}
5453
}
55-
commandStr = internal.TrimLeadingWhitespace(commandStr)
54+
commandStr = TrimLeadingWhitespace(commandStr)
5655
log.Printf("[INFO] Executing %s command on %s:\n%s", language, clusterID, commandStr)
5756
context, err := a.createContext(language, clusterID)
5857
if err != nil {

internal/utils.go renamed to commands/leading_whitespace.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
package internal
1+
package commands
22

33
import (
44
"strings"
55
)
66

7-
// TrimLeadingWhitespace removes leading whitespace
7+
// TrimLeadingWhitespace removes leading whitespace, so that Python code blocks
8+
// that are embedded into Go code still could be interpreted properly.
89
func TrimLeadingWhitespace(commandStr string) (newCommand string) {
910
lines := strings.Split(strings.ReplaceAll(commandStr, "\t", " "), "\n")
1011
leadingWhitespace := 1<<31 - 1

internal/utils_test.go renamed to commands/leading_whitespace_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package internal
1+
package commands
22

33
import (
44
"testing"

common/client.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ func (c *DatabricksClient) niceAuthError(message string) error {
320320
}
321321
info = ". " + strings.Join(infos, ". ")
322322
}
323+
info = strings.TrimSuffix(info, ".")
324+
message = strings.TrimSuffix(message, ".")
323325
docUrl := "https://registry.terraform.io/providers/databrickslabs/databricks/latest/docs#authentication"
324326
return fmt.Errorf("%s%s. Please check %s for details", message, info, docUrl)
325327
}

common/http.go

Lines changed: 82 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
"encoding/json"
77
"errors"
88
"fmt"
9-
"io/ioutil"
9+
"io"
1010
"log"
1111
"net/http"
1212
"net/url"
@@ -179,7 +179,7 @@ func (c *DatabricksClient) commonErrorClarity(resp *http.Response) *APIError {
179179
}
180180

181181
func (c *DatabricksClient) parseError(resp *http.Response) APIError {
182-
body, err := ioutil.ReadAll(resp.Body)
182+
body, err := io.ReadAll(resp.Body)
183183
if err != nil {
184184
return APIError{
185185
Message: err.Error(),
@@ -345,16 +345,21 @@ func (c *DatabricksClient) completeUrl(r *http.Request) error {
345345
return nil
346346
}
347347

348+
// scimPathVisitorFactory is a separate method for the sake of unit tests
349+
func (c *DatabricksClient) scimVisitor(r *http.Request) error {
350+
r.Header.Set("Content-Type", "application/scim+json; charset=utf-8")
351+
if c.isAccountsClient() && c.AccountID != "" {
352+
// until `/preview` is there for workspace scim,
353+
// `/api/2.0` is added by completeUrl visitor
354+
r.URL.Path = strings.ReplaceAll(r.URL.Path, "/api/2.0/preview",
355+
fmt.Sprintf("/api/2.0/accounts/%s", c.AccountID))
356+
}
357+
return nil
358+
}
359+
348360
// Scim sets SCIM headers
349361
func (c *DatabricksClient) Scim(ctx context.Context, method, path string, request interface{}, response interface{}) error {
350-
body, err := c.authenticatedQuery(ctx, method, path, request, c.completeUrl, func(r *http.Request) error {
351-
r.Header.Set("Content-Type", "application/scim+json; charset=utf-8")
352-
if c.isAccountsClient() && c.AccountID != "" {
353-
// until `/preview` is there for workspace scim
354-
r.URL.Path = strings.ReplaceAll(path, "/preview", fmt.Sprintf("/api/2.0/accounts/%s", c.AccountID))
355-
}
356-
return nil
357-
})
362+
body, err := c.authenticatedQuery(ctx, method, path, request, c.completeUrl, c.scimVisitor)
358363
if err != nil {
359364
return err
360365
}
@@ -402,7 +407,9 @@ func (c *DatabricksClient) redactedDump(body []byte) (res string) {
402407
if len(body) == 0 {
403408
return
404409
}
405-
410+
if body[0] != '{' {
411+
return fmt.Sprintf("[non-JSON document of %d bytes]", len(body))
412+
}
406413
var requestMap map[string]interface{}
407414
err := json.Unmarshal(body, &requestMap)
408415
if err != nil {
@@ -465,21 +472,21 @@ func (c *DatabricksClient) genericQuery(ctx context.Context, method, requestURL
465472
return nil, fmt.Errorf("DatabricksClient is not configured")
466473
}
467474
if err = c.rateLimiter.Wait(ctx); err != nil {
468-
return nil, err
475+
return nil, fmt.Errorf("rate limited: %w", err)
469476
}
470-
requestBody, err := makeRequestBody(method, &requestURL, data, true)
477+
requestBody, err := makeRequestBody(method, &requestURL, data)
471478
if err != nil {
472-
return nil, err
479+
return nil, fmt.Errorf("request marshal: %w", err)
473480
}
474481
request, err := http.NewRequestWithContext(ctx, method, requestURL, bytes.NewBuffer(requestBody))
475482
if err != nil {
476-
return nil, err
483+
return nil, fmt.Errorf("new request: %w", err)
477484
}
478485
request.Header.Set("User-Agent", c.userAgent(ctx))
479486
for _, requestVisitor := range visitors {
480487
err = requestVisitor(request)
481488
if err != nil {
482-
return nil, err
489+
return nil, fmt.Errorf("failed visitor: %w", err)
483490
}
484491
}
485492
headers := c.createDebugHeaders(request.Header, c.Host)
@@ -488,78 +495,93 @@ func (c *DatabricksClient) genericQuery(ctx context.Context, method, requestURL
488495

489496
r, err := retryablehttp.FromRequest(request)
490497
if err != nil {
491-
return nil, err
498+
return nil, err // no error invariants possible because of `makeRequestBody`
492499
}
493500
resp, err := c.httpClient.Do(r)
494501
// retryablehttp library now returns only wrapped errors
495502
var ae APIError
496503
if errors.As(err, &ae) {
504+
// don't re-wrap, as upper layers may depend on handling common.APIError
497505
return nil, ae
498506
}
499507
if err != nil {
500-
return nil, err
508+
// i don't even know which errors in the real world would end up here.
509+
// `retryablehttp` package nicely wraps _everything_ to `url.Error`.
510+
return nil, fmt.Errorf("failed request: %w", err)
501511
}
502512
defer func() {
503513
if ferr := resp.Body.Close(); ferr != nil {
504-
err = ferr
514+
err = fmt.Errorf("failed to close: %w", ferr)
505515
}
506516
}()
507-
body, err = ioutil.ReadAll(resp.Body)
517+
body, err = io.ReadAll(resp.Body)
508518
if err != nil {
509-
return nil, err
519+
return nil, fmt.Errorf("response body: %w", err)
510520
}
511521
headers = c.createDebugHeaders(resp.Header, "")
512522
log.Printf("[DEBUG] %s %s %s <- %s %s", resp.Status, headers, c.redactedDump(body), method, strings.ReplaceAll(request.URL.Path, "\n", ""))
513523
return body, nil
514524
}
515525

516-
func makeRequestBody(method string, requestURL *string, data interface{}, marshalJSON bool) ([]byte, error) {
526+
func makeQueryString(data interface{}) (string, error) {
527+
inputVal := reflect.ValueOf(data)
528+
inputType := reflect.TypeOf(data)
529+
if inputType.Kind() == reflect.Map {
530+
s := []string{}
531+
keys := inputVal.MapKeys()
532+
// sort map keys by their string repr, so that tests can be deterministic
533+
sort.Slice(keys, func(i, j int) bool {
534+
return keys[i].String() < keys[j].String()
535+
})
536+
for _, k := range keys {
537+
v := inputVal.MapIndex(k)
538+
if v.IsZero() {
539+
continue
540+
}
541+
s = append(s, fmt.Sprintf("%s=%s",
542+
strings.Replace(url.QueryEscape(fmt.Sprintf("%v", k.Interface())), "+", "%20", -1),
543+
strings.Replace(url.QueryEscape(fmt.Sprintf("%v", v.Interface())), "+", "%20", -1)))
544+
}
545+
return "?" + strings.Join(s, "&"), nil
546+
}
547+
if inputType.Kind() == reflect.Struct {
548+
params, err := query.Values(data)
549+
if err != nil {
550+
return "", fmt.Errorf("cannot create query string: %w", err)
551+
}
552+
return "?" + params.Encode(), nil
553+
}
554+
return "", fmt.Errorf("unsupported query string data: %#v", data)
555+
}
556+
557+
func makeRequestBody(method string, requestURL *string, data interface{}) ([]byte, error) {
517558
var requestBody []byte
518559
if data == nil && (method == "DELETE" || method == "GET") {
519560
return requestBody, nil
520561
}
521562
if method == "GET" {
522-
inputVal := reflect.ValueOf(data)
523-
inputType := reflect.TypeOf(data)
524-
switch inputType.Kind() {
525-
case reflect.Map:
526-
s := []string{}
527-
keys := inputVal.MapKeys()
528-
// sort map keys by their string repr, so that tests can be deterministic
529-
sort.Slice(keys, func(i, j int) bool {
530-
return keys[i].String() < keys[j].String()
531-
})
532-
for _, k := range keys {
533-
v := inputVal.MapIndex(k)
534-
if v.IsZero() {
535-
continue
536-
}
537-
s = append(s, fmt.Sprintf("%s=%s",
538-
strings.Replace(url.QueryEscape(fmt.Sprintf("%v", k.Interface())), "+", "%20", -1),
539-
strings.Replace(url.QueryEscape(fmt.Sprintf("%v", v.Interface())), "+", "%20", -1)))
540-
}
541-
*requestURL += "?" + strings.Join(s, "&")
542-
case reflect.Struct:
543-
params, err := query.Values(data)
544-
if err != nil {
545-
return nil, err
546-
}
547-
*requestURL += "?" + params.Encode()
548-
default:
549-
return requestBody, fmt.Errorf("unsupported request data: %#v", data)
563+
qs, err := makeQueryString(data)
564+
if err != nil {
565+
return nil, err
550566
}
551-
} else {
552-
if marshalJSON {
553-
bodyBytes, err := json.MarshalIndent(data, "", " ")
554-
if err != nil {
555-
return nil, err
556-
}
557-
requestBody = bodyBytes
558-
} else {
559-
requestBody = []byte(data.(string))
567+
*requestURL += qs
568+
return requestBody, nil
569+
}
570+
if reader, ok := data.(io.Reader); ok {
571+
raw, err := io.ReadAll(reader)
572+
if err != nil {
573+
return nil, fmt.Errorf("failed to read from reader: %w", err)
560574
}
575+
return raw, nil
576+
}
577+
if str, ok := data.(string); ok {
578+
return []byte(str), nil
579+
}
580+
bodyBytes, err := json.MarshalIndent(data, "", " ")
581+
if err != nil {
582+
return nil, fmt.Errorf("request marshal failure: %w", err)
561583
}
562-
return requestBody, nil
584+
return bodyBytes, nil
563585
}
564586

565587
func onlyNBytes(j string, numBytes int) string {

0 commit comments

Comments
 (0)