Skip to content

Commit fbd9196

Browse files
committed
add simd decode; add thresholds for non-simd fallback in encode/decode
1 parent 913eda0 commit fbd9196

File tree

1 file changed

+74
-12
lines changed
  • src/Data/ByteString/Base64/Internal

1 file changed

+74
-12
lines changed

src/Data/ByteString/Base64/Internal/Head.hs

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{-# LANGUAGE BangPatterns #-}
22
{-# LANGUAGE CPP #-}
33
{-# LANGUAGE OverloadedStrings #-}
4+
{-# LANGUAGE ScopedTypeVariables #-}
45
-- |
56
-- Module : Data.ByteString.Base64.Internal.Head
67
-- Copyright : (c) 2019-2022 Emily Pillmore
@@ -42,15 +43,28 @@ import GHC.Word
4243
import System.IO.Unsafe
4344

4445
#ifdef SIMD
45-
import Foreign.C.Types (CSize, CChar)
46+
import Foreign.C.Types (CChar, CInt, CSize)
4647
import Foreign.Storable (peek)
4748
import qualified Foreign.Marshal.Utils as Foreign
49+
import qualified Data.Text as T
4850
import LibBase64Bindings
4951
#endif
5052

5153
encodeBase64_ :: EncodingTable -> ByteString -> ByteString
5254
#ifdef SIMD
53-
encodeBase64_ _ (PS !sfp !soff !slen) =
55+
encodeBase64_ table b@(PS _ _ !slen)
56+
| slen < threshold = encodeBase64Loop_ table b
57+
| otherwise = encodeBase64Simd_ b
58+
where
59+
!threshold = 1000 -- 1k
60+
#else
61+
encodeBase64_ table b = encodeBase64Loop_ table b
62+
#endif
63+
{-# inline encodeBase64_ #-}
64+
65+
#ifdef SIMD
66+
encodeBase64Simd_ :: ByteString -> ByteString
67+
encodeBase64Simd_ (PS !sfp !soff !slen) =
5468
unsafeDupablePerformIO $ do
5569
dfp <- mallocPlainForeignPtrBytes dlen
5670
dlenFinal <- do
@@ -68,14 +82,10 @@ encodeBase64_ _ (PS !sfp !soff !slen) =
6882
where
6983
!dlen = 4 * ((slen + 2) `div` 3)
7084
!base64Flags = 0
85+
#endif
7186

72-
intToCSize :: Int -> CSize
73-
intToCSize = fromIntegral
74-
75-
cSizeToInt :: CSize -> Int
76-
cSizeToInt = fromIntegral
77-
#else
78-
encodeBase64_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) =
87+
encodeBase64Loop_ :: EncodingTable -> ByteString -> ByteString
88+
encodeBase64Loop_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) =
7989
unsafeDupablePerformIO $ do
8090
dfp <- mallocPlainForeignPtrBytes dlen
8191
withForeignPtr dfp $ \dptr ->
@@ -90,7 +100,6 @@ encodeBase64_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) =
90100
(loopTail dfp aptr dptr (castPtr end))
91101
where
92102
!dlen = 4 * ((slen + 2) `div` 3)
93-
#endif
94103

95104
encodeBase64Nopad_ :: EncodingTable -> ByteString -> ByteString
96105
encodeBase64Nopad_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) =
@@ -109,6 +118,33 @@ encodeBase64Nopad_ (EncodingTable !aptr !efp) (PS !sfp !soff !slen) =
109118
where
110119
!dlen = 4 * ((slen + 2) `div` 3)
111120

121+
#ifdef SIMD
122+
decodeBase64Simd_ :: ByteString -> IO (Either Text ByteString)
123+
decodeBase64Simd_ (PS !sfp !soff !slen) = do
124+
withForeignPtr sfp $ \src -> do
125+
dfp <- mallocPlainForeignPtrBytes dlen
126+
edlenFinal :: Either Text CSize <- do
127+
withForeignPtr dfp $ \out -> do
128+
Foreign.with (intToCSize dlen) $ \outlen -> do
129+
decodeResult <- base64_decode
130+
(plusPtr (castPtr src :: Ptr CChar) soff)
131+
(intToCSize slen)
132+
out
133+
outlen
134+
base64Flags
135+
case decodeResult of
136+
1 -> Right <$> peek outlen
137+
0 -> pure (Left "SIMD: Invalid input")
138+
(-1) -> pure (Left "Invalid Codec")
139+
x -> pure (Left ("Unexpected result from libbase64 base64_decode: " <> T.pack (show (cIntToInt x))))
140+
pure $ fmap
141+
(\dlenFinal -> PS (castForeignPtr dfp) 0 (cSizeToInt dlenFinal))
142+
edlenFinal
143+
where
144+
!dlen = (slen `quot` 4) * 3
145+
!base64Flags = 0
146+
#endif
147+
112148
-- | The main decode function. Takes a padding flag, a decoding table, and
113149
-- the input value, producing either an error string on the left, or a
114150
-- decoded value.
@@ -123,7 +159,22 @@ decodeBase64_
123159
:: ForeignPtr Word8
124160
-> ByteString
125161
-> IO (Either Text ByteString)
126-
decodeBase64_ !dtfp (PS !sfp !soff !slen) =
162+
#ifdef SIMD
163+
decodeBase64_ dtfp b@(PS _ _ !slen)
164+
| slen < threshold = decodeBase64Loop_ dtfp b
165+
| otherwise = decodeBase64Simd_ b
166+
where
167+
!threshold = 250
168+
#else
169+
decodeBase64_ dtfp b = decodeBase64Loop_ dtfp b
170+
#endif
171+
{-# inline decodeBase64_ #-}
172+
173+
decodeBase64Loop_
174+
:: ForeignPtr Word8
175+
-> ByteString
176+
-> IO (Either Text ByteString)
177+
decodeBase64Loop_ !dtfp (PS !sfp !soff !slen) =
127178
withForeignPtr dtfp $ \dtable ->
128179
withForeignPtr sfp $ \sptr -> do
129180
dfp <- mallocPlainForeignPtrBytes dlen
@@ -134,7 +185,7 @@ decodeBase64_ !dtfp (PS !sfp !soff !slen) =
134185
dptr end dfp
135186
where
136187
!dlen = (slen `quot` 4) * 3
137-
{-# inline decodeBase64_ #-}
188+
{-# inline decodeBase64Loop_ #-}
138189

139190
decodeBase64Lenient_ :: ForeignPtr Word8 -> ByteString -> ByteString
140191
decodeBase64Lenient_ !dtfp (PS !sfp !soff !slen) = unsafeDupablePerformIO $
@@ -150,3 +201,14 @@ decodeBase64Lenient_ !dtfp (PS !sfp !soff !slen) = unsafeDupablePerformIO $
150201
dfp
151202
where
152203
!dlen = ((slen + 3) `div` 4) * 3
204+
205+
#ifdef SIMD
206+
intToCSize :: Int -> CSize
207+
intToCSize = fromIntegral
208+
209+
cSizeToInt :: CSize -> Int
210+
cSizeToInt = fromIntegral
211+
212+
cIntToInt :: CInt -> Int
213+
cIntToInt = fromIntegral
214+
#endif

0 commit comments

Comments
 (0)