Skip to content

Commit ce99db4

Browse files
authored
Clean up internal/testutil/helpers package. (#1057)
1 parent a47c69f commit ce99db4

21 files changed

+250
-310
lines changed

internal/testutil/helpers/helpers.go

Lines changed: 16 additions & 233 deletions
Original file line numberDiff line numberDiff line change
@@ -4,63 +4,22 @@
44
// not use this file except in compliance with the License. You may obtain
55
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
66

7-
package testhelpers // import "go.mongodb.org/mongo-driver/internal/testutil/helpers"
7+
package helpers
88

99
import (
1010
"fmt"
1111
"io/ioutil"
12-
"math"
1312
"path"
14-
"strings"
15-
"time"
16-
1713
"testing"
1814

19-
"io"
20-
21-
"reflect"
22-
2315
"github.com/stretchr/testify/require"
2416
"go.mongodb.org/mongo-driver/bson"
25-
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
2617
)
2718

28-
// Test helpers
29-
30-
// IsNil returns true if the object is nil
31-
func IsNil(object interface{}) bool {
32-
if object == nil {
33-
return true
34-
}
35-
36-
value := reflect.ValueOf(object)
37-
kind := value.Kind()
38-
39-
// checking to see if type is Chan, Func, Interface, Map, Ptr, or Slice
40-
if kind >= reflect.Chan && kind <= reflect.Slice && value.IsNil() {
41-
return true
42-
}
43-
44-
return false
45-
}
46-
47-
// RequireNotNil throws an error if var is nil
48-
func RequireNotNil(t *testing.T, variable interface{}, msgFormat string, msgVars ...interface{}) {
49-
if IsNil(variable) {
50-
t.Fatalf(msgFormat, msgVars...)
51-
}
52-
}
53-
54-
// RequireNil throws an error if var is not nil
55-
func RequireNil(t *testing.T, variable interface{}, msgFormat string, msgVars ...interface{}) {
56-
t.Helper()
57-
if !IsNil(variable) {
58-
t.Fatalf(msgFormat, msgVars...)
59-
}
60-
}
61-
6219
// FindJSONFilesInDir finds the JSON files in a directory.
6320
func FindJSONFilesInDir(t *testing.T, dir string) []string {
21+
t.Helper()
22+
6423
files := make([]string, 0)
6524

6625
entries, err := ioutil.ReadDir(dir)
@@ -77,202 +36,26 @@ func FindJSONFilesInDir(t *testing.T, dir string) []string {
7736
return files
7837
}
7938

80-
// RequireNoErrorOnClose ensures there is not an error when calling Close.
81-
func RequireNoErrorOnClose(t *testing.T, c io.Closer) {
82-
require.NoError(t, c.Close())
83-
}
84-
85-
// VerifyConnStringOptions verifies the options on the connection string.
86-
func VerifyConnStringOptions(t *testing.T, cs connstring.ConnString, options map[string]interface{}) {
87-
// Check that all options are present.
88-
for key, value := range options {
89-
90-
key = strings.ToLower(key)
91-
switch key {
92-
case "appname":
93-
require.Equal(t, value, cs.AppName)
94-
case "authsource":
95-
require.Equal(t, value, cs.AuthSource)
96-
case "authmechanism":
97-
require.Equal(t, value, cs.AuthMechanism)
98-
case "authmechanismproperties":
99-
convertedMap := value.(map[string]interface{})
100-
require.Equal(t,
101-
mapInterfaceToString(convertedMap),
102-
cs.AuthMechanismProperties)
103-
case "compressors":
104-
require.Equal(t, convertToStringSlice(value), cs.Compressors)
105-
case "connecttimeoutms":
106-
require.Equal(t, value, float64(cs.ConnectTimeout/time.Millisecond))
107-
case "directconnection":
108-
require.True(t, cs.DirectConnectionSet)
109-
require.Equal(t, value, cs.DirectConnection)
110-
case "heartbeatfrequencyms":
111-
require.Equal(t, value, float64(cs.HeartbeatInterval/time.Millisecond))
112-
case "journal":
113-
require.True(t, cs.JSet)
114-
require.Equal(t, value, cs.J)
115-
case "loadbalanced":
116-
require.True(t, cs.LoadBalancedSet)
117-
require.Equal(t, value, cs.LoadBalanced)
118-
case "localthresholdms":
119-
require.True(t, cs.LocalThresholdSet)
120-
require.Equal(t, value, float64(cs.LocalThreshold/time.Millisecond))
121-
case "maxidletimems":
122-
require.Equal(t, value, float64(cs.MaxConnIdleTime/time.Millisecond))
123-
case "maxpoolsize":
124-
require.True(t, cs.MaxPoolSizeSet)
125-
require.Equal(t, value, cs.MaxPoolSize)
126-
case "maxstalenessseconds":
127-
require.True(t, cs.MaxStalenessSet)
128-
require.Equal(t, value, float64(cs.MaxStaleness/time.Second))
129-
case "minpoolsize":
130-
require.True(t, cs.MinPoolSizeSet)
131-
require.Equal(t, value, int64(cs.MinPoolSize))
132-
case "readpreference":
133-
require.Equal(t, value, cs.ReadPreference)
134-
case "readpreferencetags":
135-
sm, ok := value.([]interface{})
136-
require.True(t, ok)
137-
tags := make([]map[string]string, 0, len(sm))
138-
for _, i := range sm {
139-
m, ok := i.(map[string]interface{})
140-
require.True(t, ok)
141-
tags = append(tags, mapInterfaceToString(m))
142-
}
143-
require.Equal(t, tags, cs.ReadPreferenceTagSets)
144-
case "readconcernlevel":
145-
require.Equal(t, value, cs.ReadConcernLevel)
146-
case "replicaset":
147-
require.Equal(t, value, cs.ReplicaSet)
148-
case "retrywrites":
149-
require.True(t, cs.RetryWritesSet)
150-
require.Equal(t, value, cs.RetryWrites)
151-
case "serverselectiontimeoutms":
152-
require.Equal(t, value, float64(cs.ServerSelectionTimeout/time.Millisecond))
153-
case "srvmaxhosts":
154-
require.Equal(t, value, float64(cs.SRVMaxHosts))
155-
case "srvservicename":
156-
require.Equal(t, value, cs.SRVServiceName)
157-
case "ssl", "tls":
158-
require.Equal(t, value, cs.SSL)
159-
case "sockettimeoutms":
160-
require.Equal(t, value, float64(cs.SocketTimeout/time.Millisecond))
161-
case "tlsallowinvalidcertificates", "tlsallowinvalidhostnames", "tlsinsecure":
162-
require.True(t, cs.SSLInsecureSet)
163-
require.Equal(t, value, cs.SSLInsecure)
164-
case "tlscafile":
165-
require.True(t, cs.SSLCaFileSet)
166-
require.Equal(t, value, cs.SSLCaFile)
167-
case "tlscertificatekeyfile":
168-
require.True(t, cs.SSLClientCertificateKeyFileSet)
169-
require.Equal(t, value, cs.SSLClientCertificateKeyFile)
170-
case "tlscertificatekeyfilepassword":
171-
require.True(t, cs.SSLClientCertificateKeyPasswordSet)
172-
require.Equal(t, value, cs.SSLClientCertificateKeyPassword())
173-
case "w":
174-
if cs.WNumberSet {
175-
valueInt := GetIntFromInterface(value)
176-
require.NotNil(t, valueInt)
177-
require.Equal(t, *valueInt, int64(cs.WNumber))
178-
} else {
179-
require.Equal(t, value, cs.WString)
180-
}
181-
case "wtimeoutms":
182-
require.Equal(t, value, float64(cs.WTimeout/time.Millisecond))
183-
case "waitqueuetimeoutms":
184-
case "zlibcompressionlevel":
185-
require.Equal(t, value, float64(cs.ZlibLevel))
186-
case "zstdcompressionlevel":
187-
require.Equal(t, value, float64(cs.ZstdLevel))
188-
case "tlsdisableocspendpointcheck":
189-
require.Equal(t, value, cs.SSLDisableOCSPEndpointCheck)
190-
default:
191-
opt, ok := cs.UnknownOptions[key]
192-
require.True(t, ok)
193-
require.Contains(t, opt, fmt.Sprint(value))
194-
}
39+
// RawToDocuments converts a bson.Raw that is internally an array of documents to []bson.Raw.
40+
func RawToDocuments(doc bson.Raw) []bson.Raw {
41+
values, err := doc.Values()
42+
if err != nil {
43+
panic(fmt.Sprintf("error converting BSON document to values: %v", err))
19544
}
196-
}
19745

198-
// RawSliceToInterfaceSlice converts a []bson.Raw to []interface{}.
199-
func RawSliceToInterfaceSlice(elems []bson.Raw) []interface{} {
200-
out := make([]interface{}, 0, len(elems))
201-
for _, elem := range elems {
202-
out = append(out, elem)
203-
}
204-
return out
205-
}
206-
207-
// RawToInterfaceSlice converts a bson.Raw that is internally an array to []interface{}.
208-
func RawToInterfaceSlice(doc bson.Raw) []interface{} {
209-
values, _ := doc.Values()
210-
211-
out := make([]interface{}, 0, len(values))
212-
for _, val := range values {
213-
out = append(out, val.Document())
46+
out := make([]bson.Raw, len(values))
47+
for i := range values {
48+
out[i] = values[i].Document()
21449
}
21550

21651
return out
21752
}
21853

219-
// Convert each interface{} value in the map to a string.
220-
func mapInterfaceToString(m map[string]interface{}) map[string]string {
221-
out := make(map[string]string)
222-
223-
for key, value := range m {
224-
out[key] = fmt.Sprint(value)
54+
// RawToInterfaces takes one or many bson.Raw documents and returns them as a []interface{}.
55+
func RawToInterfaces(docs ...bson.Raw) []interface{} {
56+
out := make([]interface{}, len(docs))
57+
for i := range docs {
58+
out[i] = docs[i]
22559
}
226-
22760
return out
22861
}
229-
230-
func convertToStringSlice(i interface{}) []string {
231-
s, ok := i.([]interface{})
232-
if !ok {
233-
return nil
234-
}
235-
ret := make([]string, 0, len(s))
236-
for _, v := range s {
237-
str, ok := v.(string)
238-
if !ok {
239-
continue
240-
}
241-
ret = append(ret, str)
242-
}
243-
return ret
244-
}
245-
246-
// GetIntFromInterface attempts to convert an empty interface value to an integer.
247-
//
248-
// Returns nil if it is not possible.
249-
func GetIntFromInterface(i interface{}) *int64 {
250-
var out int64
251-
252-
switch v := i.(type) {
253-
case int:
254-
out = int64(v)
255-
case int32:
256-
out = int64(v)
257-
case int64:
258-
out = v
259-
case float32:
260-
f := float64(v)
261-
if math.Floor(f) != f || f > float64(math.MaxInt64) {
262-
break
263-
}
264-
265-
out = int64(f)
266-
267-
case float64:
268-
if math.Floor(v) != v || v > float64(math.MaxInt64) {
269-
break
270-
}
271-
272-
out = int64(v)
273-
default:
274-
return nil
275-
}
276-
277-
return &out
278-
}

mongo/description/max_staleness_spec_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
"path"
1111
"testing"
1212

13-
testhelpers "go.mongodb.org/mongo-driver/internal/testutil/helpers"
13+
"go.mongodb.org/mongo-driver/internal/testutil/helpers"
1414
)
1515

1616
const maxStalenessTestsDir = "../../data/max-staleness"
@@ -24,7 +24,7 @@ func TestMaxStalenessSpec(t *testing.T) {
2424
"Single",
2525
"Unknown",
2626
} {
27-
for _, file := range testhelpers.FindJSONFilesInDir(t,
27+
for _, file := range helpers.FindJSONFilesInDir(t,
2828
path.Join(maxStalenessTestsDir, topology)) {
2929

3030
runTest(t, maxStalenessTestsDir, topology, file)

mongo/description/selector_spec_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
"path"
1111
"testing"
1212

13-
testhelpers "go.mongodb.org/mongo-driver/internal/testutil/helpers"
13+
"go.mongodb.org/mongo-driver/internal/testutil/helpers"
1414
)
1515

1616
const selectorTestsDir = "../../data/server-selection/server_selection"
@@ -28,7 +28,7 @@ func TestServerSelectionSpec(t *testing.T) {
2828
for _, subdir := range [...]string{"read", "write"} {
2929
subdirPath := path.Join(topology, subdir)
3030

31-
for _, file := range testhelpers.FindJSONFilesInDir(t,
31+
for _, file := range helpers.FindJSONFilesInDir(t,
3232
path.Join(selectorTestsDir, subdirPath)) {
3333

3434
runTest(t, selectorTestsDir, subdirPath, file)

mongo/integration/crud_helpers_test.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"go.mongodb.org/mongo-driver/bson/primitive"
2121
"go.mongodb.org/mongo-driver/internal/testutil"
2222
"go.mongodb.org/mongo-driver/internal/testutil/assert"
23+
"go.mongodb.org/mongo-driver/internal/testutil/helpers"
2324
"go.mongodb.org/mongo-driver/mongo"
2425
"go.mongodb.org/mongo-driver/mongo/gridfs"
2526
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
@@ -171,7 +172,7 @@ func executeAggregate(mt *mtest.T, agg aggregator, sess mongo.Session, args bson
171172

172173
switch key {
173174
case "pipeline":
174-
pipeline = rawArrayToInterfaceSlice(val.Array())
175+
pipeline = helpers.RawToInterfaces(helpers.RawToDocuments(val.Array())...)
175176
case "batchSize":
176177
opts.SetBatchSize(val.Int32())
177178
case "collation":
@@ -209,7 +210,7 @@ func executeWatch(mt *mtest.T, w watcher, sess mongo.Session, args bson.Raw) (*m
209210

210211
switch key {
211212
case "pipeline":
212-
pipeline = rawArrayToInterfaceSlice(val.Array())
213+
pipeline = helpers.RawToInterfaces(helpers.RawToDocuments(val.Array())...)
213214
default:
214215
mt.Fatalf("unrecognized watch option: %v", key)
215216
}
@@ -312,7 +313,7 @@ func executeInsertMany(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.I
312313

313314
switch key {
314315
case "documents":
315-
docs = rawArrayToInterfaceSlice(val.Array())
316+
docs = helpers.RawToInterfaces(helpers.RawToDocuments(val.Array())...)
316317
case "options":
317318
// Some of the older tests use this to set the "ordered" option
318319
optsDoc := val.Document()
@@ -699,7 +700,7 @@ func executeFindOneAndUpdate(mt *mtest.T, sess mongo.Session, args bson.Raw) *mo
699700
update = createUpdate(mt, val)
700701
case "arrayFilters":
701702
opts = opts.SetArrayFilters(options.ArrayFilters{
702-
Filters: rawArrayToInterfaceSlice(val.Array()),
703+
Filters: helpers.RawToInterfaces(helpers.RawToDocuments(val.Array())...),
703704
})
704705
case "sort":
705706
opts = opts.SetSort(val.Document())
@@ -881,7 +882,7 @@ func executeUpdateOne(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.Up
881882
update = createUpdate(mt, val)
882883
case "arrayFilters":
883884
opts = opts.SetArrayFilters(options.ArrayFilters{
884-
Filters: rawArrayToInterfaceSlice(val.Array()),
885+
Filters: helpers.RawToInterfaces(helpers.RawToDocuments(val.Array())...),
885886
})
886887
case "upsert":
887888
opts = opts.SetUpsert(val.Boolean())
@@ -929,7 +930,7 @@ func executeUpdateMany(mt *mtest.T, sess mongo.Session, args bson.Raw) (*mongo.U
929930
update = createUpdate(mt, val)
930931
case "arrayFilters":
931932
opts = opts.SetArrayFilters(options.ArrayFilters{
932-
Filters: rawArrayToInterfaceSlice(val.Array()),
933+
Filters: helpers.RawToInterfaces(helpers.RawToDocuments(val.Array())...),
933934
})
934935
case "upsert":
935936
opts = opts.SetUpsert(val.Boolean())
@@ -1115,7 +1116,7 @@ func createBulkWriteModel(mt *mtest.T, rawModel bson.Raw) mongo.WriteModel {
11151116
}
11161117
if arrayFilters, err := args.LookupErr("arrayFilters"); err == nil {
11171118
uom.SetArrayFilters(options.ArrayFilters{
1118-
Filters: rawArrayToInterfaceSlice(arrayFilters.Array()),
1119+
Filters: helpers.RawToInterfaces(helpers.RawToDocuments(arrayFilters.Array())...),
11191120
})
11201121
}
11211122
if hintVal, err := args.LookupErr("hint"); err == nil {
@@ -1138,7 +1139,7 @@ func createBulkWriteModel(mt *mtest.T, rawModel bson.Raw) mongo.WriteModel {
11381139
}
11391140
if arrayFilters, err := args.LookupErr("arrayFilters"); err == nil {
11401141
umm.SetArrayFilters(options.ArrayFilters{
1141-
Filters: rawArrayToInterfaceSlice(arrayFilters.Array()),
1142+
Filters: helpers.RawToInterfaces(helpers.RawToDocuments(arrayFilters.Array())...),
11421143
})
11431144
}
11441145
if hintVal, err := args.LookupErr("hint"); err == nil {

0 commit comments

Comments
 (0)