Skip to content

Commit 470b6e3

Browse files
authored
Add a few more size-overflow-related checks (#599)
1 parent 88f16dc commit 470b6e3

File tree

4 files changed

+89
-83
lines changed

4 files changed

+89
-83
lines changed

Changelog.md

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
[0.12.0.0]Unreleased
1+
[0.12.0.0]July 2023
22

33
* __Breaking Changes__:
44
* [`readInt` returns `Nothing`, if the sequence of digits cannot be represented by an `Int`, instead of overflowing silently](https://github.com/haskell/bytestring/pull/309)
@@ -14,11 +14,8 @@
1414
* [`stimes @StrictByteString`](https://github.com/haskell/bytestring/pull/443)
1515
* [`Data.ByteString.Short.concat`](https://github.com/haskell/bytestring/pull/443)
1616
* [`Data.ByteString.Short.append`](https://github.com/haskell/bytestring/pull/443)
17-
<!-- TODO: Some other `ShortByteString` functions are probably still
18-
susceptible to bad behavior on `Int` overflow in edge cases;
19-
`D.B.Short.Internal.create` does not check for negative size,
20-
unlike its `StrictByteString` counterpart.
21-
-->
17+
* [`Data.ByteString.Short.snoc`](https://github.com/haskell/bytestring/pull/599)
18+
* [`Data.ByteString.Short.cons`](https://github.com/haskell/bytestring/pull/599)
2219
* API additions:
2320
* [New sized and/or unsigned variants of `readInt` and `readInteger`](https://github.com/haskell/bytestring/pull/438)
2421
* [`Data.ByteString.Internal` now provides `SizeOverflowException`, `overflowError`, and `checkedMultiply`](https://github.com/haskell/bytestring/pull/443)
@@ -34,7 +31,7 @@
3431

3532
[0.12.0.0]: https://github.com/haskell/bytestring/compare/0.11.5.0...0.12.0.0
3633

37-
[0.11.5.0]Unreleased
34+
[0.11.5.0]July 2023
3835

3936
* Bug fixes:
4037
* [Fix multiple bugs with ASCII blocks in the SIMD implementations for `isValidUtf8`](https://github.com/haskell/bytestring/pull/582)

Data/ByteString.hs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -380,16 +380,16 @@ infixl 5 `snoc`
380380
-- | /O(n)/ 'cons' is analogous to (:) for lists, but of different
381381
-- complexity, as it requires making a copy.
382382
cons :: Word8 -> ByteString -> ByteString
383-
cons c (BS x l) = unsafeCreateFp (l+1) $ \p -> do
383+
cons c (BS x len) = unsafeCreateFp (checkedAdd "cons" len 1) $ \p -> do
384384
pokeFp p c
385-
memcpyFp (p `plusForeignPtr` 1) x l
385+
memcpyFp (p `plusForeignPtr` 1) x len
386386
{-# INLINE cons #-}
387387

388388
-- | /O(n)/ Append a byte to the end of a 'ByteString'
389389
snoc :: ByteString -> Word8 -> ByteString
390-
snoc (BS x l) c = unsafeCreateFp (l+1) $ \p -> do
391-
memcpyFp p x l
392-
pokeFp (p `plusForeignPtr` l) c
390+
snoc (BS x len) c = unsafeCreateFp (checkedAdd "snoc" len 1) $ \p -> do
391+
memcpyFp p x len
392+
pokeFp (p `plusForeignPtr` len) c
393393
{-# INLINE snoc #-}
394394

395395
-- | /O(1)/ Extract the first element of a ByteString, which must be non-empty.
@@ -773,7 +773,7 @@ scanl
773773
-- ^ input of length n
774774
-> ByteString
775775
-- ^ output of length n+1
776-
scanl f v = \(BS a len) -> unsafeCreateFp (len+1) $ \q -> do
776+
scanl f v = \(BS a len) -> unsafeCreateFp (checkedAdd "scanl" len 1) $ \q -> do
777777
-- see fold inlining
778778
pokeFp q v
779779
let
@@ -817,7 +817,7 @@ scanr
817817
-- ^ input of length n
818818
-> ByteString
819819
-- ^ output of length n+1
820-
scanr f v = \(BS a len) -> unsafeCreateFp (len+1) $ \b -> do
820+
scanr f v = \(BS a len) -> unsafeCreateFp (checkedAdd "scanr" len 1) $ \b -> do
821821
-- see fold inlining
822822
pokeFpByteOff b len v
823823
let

Data/ByteString/Internal/Type.hs

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -653,30 +653,30 @@ unsafeCreateFpUptoN' l f = unsafeDupablePerformIO (createFpUptoN' l f)
653653

654654
-- | Create ByteString of size @l@ and use action @f@ to fill its contents.
655655
createFp :: Int -> (ForeignPtr Word8 -> IO ()) -> IO ByteString
656-
createFp l action = do
657-
fp <- mallocByteString l
656+
createFp len action = assert (len >= 0) $ do
657+
fp <- mallocByteString len
658658
action fp
659-
mkDeferredByteString fp l
659+
mkDeferredByteString fp len
660660
{-# INLINE createFp #-}
661661

662662
-- | Given a maximum size @l@ and an action @f@ that fills the 'ByteString'
663663
-- starting at the given 'Ptr' and returns the actual utilized length,
664664
-- @`createFpUptoN'` l f@ returns the filled 'ByteString'.
665665
createFpUptoN :: Int -> (ForeignPtr Word8 -> IO Int) -> IO ByteString
666-
createFpUptoN l action = do
667-
fp <- mallocByteString l
668-
l' <- action fp
669-
assert (l' <= l) $ mkDeferredByteString fp l'
666+
createFpUptoN maxLen action = assert (maxLen >= 0) $ do
667+
fp <- mallocByteString maxLen
668+
len <- action fp
669+
assert (0 <= len && len <= maxLen) $ mkDeferredByteString fp len
670670
{-# INLINE createFpUptoN #-}
671671

672672
-- | Like 'createFpUptoN', but also returns an additional value created by the
673673
-- action.
674674
createFpUptoN' :: Int -> (ForeignPtr Word8 -> IO (Int, a)) -> IO (ByteString, a)
675-
createFpUptoN' l action = do
676-
fp <- mallocByteString l
677-
(l', res) <- action fp
678-
bs <- mkDeferredByteString fp l'
679-
assert (l' <= l) $ pure (bs, res)
675+
createFpUptoN' maxLen action = assert (maxLen >= 0) $ do
676+
fp <- mallocByteString maxLen
677+
(len, res) <- action fp
678+
bs <- mkDeferredByteString fp len
679+
assert (0 <= len && len <= maxLen) $ pure (bs, res)
680680
{-# INLINE createFpUptoN' #-}
681681

682682
-- | Given the maximum size needed and a function to make the contents
@@ -688,22 +688,26 @@ createFpUptoN' l action = do
688688
-- ByteString functions, using Haskell or C functions to fill the space.
689689
--
690690
createFpAndTrim :: Int -> (ForeignPtr Word8 -> IO Int) -> IO ByteString
691-
createFpAndTrim l action = do
692-
fp <- mallocByteString l
693-
l' <- action fp
694-
if assert (0 <= l' && l' <= l) $ l' >= l
695-
then mkDeferredByteString fp l
696-
else createFp l' $ \dest -> memcpyFp dest fp l'
691+
createFpAndTrim maxLen action = assert (maxLen >= 0) $ do
692+
fp <- mallocByteString maxLen
693+
len <- action fp
694+
if assert (0 <= len && len <= maxLen) $ len >= maxLen
695+
then mkDeferredByteString fp maxLen
696+
else createFp len $ \dest -> memcpyFp dest fp len
697697
{-# INLINE createFpAndTrim #-}
698698

699699
createFpAndTrim' :: Int -> (ForeignPtr Word8 -> IO (Int, Int, a)) -> IO (ByteString, a)
700-
createFpAndTrim' l action = do
701-
fp <- mallocByteString l
702-
(off, l', res) <- action fp
703-
bs <- if assert (0 <= l' && l' <= l) $ l' >= l
704-
then mkDeferredByteString fp l -- entire buffer used => offset is zero
705-
else createFp l' $ \dest ->
706-
memcpyFp dest (fp `plusForeignPtr` off) l'
700+
createFpAndTrim' maxLen action = assert (maxLen >= 0) $ do
701+
fp <- mallocByteString maxLen
702+
(off, len, res) <- action fp
703+
assert (
704+
0 <= len && len <= maxLen && -- length OK
705+
(len == 0 || (0 <= off && off <= maxLen - len)) -- offset OK
706+
) $ pure ()
707+
bs <- if len >= maxLen
708+
then mkDeferredByteString fp maxLen -- entire buffer used => offset is zero
709+
else createFp len $ \dest ->
710+
memcpyFp dest (fp `plusForeignPtr` off) len
707711
return (bs, res)
708712
{-# INLINE createFpAndTrim' #-}
709713

@@ -971,8 +975,10 @@ overflowError fun = throw $ SizeOverflowException msg
971975
checkedAdd :: String -> Int -> Int -> Int
972976
{-# INLINE checkedAdd #-}
973977
checkedAdd fun x y
974-
| r >= 0 = r
975-
| otherwise = overflowError fun
978+
-- checking "r < 0" here matches the condition in mallocPlainForeignPtrBytes,
979+
-- helping the compiler see the latter is redundant in some places
980+
| r < 0 = overflowError fun
981+
| otherwise = r
976982
where r = assert (min x y >= 0) $ x + y
977983

978984
-- | Multiplies two non-negative numbers.

Data/ByteString/Short/Internal.hs

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ unSBS (ShortByteString (ByteArray ba#)) = ba#
400400

401401
create :: Int -> (forall s. MBA s -> ST s ()) -> ShortByteString
402402
create len fill =
403-
runST $ do
403+
assert (len >= 0) $ runST $ do
404404
mba <- newByteArray len
405405
fill mba
406406
BA# ba# <- unsafeFreezeByteArray mba
@@ -413,59 +413,60 @@ create len fill =
413413
-- (<= the maximum size) and the result value. The resulting byte array
414414
-- is realloced to this size.
415415
createAndTrim :: Int -> (forall s. MBA s -> ST s (Int, a)) -> (ShortByteString, a)
416-
createAndTrim l fill =
417-
runST $ do
418-
mba <- newByteArray l
419-
(l', res) <- fill mba
420-
if assert (l' <= l) $ l' >= l
416+
createAndTrim maxLen fill =
417+
assert (maxLen >= 0) $ runST $ do
418+
mba <- newByteArray maxLen
419+
(len, res) <- fill mba
420+
if assert (0 <= len && len <= maxLen) $ len >= maxLen
421421
then do
422422
BA# ba# <- unsafeFreezeByteArray mba
423423
return (SBS ba#, res)
424424
else do
425-
mba2 <- newByteArray l'
426-
copyMutableByteArray mba 0 mba2 0 l'
425+
mba2 <- newByteArray len
426+
copyMutableByteArray mba 0 mba2 0 len
427427
BA# ba# <- unsafeFreezeByteArray mba2
428428
return (SBS ba#, res)
429429
{-# INLINE createAndTrim #-}
430430

431431
createAndTrim' :: Int -> (forall s. MBA s -> ST s Int) -> ShortByteString
432-
createAndTrim' l fill =
433-
runST $ do
434-
mba <- newByteArray l
435-
l' <- fill mba
436-
if assert (l' <= l) $ l' >= l
432+
createAndTrim' maxLen fill =
433+
assert (maxLen >= 0) $ runST $ do
434+
mba <- newByteArray maxLen
435+
len <- fill mba
436+
if assert (0 <= len && len <= maxLen) $ len >= maxLen
437437
then do
438438
BA# ba# <- unsafeFreezeByteArray mba
439439
return (SBS ba#)
440440
else do
441-
mba2 <- newByteArray l'
442-
copyMutableByteArray mba 0 mba2 0 l'
441+
mba2 <- newByteArray len
442+
copyMutableByteArray mba 0 mba2 0 len
443443
BA# ba# <- unsafeFreezeByteArray mba2
444444
return (SBS ba#)
445445
{-# INLINE createAndTrim' #-}
446446

447-
createAndTrim'' :: Int -> (forall s. MBA s -> MBA s -> ST s (Int, Int)) -> (ShortByteString, ShortByteString)
448-
createAndTrim'' l fill =
447+
-- | Like createAndTrim, but with two buffers at once
448+
createAndTrim2 :: Int -> Int -> (forall s. MBA s -> MBA s -> ST s (Int, Int)) -> (ShortByteString, ShortByteString)
449+
createAndTrim2 maxLen1 maxLen2 fill =
449450
runST $ do
450-
mba1 <- newByteArray l
451-
mba2 <- newByteArray l
452-
(l1, l2) <- fill mba1 mba2
453-
sbs1 <- freeze' l1 mba1
454-
sbs2 <- freeze' l2 mba2
451+
mba1 <- newByteArray maxLen1
452+
mba2 <- newByteArray maxLen2
453+
(len1, len2) <- fill mba1 mba2
454+
sbs1 <- freeze' len1 maxLen1 mba1
455+
sbs2 <- freeze' len2 maxLen2 mba2
455456
pure (sbs1, sbs2)
456457
where
457-
freeze' :: Int -> MBA s -> ST s ShortByteString
458-
freeze' l' mba =
459-
if assert (l' <= l) $ l' >= l
458+
freeze' :: Int -> Int -> MBA s -> ST s ShortByteString
459+
freeze' len maxLen mba =
460+
if assert (0 <= len && len <= maxLen) $ len >= maxLen
460461
then do
461462
BA# ba# <- unsafeFreezeByteArray mba
462463
return (SBS ba#)
463464
else do
464-
mba2 <- newByteArray l'
465-
copyMutableByteArray mba 0 mba2 0 l'
465+
mba2 <- newByteArray len
466+
copyMutableByteArray mba 0 mba2 0 len
466467
BA# ba# <- unsafeFreezeByteArray mba2
467468
return (SBS ba#)
468-
{-# INLINE createAndTrim'' #-}
469+
{-# INLINE createAndTrim2 #-}
469470

470471
isPinned :: ByteArray# -> Bool
471472
#if MIN_VERSION_base(4,10,0)
@@ -676,23 +677,23 @@ infixl 5 `snoc`
676677
--
677678
-- @since 0.11.3.0
678679
snoc :: ShortByteString -> Word8 -> ShortByteString
679-
snoc = \sbs c -> let l = length sbs
680-
nl = l + 1
681-
in create nl $ \mba -> do
682-
copyByteArray (asBA sbs) 0 mba 0 l
683-
writeWord8Array mba l c
680+
snoc = \sbs c -> let len = length sbs
681+
newLen = checkedAdd "Short.snoc" len 1
682+
in create newLen $ \mba -> do
683+
copyByteArray (asBA sbs) 0 mba 0 len
684+
writeWord8Array mba len c
684685

685686
-- | /O(n)/ 'cons' is analogous to (:) for lists.
686687
--
687688
-- Note: copies the entire byte array
688689
--
689690
-- @since 0.11.3.0
690691
cons :: Word8 -> ShortByteString -> ShortByteString
691-
cons c = \sbs -> let l = length sbs
692-
nl = l + 1
693-
in create nl $ \mba -> do
692+
cons c = \sbs -> let len = length sbs
693+
newLen = checkedAdd "Short.cons" len 1
694+
in create newLen $ \mba -> do
694695
writeWord8Array mba 0 c
695-
copyByteArray (asBA sbs) 0 mba 1 l
696+
copyByteArray (asBA sbs) 0 mba 1 len
696697

697698
-- | /O(1)/ Extract the last element of a ShortByteString, which must be finite and non-empty.
698699
-- An exception will be thrown in the case of an empty ShortByteString.
@@ -1484,9 +1485,9 @@ find f = \sbs -> case findIndex f sbs of
14841485
--
14851486
-- @since 0.11.3.0
14861487
partition :: (Word8 -> Bool) -> ShortByteString -> (ShortByteString, ShortByteString)
1487-
partition k = \sbs -> let l = length sbs
1488-
in if | l <= 0 -> (sbs, sbs)
1489-
| otherwise -> createAndTrim'' l $ \mba1 mba2 -> go mba1 mba2 (asBA sbs) l
1488+
partition k = \sbs -> let len = length sbs
1489+
in if | len <= 0 -> (sbs, sbs)
1490+
| otherwise -> createAndTrim2 len len $ \mba1 mba2 -> go mba1 mba2 (asBA sbs) len
14901491
where
14911492
go :: forall s.
14921493
MBA s -- mutable output bytestring1
@@ -1614,12 +1615,14 @@ indexWord8ArrayAsWord64 (BA# ba#) (I# i#) = W64# (indexWord8ArrayAsWord64# ba# i
16141615
#endif
16151616

16161617
newByteArray :: Int -> ST s (MBA s)
1617-
newByteArray (I# len#) =
1618+
newByteArray len@(I# len#) =
1619+
assert (len >= 0) $
16181620
ST $ \s -> case newByteArray# len# s of
16191621
(# s', mba# #) -> (# s', MBA# mba# #)
16201622

16211623
newPinnedByteArray :: Int -> ST s (MBA s)
1622-
newPinnedByteArray (I# len#) =
1624+
newPinnedByteArray len@(I# len#) =
1625+
assert (len >= 0) $
16231626
ST $ \s -> case newPinnedByteArray# len# s of
16241627
(# s', mba# #) -> (# s', MBA# mba# #)
16251628

0 commit comments

Comments
 (0)