@@ -116,12 +116,12 @@ func (h *Header) encode() (string, error) {
116116// Decode decodes a claim set from a JWS payload.
117117func Decode (payload string ) (* ClaimSet , error ) {
118118 // decode returned id token to get expiry
119- s := strings . Split (payload , "." )
120- if len ( s ) < 2 {
119+ _ , claims , _ , ok := parseToken (payload )
120+ if ! ok {
121121 // TODO(jbd): Provide more context about the error.
122122 return nil , errors .New ("jws: invalid token received" )
123123 }
124- decoded , err := base64 .RawURLEncoding .DecodeString (s [ 1 ] )
124+ decoded , err := base64 .RawURLEncoding .DecodeString (claims )
125125 if err != nil {
126126 return nil , err
127127 }
@@ -165,18 +165,34 @@ func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) {
165165// Verify tests whether the provided JWT token's signature was produced by the private key
166166// associated with the supplied public key.
167167func Verify (token string , key * rsa.PublicKey ) error {
168- if strings .Count (token , "." ) != 2 {
168+ header , claims , sig , ok := parseToken (token )
169+ if ! ok {
169170 return errors .New ("jws: invalid token received, token must have 3 parts" )
170171 }
171-
172- parts := strings .SplitN (token , "." , 3 )
173- signedContent := parts [0 ] + "." + parts [1 ]
174- signatureString , err := base64 .RawURLEncoding .DecodeString (parts [2 ])
172+ signatureString , err := base64 .RawURLEncoding .DecodeString (sig )
175173 if err != nil {
176174 return err
177175 }
178176
179177 h := sha256 .New ()
180- h .Write ([]byte (signedContent ))
178+ h .Write ([]byte (header + tokenDelim + claims ))
181179 return rsa .VerifyPKCS1v15 (key , crypto .SHA256 , h .Sum (nil ), signatureString )
182180}
181+
182+ func parseToken (s string ) (header , claims , sig string , ok bool ) {
183+ header , s , ok = strings .Cut (s , tokenDelim )
184+ if ! ok { // no period found
185+ return "" , "" , "" , false
186+ }
187+ claims , s , ok = strings .Cut (s , tokenDelim )
188+ if ! ok { // only one period found
189+ return "" , "" , "" , false
190+ }
191+ sig , _ , ok = strings .Cut (s , tokenDelim )
192+ if ok { // three periods found
193+ return "" , "" , "" , false
194+ }
195+ return header , claims , sig , true
196+ }
197+
198+ const tokenDelim = "."
0 commit comments