Skip to content

Commit bc14c6e

Browse files
authored
Improve ClientOptions.Validate to prevent overwriting validation errors. (#938)
1 parent 8f19bdb commit bc14c6e

File tree

2 files changed

+52
-21
lines changed

2 files changed

+52
-21
lines changed

mongo/options/clientoptions.go

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -160,62 +160,58 @@ func Client() *ClientOptions {
160160

161161
// Validate validates the client options. This method will return the first error found.
162162
func (c *ClientOptions) Validate() error {
163-
c.validateAndSetError()
164-
return c.err
165-
}
166-
167-
func (c *ClientOptions) validateAndSetError() {
168163
if c.err != nil {
169-
return
164+
return c.err
170165
}
166+
c.err = c.validate()
167+
return c.err
168+
}
171169

170+
func (c *ClientOptions) validate() error {
172171
// Direct connections cannot be made if multiple hosts are specified or an SRV URI is used.
173172
if c.Direct != nil && *c.Direct {
174173
if len(c.Hosts) > 1 {
175-
c.err = errors.New("a direct connection cannot be made if multiple hosts are specified")
176-
return
174+
return errors.New("a direct connection cannot be made if multiple hosts are specified")
177175
}
178176
if c.cs != nil && c.cs.Scheme == connstring.SchemeMongoDBSRV {
179-
c.err = errors.New("a direct connection cannot be made if an SRV URI is used")
180-
return
177+
return errors.New("a direct connection cannot be made if an SRV URI is used")
181178
}
182179
}
183180

184181
if c.MaxPoolSize != nil && c.MinPoolSize != nil && *c.MaxPoolSize != 0 && *c.MinPoolSize > *c.MaxPoolSize {
185-
c.err = fmt.Errorf("minPoolSize must be less than or equal to maxPoolSize, got minPoolSize=%d maxPoolSize=%d", *c.MinPoolSize, *c.MaxPoolSize)
186-
return
182+
return fmt.Errorf("minPoolSize must be less than or equal to maxPoolSize, got minPoolSize=%d maxPoolSize=%d", *c.MinPoolSize, *c.MaxPoolSize)
187183
}
188184

189185
// verify server API version if ServerAPIOptions are passed in.
190186
if c.ServerAPIOptions != nil {
191-
c.err = c.ServerAPIOptions.ServerAPIVersion.Validate()
187+
if err := c.ServerAPIOptions.ServerAPIVersion.Validate(); err != nil {
188+
return err
189+
}
192190
}
193191

194192
// Validation for load-balanced mode.
195193
if c.LoadBalanced != nil && *c.LoadBalanced {
196194
if len(c.Hosts) > 1 {
197-
c.err = internal.ErrLoadBalancedWithMultipleHosts
198-
return
195+
return internal.ErrLoadBalancedWithMultipleHosts
199196
}
200197
if c.ReplicaSet != nil {
201-
c.err = internal.ErrLoadBalancedWithReplicaSet
202-
return
198+
return internal.ErrLoadBalancedWithReplicaSet
203199
}
204200
if c.Direct != nil {
205-
c.err = internal.ErrLoadBalancedWithDirectConnection
206-
return
201+
return internal.ErrLoadBalancedWithDirectConnection
207202
}
208203
}
209204

210205
// Validation for srvMaxHosts.
211206
if c.SRVMaxHosts != nil && *c.SRVMaxHosts > 0 {
212207
if c.ReplicaSet != nil {
213-
c.err = internal.ErrSRVMaxHostsWithReplicaSet
208+
return internal.ErrSRVMaxHostsWithReplicaSet
214209
}
215210
if c.LoadBalanced != nil && *c.LoadBalanced {
216-
c.err = internal.ErrSRVMaxHostsWithLoadBalanced
211+
return internal.ErrSRVMaxHostsWithLoadBalanced
217212
}
218213
}
214+
return nil
219215
}
220216

221217
// GetURI returns the original URI used to configure the ClientOptions instance. If ApplyURI was not called during

mongo/options/clientoptions_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,41 @@ func TestClientOptions(t *testing.T) {
696696
})
697697
}
698698
})
699+
t.Run("srvMaxHosts validation", func(t *testing.T) {
700+
t.Parallel()
701+
702+
testCases := []struct {
703+
name string
704+
opts *ClientOptions
705+
err error
706+
}{
707+
{
708+
name: "valid ServerAPI",
709+
opts: Client().SetServerAPIOptions(ServerAPI(ServerAPIVersion1)),
710+
err: nil,
711+
},
712+
{
713+
name: "invalid ServerAPI",
714+
opts: Client().SetServerAPIOptions(ServerAPI("nope")),
715+
err: errors.New(`api version "nope" not supported; this driver version only supports API version "1"`),
716+
},
717+
{
718+
name: "invalid ServerAPI with other invalid options",
719+
opts: Client().SetServerAPIOptions(ServerAPI("nope")).SetSRVMaxHosts(1).SetReplicaSet("foo"),
720+
err: errors.New(`api version "nope" not supported; this driver version only supports API version "1"`),
721+
},
722+
}
723+
for _, tc := range testCases {
724+
tc := tc // Capture range variable.
725+
726+
t.Run(tc.name, func(t *testing.T) {
727+
t.Parallel()
728+
729+
err := tc.opts.Validate()
730+
assert.Equal(t, tc.err, err, "want error %v, got error %v", tc.err, err)
731+
})
732+
}
733+
})
699734
}
700735

701736
func createCertPool(t *testing.T, paths ...string) *x509.CertPool {

0 commit comments

Comments
 (0)