@@ -10,6 +10,7 @@ use std::{
1010 convert:: TryFrom ,
1111 fmt:: { self , Display , Formatter , Write } ,
1212 hash:: { Hash , Hasher } ,
13+ net:: Ipv6Addr ,
1314 path:: PathBuf ,
1415 str:: FromStr ,
1516 sync:: Arc ,
@@ -124,9 +125,29 @@ impl<'de> Deserialize<'de> for ServerAddress {
124125 where
125126 D : Deserializer < ' de > ,
126127 {
127- let s: String = Deserialize :: deserialize ( deserializer) ?;
128- Self :: parse ( s. as_str ( ) )
129- . map_err ( |e| <D :: Error as serde:: de:: Error >:: custom ( format ! ( "{}" , e) ) )
128+ #[ derive( Deserialize ) ]
129+ #[ serde( untagged) ]
130+ enum ServerAddressHelper {
131+ String ( String ) ,
132+ Object { host : String , port : Option < u16 > } ,
133+ }
134+
135+ let helper = ServerAddressHelper :: deserialize ( deserializer) ?;
136+ match helper {
137+ ServerAddressHelper :: String ( string) => {
138+ Self :: parse ( string) . map_err ( serde:: de:: Error :: custom)
139+ }
140+ ServerAddressHelper :: Object { host, port } => {
141+ #[ cfg( unix) ]
142+ if host. ends_with ( "sock" ) {
143+ return Ok ( Self :: Unix {
144+ path : PathBuf :: from ( host) ,
145+ } ) ;
146+ }
147+
148+ Ok ( Self :: Tcp { host, port } )
149+ }
150+ }
130151 }
131152}
132153
@@ -184,71 +205,92 @@ impl ServerAddress {
184205 /// Parses an address string into a `ServerAddress`.
185206 pub fn parse ( address : impl AsRef < str > ) -> Result < Self > {
186207 let address = address. as_ref ( ) ;
187- // checks if the address is a unix domain socket
188- #[ cfg( unix) ]
189- {
190- if address. ends_with ( ".sock" ) {
191- return Ok ( ServerAddress :: Unix {
208+
209+ if address. ends_with ( ".sock" ) {
210+ #[ cfg( unix) ]
211+ {
212+ let address = percent_decode ( address, "unix domain sockets must be URL-encoded" ) ?;
213+ return Ok ( Self :: Unix {
192214 path : PathBuf :: from ( address) ,
193215 } ) ;
194216 }
217+ #[ cfg( not( unix) ) ]
218+ return Err ( ErrorKind :: InvalidArgument {
219+ message : "unix domain sockets are not supported on this platform" . to_string ( ) ,
220+ }
221+ . into ( ) ) ;
195222 }
196- let mut parts = address. split ( ':' ) ;
197- let hostname = match parts. next ( ) {
198- Some ( part) => {
199- if part. is_empty ( ) {
200- return Err ( ErrorKind :: InvalidArgument {
201- message : format ! (
202- "invalid server address: \" {}\" ; hostname cannot be empty" ,
203- address
204- ) ,
205- }
206- . into ( ) ) ;
223+
224+ let ( hostname, port) = if let Some ( ip_literal) = address. strip_prefix ( "[" ) {
225+ let Some ( ( hostname, port) ) = ip_literal. split_once ( "]" ) else {
226+ return Err ( ErrorKind :: InvalidArgument {
227+ message : format ! (
228+ "invalid server address {}: missing closing ']' in IP literal hostname" ,
229+ address
230+ ) ,
207231 }
208- part
209- }
210- None => {
232+ . into ( ) ) ;
233+ } ;
234+
235+ if let Err ( parse_error) = Ipv6Addr :: from_str ( hostname) {
211236 return Err ( ErrorKind :: InvalidArgument {
212- message : format ! ( "invalid server address: \" {} \" " , address) ,
237+ message : format ! ( "invalid server address {}: {}" , address, parse_error ) ,
213238 }
214- . into ( ) )
239+ . into ( ) ) ;
215240 }
216- } ;
217241
218- let port = match parts. next ( ) {
219- Some ( part) => {
220- let port = u16:: from_str ( part) . map_err ( |_| ErrorKind :: InvalidArgument {
242+ let port = if port. is_empty ( ) {
243+ None
244+ } else if let Some ( port) = port. strip_prefix ( ":" ) {
245+ Some ( port)
246+ } else {
247+ return Err ( ErrorKind :: InvalidArgument {
221248 message : format ! (
222- "port must be valid 16-bit unsigned integer, instead got: {}" ,
223- part
249+ "invalid server address {}: the hostname can only be followed by a port \
250+ prefixed with ':', got {}",
251+ address, port
224252 ) ,
225- } ) ?;
226-
227- if port == 0 {
228- return Err ( ErrorKind :: InvalidArgument {
229- message : format ! (
230- "invalid server address: \" {}\" ; port must be non-zero" ,
231- address
232- ) ,
233- }
234- . into ( ) ) ;
235253 }
236- if parts. next ( ) . is_some ( ) {
254+ . into ( ) ) ;
255+ } ;
256+
257+ ( hostname, port)
258+ } else {
259+ match address. split_once ( ":" ) {
260+ Some ( ( hostname, port) ) => ( hostname, Some ( port) ) ,
261+ None => ( address, None ) ,
262+ }
263+ } ;
264+
265+ if hostname. is_empty ( ) {
266+ return Err ( ErrorKind :: InvalidArgument {
267+ message : format ! (
268+ "invalid server address {}: the hostname cannot be empty" ,
269+ address
270+ ) ,
271+ }
272+ . into ( ) ) ;
273+ }
274+
275+ let port = if let Some ( port) = port {
276+ match u16:: from_str ( port) {
277+ Ok ( 0 ) | Err ( _) => {
237278 return Err ( ErrorKind :: InvalidArgument {
238279 message : format ! (
239- "address \" {}\" contains more than one unescaped ':'" ,
240- address
280+ "invalid server address {}: the port must be an integer between 1 and \
281+ 65535, got {}",
282+ address, port
241283 ) ,
242284 }
243- . into ( ) ) ;
285+ . into ( ) )
244286 }
245-
246- Some ( port)
287+ Ok ( port) => Some ( port) ,
247288 }
248- None => None ,
289+ } else {
290+ None
249291 } ;
250292
251- Ok ( ServerAddress :: Tcp {
293+ Ok ( Self :: Tcp {
252294 host : hostname. to_lowercase ( ) ,
253295 port,
254296 } )
@@ -1689,37 +1731,21 @@ impl ConnectionString {
16891731 None => ( None , None ) ,
16901732 } ;
16911733
1692- let mut host_list = Vec :: with_capacity ( hosts_section. len ( ) ) ;
1693- for host in hosts_section. split ( ',' ) {
1694- let address = if host. ends_with ( ".sock" ) {
1695- #[ cfg( unix) ]
1696- {
1697- ServerAddress :: parse ( percent_decode (
1698- host,
1699- "Unix domain sockets must be URL-encoded" ,
1700- ) ?)
1701- }
1702- #[ cfg( not( unix) ) ]
1703- return Err ( ErrorKind :: InvalidArgument {
1704- message : "Unix domain sockets are not supported on this platform" . to_string ( ) ,
1705- }
1706- . into ( ) ) ;
1707- } else {
1708- ServerAddress :: parse ( host)
1709- } ?;
1710- host_list. push ( address) ;
1711- }
1734+ let hosts = hosts_section
1735+ . split ( ',' )
1736+ . map ( ServerAddress :: parse)
1737+ . collect :: < Result < Vec < ServerAddress > > > ( ) ?;
17121738
1713- let hosts = if srv {
1714- if host_list . len ( ) != 1 {
1739+ let host_info = if srv {
1740+ if hosts . len ( ) != 1 {
17151741 return Err ( ErrorKind :: InvalidArgument {
17161742 message : "exactly one host must be specified with 'mongodb+srv'" . into ( ) ,
17171743 }
17181744 . into ( ) ) ;
17191745 }
17201746
17211747 // Unwrap safety: the `len` check above guarantees this can't fail.
1722- match host_list . into_iter ( ) . next ( ) . unwrap ( ) {
1748+ match hosts . into_iter ( ) . next ( ) . unwrap ( ) {
17231749 ServerAddress :: Tcp { host, port } => {
17241750 if port. is_some ( ) {
17251751 return Err ( ErrorKind :: InvalidArgument {
@@ -1738,11 +1764,11 @@ impl ConnectionString {
17381764 }
17391765 }
17401766 } else {
1741- HostInfo :: HostIdentifiers ( host_list )
1767+ HostInfo :: HostIdentifiers ( hosts )
17421768 } ;
17431769
17441770 let mut conn_str = ConnectionString {
1745- host_info : hosts ,
1771+ host_info,
17461772 #[ cfg( test) ]
17471773 original_uri : s. into ( ) ,
17481774 ..Default :: default ( )
0 commit comments