@@ -83,14 +83,14 @@ fn hashpw<'p>(
8383 // salt here is not just the salt bytes, but rather an encoded value
8484 // containing a version number, number of rounds, and the salt.
8585 // Should be [prefix, cost, hash]. This logic is copied from `bcrypt`
86- let raw_parts : Vec < _ > = salt
86+ let [ raw_version , raw_cost , remainder ] : [ & [ u8 ] ; 3 ] = salt
8787 . split ( |& b| b == b'$' )
8888 . filter ( |s| !s. is_empty ( ) )
89- . collect ( ) ;
90- if raw_parts . len ( ) != 3 {
91- return Err ( pyo3:: exceptions:: PyValueError :: new_err ( "Invalid salt" ) ) ;
92- }
93- let version = match raw_parts [ 0 ] {
89+ . collect :: < Vec < _ > > ( )
90+ . try_into ( )
91+ . map_err ( |_| pyo3:: exceptions:: PyValueError :: new_err ( "Invalid salt" ) ) ? ;
92+
93+ let version = match raw_version {
9494 b"2y" => bcrypt:: Version :: TwoY ,
9595 b"2b" => bcrypt:: Version :: TwoB ,
9696 b"2a" => bcrypt:: Version :: TwoA ,
@@ -99,15 +99,20 @@ fn hashpw<'p>(
9999 return Err ( pyo3:: exceptions:: PyValueError :: new_err ( "Invalid salt" ) ) ;
100100 }
101101 } ;
102- let cost = std:: str:: from_utf8 ( raw_parts [ 1 ] )
102+ let cost = std:: str:: from_utf8 ( raw_cost )
103103 . map_err ( |_| pyo3:: exceptions:: PyValueError :: new_err ( "Invalid salt" ) ) ?
104104 . parse :: < u32 > ( )
105105 . map_err ( |_| pyo3:: exceptions:: PyValueError :: new_err ( "Invalid salt" ) ) ?;
106+
107+ if remainder. len ( ) < 22 {
108+ return Err ( pyo3:: exceptions:: PyValueError :: new_err ( "Invalid salt" ) ) ;
109+ }
110+
106111 // The last component can contain either just the salt, or the salt and
107112 // the result hash, depending on if the `salt` value come from `hashpw` or
108113 // `gensalt`.
109114 let raw_salt = BASE64_ENGINE
110- . decode ( & raw_parts [ 2 ] [ ..22 ] )
115+ . decode ( & remainder [ ..22 ] )
111116 . map_err ( |_| pyo3:: exceptions:: PyValueError :: new_err ( "Invalid salt" ) ) ?
112117 . try_into ( )
113118 . map_err ( |_| pyo3:: exceptions:: PyValueError :: new_err ( "Invalid salt" ) ) ?;
0 commit comments