33
44using System ;
55using System . Buffers ;
6+ #if NET9_0_OR_GREATER
67using System . Buffers . Text ;
8+ #endif
79using System . Text ;
810using Microsoft . IdentityModel . Logging ;
911
@@ -21,6 +23,8 @@ public static class Base64UrlEncoder
2123 private const char Base64PadCharacter = '=' ;
2224 private const char Base64Character62 = '+' ;
2325 private const char Base64Character63 = '/' ;
26+ private const char Base64UrlCharacter62 = '-' ;
27+ private const char Base64UrlCharacter63 = '_' ;
2428
2529 /// <summary>
2630 /// Performs base64url encoding, which differs from regular base64 encoding as follows:
@@ -98,17 +102,86 @@ public static string Encode(byte[] inArray, int offset, int length)
98102 LogHelper . MarkAsNonPII ( inArray . Length ) ) ) ) ;
99103#pragma warning restore CA2208 // Instantiate argument exceptions correctly
100104
105+ #if NET9_0_OR_GREATER
101106 return Base64Url . EncodeToString ( inArray . AsSpan ( ) . Slice ( offset , length ) ) ;
107+ #else
108+ char [ ] destination = new char [ ( inArray . Length + 2 ) / 3 * 4 ] ;
109+ int j = Encode ( inArray . AsSpan < byte > ( ) . Slice ( offset , length ) , destination . AsSpan < char > ( ) ) ;
110+
111+ return new string ( destination , 0 , j ) ;
112+ #endif
102113 }
103114
115+ #if NET9_0_OR_GREATER
104116 /// <summary>
105117 /// Populates a <see cref="Span{T}"/> with the base64url encoded representation of a <see cref="ReadOnlySpan{T}"/> of bytes.
106118 /// </summary>
107119 /// <param name="inArray">A read-only span of bytes to encode.</param>
108120 /// <param name="output">The span of characters to write the encoded output.</param>
109121 /// <returns>The number of characters written to the output span.</returns>
110122 public static int Encode ( ReadOnlySpan < byte > inArray , Span < char > output ) => Base64Url . EncodeToChars ( inArray , output ) ;
123+ #else
124+ /// <summary>
125+ /// Populates a <see cref="Span{T}"/> with the base64url encoded representation of a <see cref="ReadOnlySpan{T}"/> of bytes.
126+ /// </summary>
127+ /// <param name="inArray">A read-only span of bytes to encode.</param>
128+ /// <param name="output">The span of characters to write the encoded output.</param>
129+ /// <returns>The number of characters written to the output span.</returns>
130+ public static int Encode ( ReadOnlySpan < byte > inArray , Span < char > output )
131+ {
132+ int lengthmod3 = inArray . Length % 3 ;
133+ int limit = ( inArray . Length - lengthmod3 ) ;
134+ ReadOnlySpan < byte > table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"u8 ;
135+
136+ int i , j = 0 ;
137+
138+ // takes 3 bytes from inArray and insert 4 bytes into output
139+ for ( i = 0 ; i < limit ; i += 3 )
140+ {
141+ byte d0 = inArray [ i ] ;
142+ byte d1 = inArray [ i + 1 ] ;
143+ byte d2 = inArray [ i + 2 ] ;
144+
145+ output [ j + 0 ] = ( char ) table [ d0 >> 2 ] ;
146+ output [ j + 1 ] = ( char ) table [ ( ( d0 & 0x03 ) << 4 ) | ( d1 >> 4 ) ] ;
147+ output [ j + 2 ] = ( char ) table [ ( ( d1 & 0x0f ) << 2 ) | ( d2 >> 6 ) ] ;
148+ output [ j + 3 ] = ( char ) table [ d2 & 0x3f ] ;
149+ j += 4 ;
150+ }
151+
152+ //Where we left off before
153+ i = limit ;
111154
155+ switch ( lengthmod3 )
156+ {
157+ case 2 :
158+ {
159+ byte d0 = inArray [ i ] ;
160+ byte d1 = inArray [ i + 1 ] ;
161+
162+ output [ j + 0 ] = ( char ) table [ d0 >> 2 ] ;
163+ output [ j + 1 ] = ( char ) table [ ( ( d0 & 0x03 ) << 4 ) | ( d1 >> 4 ) ] ;
164+ output [ j + 2 ] = ( char ) table [ ( d1 & 0x0f ) << 2 ] ;
165+ j += 3 ;
166+ }
167+ break ;
168+ case 1 :
169+ {
170+ byte d0 = inArray [ i ] ;
171+
172+ output [ j + 0 ] = ( char ) table [ d0 >> 2 ] ;
173+ output [ j + 1 ] = ( char ) table [ ( d0 & 0x03 ) << 4 ] ;
174+ j += 2 ;
175+ }
176+ break ;
177+
178+ //default or case 0: no further operations are needed.
179+ }
180+
181+ return j ;
182+ }
183+
184+ #endif
112185 /// <summary>
113186 /// Converts the specified base64url encoded string to UTF-8 bytes.
114187 /// </summary>
@@ -123,6 +196,8 @@ public static byte[] DecodeBytes(string str)
123196#if NETCOREAPP
124197 [ SkipLocalsInit ]
125198#endif
199+
200+ #if NET9_0_OR_GREATER
126201 internal static byte [ ] Decode ( ReadOnlySpan < char > strSpan )
127202 {
128203 int upperBound = Base64Url . GetMaxDecodedLength ( strSpan . Length ) ;
@@ -144,33 +219,37 @@ internal static byte[] Decode(ReadOnlySpan<char> strSpan)
144219 ArrayPool < byte > . Shared . Return ( rented , true ) ;
145220 }
146221 }
147-
148- #if ! NET8_0_OR_GREATER
149- private static bool IsOnlyValidBase64Chars ( ReadOnlySpan < char > strSpan )
222+ #else
223+ internal static byte [ ] Decode ( ReadOnlySpan < char > strSpan )
150224 {
151- foreach ( char c in strSpan )
152- if ( ! char . IsDigit ( c ) && ! char . IsLetter ( c ) && c != Base64Character62 && c != Base64Character63 && c != Base64PadCharacter )
153- return false ;
225+ int mod = strSpan . Length % 4 ;
226+ if ( mod == 1 )
227+ throw LogHelper . LogExceptionMessage ( new FormatException ( LogHelper . FormatInvariant ( LogMessages . IDX10400 , strSpan . ToString ( ) ) ) ) ;
154228
155- return true ;
229+ bool needReplace = strSpan . IndexOfAny ( Base64UrlCharacter62 , Base64UrlCharacter63 ) >= 0 ;
230+ int decodedLength = strSpan . Length + ( 4 - mod ) % 4 ;
231+ #if NET6_0_OR_GREATER
232+ Span < byte > output = new byte [ decodedLength ] ;
233+ int length = Decode ( strSpan , output , needReplace , decodedLength ) ;
234+ return output . Slice ( 0 , length ) . ToArray ( ) ;
235+ #else
236+ return UnsafeDecode ( strSpan , needReplace , decodedLength ) ;
237+ #endif
156238 }
157-
158239#endif
240+
159241#if NETCOREAPP
160242 [ SkipLocalsInit ]
161243#endif
244+
245+ #if NET9_0_OR_GREATER
162246 internal static int Decode ( ReadOnlySpan < char > strSpan , Span < byte > output )
163247 {
164248 OperationStatus status = Base64Url . DecodeFromChars ( strSpan , output , out _ , out int bytesWritten ) ;
165249 if ( status == OperationStatus . Done )
166250 return bytesWritten ;
167251
168- if ( status == OperationStatus . InvalidData &&
169- #if NET8_0_OR_GREATER
170- ! Base64 . IsValid ( strSpan ) )
171- #else
172- ! IsOnlyValidBase64Chars ( strSpan ) )
173- #endif
252+ if ( status == OperationStatus . InvalidData && ! Base64 . IsValid ( strSpan ) )
174253 throw LogHelper . LogExceptionMessage ( new FormatException ( LogHelper . FormatInvariant ( LogMessages . IDX10400 , strSpan . ToString ( ) ) ) ) ;
175254
176255 int mod = strSpan . Length % 4 ;
@@ -180,8 +259,24 @@ internal static int Decode(ReadOnlySpan<char> strSpan, Span<byte> output)
180259
181260 return Decode ( strSpan , output , decodedLength ) ;
182261 }
262+ #else
263+ internal static void Decode ( ReadOnlySpan < char > strSpan , Span < byte > output )
264+ {
265+ int mod = strSpan . Length % 4 ;
266+ if ( mod == 1 )
267+ throw LogHelper . LogExceptionMessage ( new FormatException ( LogHelper . FormatInvariant ( LogMessages . IDX10400 , strSpan . ToString ( ) ) ) ) ;
268+ bool needReplace = strSpan . IndexOfAny ( Base64UrlCharacter62 , Base64UrlCharacter63 ) >= 0 ;
269+ int decodedLength = strSpan . Length + ( 4 - mod ) % 4 ;
270+ #if NET6_0_OR_GREATER
271+ Decode ( strSpan , output , needReplace , decodedLength ) ;
272+ #else
273+ Decode ( strSpan , output , needReplace , decodedLength ) ;
274+ #endif
275+ }
183276
184- #if NETCOREAPP
277+ #endif
278+
279+ #if NET9_0_OR_GREATER
185280 [ SkipLocalsInit ]
186281 private static int Decode ( ReadOnlySpan < char > strSpan , Span < byte > output , int decodedLength )
187282 {
@@ -254,33 +349,144 @@ private static ReadOnlySpan<char> HandlePadding(ReadOnlySpan<char> source, Span<
254349
255350 return charsSpan ;
256351 }
257- #else
258- private static unsafe byte [ ] UnsafeDecode ( ReadOnlySpan < char > strSpan , int decodedLength )
352+ #elif NET6_0_OR_GREATER
353+ [ SkipLocalsInit ]
354+ private static int Decode ( ReadOnlySpan < char > strSpan , Span < byte > output , bool needReplace , int decodedLength )
259355 {
260- if ( decodedLength == strSpan . Length )
356+ // If the incoming chars don't contain any of the base64url characters that need to be replaced,
357+ // and if the incoming chars are of the exact right length, then we'll be able to just pass the
358+ // incoming chars directly to DecodeFromUtf8InPlace. Otherwise, rent an array, copy all the
359+ // data into it, and do whatever fixups are necessary on that copy, then pass that copy into
360+ // DecodeFromUtf8InPlace.
361+
362+ const int StackAllocThreshold = 512 ;
363+ char [ ] arrayPoolChars = null ;
364+ scoped Span < char > charsSpan = default ;
365+ scoped ReadOnlySpan < char > source = strSpan ;
366+
367+ if ( needReplace || decodedLength != source . Length )
368+ {
369+ charsSpan = decodedLength <= StackAllocThreshold ?
370+ stackalloc char [ StackAllocThreshold ] :
371+ arrayPoolChars = ArrayPool < char > . Shared . Rent ( decodedLength ) ;
372+ charsSpan = charsSpan . Slice ( 0 , decodedLength ) ;
373+
374+ source = HandlePaddingAndReplace ( source , charsSpan , needReplace ) ;
375+ }
376+
377+ byte [ ] arrayPoolBytes = null ;
378+ Span < byte > bytesSpan = decodedLength <= StackAllocThreshold ?
379+ stackalloc byte [ StackAllocThreshold ] :
380+ arrayPoolBytes = ArrayPool < byte > . Shared . Rent ( decodedLength ) ;
381+
382+ int length = Encoding . UTF8 . GetBytes ( source , bytesSpan ) ;
383+ Span < byte > utf8Span = bytesSpan . Slice ( 0 , length ) ;
384+ try
261385 {
262- return Convert. FromBase64CharArray ( strSpan . ToArray ( ) , 0 , strSpan . Length ) ;
386+ OperationStatus status = System . Buffers . Text . Base64 . DecodeFromUtf8InPlace ( utf8Span , out int bytesWritten ) ;
387+ if ( status != OperationStatus . Done )
388+ throw LogHelper . LogExceptionMessage ( new FormatException ( LogHelper . FormatInvariant ( LogMessages . IDX10400 , strSpan . ToString ( ) ) ) ) ;
389+
390+ utf8Span . Slice ( 0 , bytesWritten ) . CopyTo ( output ) ;
391+
392+ return bytesWritten ;
263393 }
394+ finally
395+ {
396+ if ( arrayPoolBytes is not null )
397+ {
398+ bytesSpan . Clear ( ) ;
399+ ArrayPool < byte > . Shared . Return ( arrayPoolBytes ) ;
400+ }
264401
265- string decodedString = new ( char . MinValue , decodedLength ) ;
266- fixed ( char * src = strSpan)
267- fixed ( char * dest = decodedString)
402+ if ( arrayPoolChars is not null )
403+ {
404+ charsSpan . Clear ( ) ;
405+ ArrayPool < char > . Shared . Return ( arrayPoolChars ) ;
406+ }
407+ }
408+ }
409+
410+ private static ReadOnlySpan < char > HandlePaddingAndReplace ( ReadOnlySpan < char > source , Span < char > charsSpan , bool needReplace )
411+ {
412+ source . CopyTo ( charsSpan ) ;
413+ if ( source . Length < charsSpan . Length )
268414 {
269- Buffer. MemoryCopy ( src , dest , strSpan . Length * 2 , strSpan . Length * 2 ) ;
415+ charsSpan [ source . Length ] = Base64PadCharacter ;
416+ if ( source . Length + 1 < charsSpan . Length )
417+ {
418+ charsSpan [ source . Length + 1 ] = Base64PadCharacter ;
419+ }
420+ }
270421
271- dest[ strSpan . Length ] = Base64PadCharacter ;
272- if ( strSpan . Length + 2 == decodedLength )
273- dest[ strSpan . Length + 1 ] = Base64PadCharacter ;
422+ if ( needReplace )
423+ {
424+ Span < char > remaining = charsSpan ;
425+ int pos ;
426+ while ( ( pos = remaining . IndexOfAny ( Base64UrlCharacter62 , Base64UrlCharacter63 ) ) >= 0 )
427+ {
428+ remaining [ pos ] = ( remaining [ pos ] == Base64UrlCharacter62 ) ? Base64Character62 : Base64Character63 ;
429+ remaining = remaining . Slice ( pos + 1 ) ;
430+ }
274431 }
275432
276- return Convert . FromBase64String ( decodedString ) ;
433+ return charsSpan ;
277434 }
278435
279- private static int Decode( ReadOnlySpan < char > strSpan , Span < byte > output , int decodedLength )
436+ #else
437+ private static unsafe byte [ ] UnsafeDecode ( ReadOnlySpan < char > strSpan , bool needReplace , int decodedLength )
438+ {
439+ if ( needReplace )
440+ {
441+ string decodedString = new ( char . MinValue , decodedLength ) ;
442+ fixed ( char * dest = decodedString )
443+ {
444+ int i = 0 ;
445+ for ( ; i < strSpan . Length ; i ++ )
446+ {
447+ if ( strSpan [ i ] == Base64UrlCharacter62 )
448+ dest [ i ] = Base64Character62 ;
449+ else if ( strSpan [ i ] == Base64UrlCharacter63 )
450+ dest [ i ] = Base64Character63 ;
451+ else
452+ dest [ i ] = strSpan [ i ] ;
453+ }
454+
455+ for ( ; i < decodedLength ; i ++ )
456+ dest [ i ] = Base64PadCharacter ;
457+ }
458+
459+ return Convert . FromBase64String ( decodedString ) ;
460+ }
461+ else
462+ {
463+ if ( decodedLength == strSpan . Length )
464+ {
465+ return Convert . FromBase64CharArray ( strSpan . ToArray ( ) , 0 , strSpan . Length ) ;
466+ }
467+ else
468+ {
469+ string decodedString = new ( char . MinValue , decodedLength ) ;
470+ fixed ( char * src = strSpan )
471+ fixed ( char * dest = decodedString )
472+ {
473+ Buffer . MemoryCopy ( src , dest , strSpan . Length * 2 , strSpan . Length * 2 ) ;
474+
475+ dest [ strSpan . Length ] = Base64PadCharacter ;
476+ if ( strSpan . Length + 2 == decodedLength )
477+ dest [ strSpan . Length + 1 ] = Base64PadCharacter ;
478+ }
479+
480+ return Convert . FromBase64String ( decodedString ) ;
481+ }
482+ }
483+ }
484+
485+ private static void Decode ( ReadOnlySpan < char > strSpan , Span < byte > output , bool needReplace , int decodedLength )
280486 {
281- byte [ ] result = UnsafeDecode( strSpan , decodedLength ) ;
487+ byte [ ] result = UnsafeDecode ( strSpan , needReplace , decodedLength ) ;
282488 result . CopyTo ( output ) ;
283- return result . Length ;
489+
284490 }
285491#endif
286492
0 commit comments