Skip to content

Commit 1eb5637

Browse files
committed
Merge pull request #84 from dolio/master
Avoid giving out uninitialized memory from safe vector creation functions.
2 parents e1c01ca + 7134f75 commit 1eb5637

File tree

7 files changed

+93
-4
lines changed

7 files changed

+93
-4
lines changed

Data/Vector/Generic/Mutable.hs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ overlaps = basicOverlaps
582582
new :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
583583
{-# INLINE new #-}
584584
new n = BOUNDS_CHECK(checkLength) "new" n
585-
$ unsafeNew n
585+
$ unsafeNew n >>= \v -> basicInitialize v >> return v
586586

587587
-- | Create a mutable vector of the given length. The length is not checked.
588588
unsafeNew :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
@@ -619,13 +619,17 @@ grow :: (PrimMonad m, MVector v a)
619619
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
620620
{-# INLINE grow #-}
621621
grow v by = BOUNDS_CHECK(checkLength) "grow" by
622-
$ unsafeGrow v by
622+
$ do vnew <- unsafeGrow v by
623+
basicInitialize $ basicUnsafeSlice (length v) by vnew
624+
return vnew
623625

624626
growFront :: (PrimMonad m, MVector v a)
625627
=> v (PrimState m) a -> Int -> m (v (PrimState m) a)
626628
{-# INLINE growFront #-}
627629
growFront v by = BOUNDS_CHECK(checkLength) "growFront" by
628-
$ unsafeGrowFront v by
630+
$ do vnew <- unsafeGrowFront v by
631+
basicInitialize $ basicUnsafeSlice 0 by vnew
632+
return vnew
629633

630634
enlarge_delta :: MVector v a => v s a -> Int
631635
enlarge_delta v = max (length v) 1
@@ -634,13 +638,18 @@ enlarge_delta v = max (length v) 1
634638
enlarge :: (PrimMonad m, MVector v a)
635639
=> v (PrimState m) a -> m (v (PrimState m) a)
636640
{-# INLINE enlarge #-}
637-
enlarge v = unsafeGrow v (enlarge_delta v)
641+
enlarge v = do vnew <- unsafeGrow v by
642+
basicInitialize $ basicUnsafeSlice (length v) by vnew
643+
return vnew
644+
where
645+
by = enlarge_delta v
638646

639647
enlargeFront :: (PrimMonad m, MVector v a)
640648
=> v (PrimState m) a -> m (v (PrimState m) a, Int)
641649
{-# INLINE enlargeFront #-}
642650
enlargeFront v = do
643651
v' <- unsafeGrowFront v by
652+
basicInitialize $ basicUnsafeSlice 0 by v'
644653
return (v', by)
645654
where
646655
by = enlarge_delta v

Data/Vector/Generic/Mutable/Base.hs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ class MVector v a where
4343
-- called directly, use 'unsafeNew' instead.
4444
basicUnsafeNew :: PrimMonad m => Int -> m (v (PrimState m) a)
4545

46+
-- | Initialize a vector to a standard value. This is intended to be called as
47+
-- part of the safe new operation (and similar operations), to properly blank
48+
-- the newly allocated memory if necessary.
49+
--
50+
-- Vectors that are necessarily initialized as part of creation may implement
51+
-- this as a no-op.
52+
basicInitialize :: PrimMonad m => v (PrimState m) a -> m ()
53+
4654
-- | Create a mutable vector of the given length and fill it with an
4755
-- initial value. This method should not be called directly, use
4856
-- 'replicate' instead.

Data/Vector/Mutable.hs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ instance G.MVector MVector a where
100100
arr <- newArray n uninitialised
101101
return (MVector 0 n arr)
102102

103+
{-# INLINE basicInitialize #-}
104+
-- initialization is unnecessary for boxed vectors
105+
basicInitialize _ = return ()
106+
103107
{-# INLINE basicUnsafeReplicate #-}
104108
basicUnsafeReplicate n x
105109
= do

Data/Vector/Primitive/Mutable.hs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ module Data.Vector.Primitive.Mutable (
5252
import qualified Data.Vector.Generic.Mutable as G
5353
import Data.Primitive.ByteArray
5454
import Data.Primitive ( Prim, sizeOf )
55+
import Data.Word ( Word8 )
5556
import Control.Monad.Primitive
5657
import Control.Monad ( liftM )
5758

@@ -99,6 +100,13 @@ instance Prim a => G.MVector MVector a where
99100
size = sizeOf (undefined :: a)
100101
mx = maxBound `div` size :: Int
101102

103+
{-# INLINE basicInitialize #-}
104+
basicInitialize (MVector off n v) =
105+
setByteArray v (off * size) (n * size) (0 :: Word8)
106+
where
107+
size = sizeOf (undefined :: a)
108+
109+
102110
{-# INLINE basicUnsafeRead #-}
103111
basicUnsafeRead (MVector i _ arr) j = readByteArray arr (i+j)
104112

Data/Vector/Storable/Mutable.hs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ instance Storable a => G.MVector MVector a where
122122
fp <- mallocVector n
123123
return $ MVector n fp
124124

125+
{-# INLINE basicInitialize #-}
126+
basicInitialize = storableZero
127+
125128
{-# INLINE basicUnsafeRead #-}
126129
basicUnsafeRead (MVector _ fp) i
127130
= unsafePrimToPrim
@@ -149,6 +152,18 @@ instance Storable a => G.MVector MVector a where
149152
withForeignPtr fq $ \q ->
150153
moveArray p q n
151154

155+
storableZero :: forall a m. (Storable a, PrimMonad m) => MVector (PrimState m) a -> m ()
156+
{-# INLINE storableZero #-}
157+
storableZero (MVector n fp) = unsafePrimToPrim . withForeignPtr fp $ \(Ptr p) -> do
158+
let q = Addr p
159+
setAddr q byteSize (0 :: Word8)
160+
where
161+
x :: a
162+
x = undefined
163+
164+
byteSize :: Int
165+
byteSize = n * sizeOf x
166+
152167
storableSet :: (Storable a, PrimMonad m) => MVector (PrimState m) a -> a -> m ()
153168
{-# INLINE storableSet #-}
154169
storableSet (MVector n fp) x

Data/Vector/Unboxed/Base.hs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ instance M.MVector MVector () where
106106
{-# INLINE basicUnsafeSlice #-}
107107
{-# INLINE basicOverlaps #-}
108108
{-# INLINE basicUnsafeNew #-}
109+
{-# INLINE basicInitialize #-}
109110
{-# INLINE basicUnsafeRead #-}
110111
{-# INLINE basicUnsafeWrite #-}
111112
{-# INLINE basicClear #-}
@@ -121,6 +122,9 @@ instance M.MVector MVector () where
121122

122123
basicUnsafeNew n = return (MV_Unit n)
123124

125+
-- Nothing to initialize
126+
basicInitialize _ = return ()
127+
124128
basicUnsafeRead (MV_Unit _) _ = return ()
125129

126130
basicUnsafeWrite (MV_Unit _) _ () = return ()
@@ -166,6 +170,7 @@ instance M.MVector MVector ty where { \
166170
; {-# INLINE basicUnsafeSlice #-} \
167171
; {-# INLINE basicOverlaps #-} \
168172
; {-# INLINE basicUnsafeNew #-} \
173+
; {-# INLINE basicInitialize #-} \
169174
; {-# INLINE basicUnsafeReplicate #-} \
170175
; {-# INLINE basicUnsafeRead #-} \
171176
; {-# INLINE basicUnsafeWrite #-} \
@@ -177,6 +182,7 @@ instance M.MVector MVector ty where { \
177182
; basicUnsafeSlice i n (con v) = con $ M.basicUnsafeSlice i n v \
178183
; basicOverlaps (con v1) (con v2) = M.basicOverlaps v1 v2 \
179184
; basicUnsafeNew n = con `liftM` M.basicUnsafeNew n \
185+
; basicInitialize (con v) = M.basicInitialize v \
180186
; basicUnsafeReplicate n x = con `liftM` M.basicUnsafeReplicate n x \
181187
; basicUnsafeRead (con v) i = M.basicUnsafeRead v i \
182188
; basicUnsafeWrite (con v) i x = M.basicUnsafeWrite v i x \
@@ -307,6 +313,7 @@ instance M.MVector MVector Bool where
307313
{-# INLINE basicUnsafeSlice #-}
308314
{-# INLINE basicOverlaps #-}
309315
{-# INLINE basicUnsafeNew #-}
316+
{-# INLINE basicInitialize #-}
310317
{-# INLINE basicUnsafeReplicate #-}
311318
{-# INLINE basicUnsafeRead #-}
312319
{-# INLINE basicUnsafeWrite #-}
@@ -318,6 +325,7 @@ instance M.MVector MVector Bool where
318325
basicUnsafeSlice i n (MV_Bool v) = MV_Bool $ M.basicUnsafeSlice i n v
319326
basicOverlaps (MV_Bool v1) (MV_Bool v2) = M.basicOverlaps v1 v2
320327
basicUnsafeNew n = MV_Bool `liftM` M.basicUnsafeNew n
328+
basicInitialize (MV_Bool v) = M.basicInitialize v
321329
basicUnsafeReplicate n x = MV_Bool `liftM` M.basicUnsafeReplicate n (fromBool x)
322330
basicUnsafeRead (MV_Bool v) i = toBool `liftM` M.basicUnsafeRead v i
323331
basicUnsafeWrite (MV_Bool v) i x = M.basicUnsafeWrite v i (fromBool x)
@@ -356,6 +364,7 @@ instance (RealFloat a, Unbox a) => M.MVector MVector (Complex a) where
356364
{-# INLINE basicUnsafeSlice #-}
357365
{-# INLINE basicOverlaps #-}
358366
{-# INLINE basicUnsafeNew #-}
367+
{-# INLINE basicInitialize #-}
359368
{-# INLINE basicUnsafeReplicate #-}
360369
{-# INLINE basicUnsafeRead #-}
361370
{-# INLINE basicUnsafeWrite #-}
@@ -367,6 +376,7 @@ instance (RealFloat a, Unbox a) => M.MVector MVector (Complex a) where
367376
basicUnsafeSlice i n (MV_Complex v) = MV_Complex $ M.basicUnsafeSlice i n v
368377
basicOverlaps (MV_Complex v1) (MV_Complex v2) = M.basicOverlaps v1 v2
369378
basicUnsafeNew n = MV_Complex `liftM` M.basicUnsafeNew n
379+
basicInitialize (MV_Complex v) = M.basicInitialize v
370380
basicUnsafeReplicate n (x :+ y) = MV_Complex `liftM` M.basicUnsafeReplicate n (x,y)
371381
basicUnsafeRead (MV_Complex v) i = uncurry (:+) `liftM` M.basicUnsafeRead v i
372382
basicUnsafeWrite (MV_Complex v) i (x :+ y) = M.basicUnsafeWrite v i (x,y)

internal/unbox-tuple-instances

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ instance (Unbox a, Unbox b) => M.MVector MVector (a, b) where
2323
as <- M.basicUnsafeNew n_
2424
bs <- M.basicUnsafeNew n_
2525
return $ MV_2 n_ as bs
26+
{-# INLINE basicInitialize #-}
27+
basicInitialize (MV_2 _ as bs)
28+
= do
29+
M.basicInitialize as
30+
M.basicInitialize bs
2631
{-# INLINE basicUnsafeReplicate #-}
2732
basicUnsafeReplicate n_ (a, b)
2833
= do
@@ -162,6 +167,12 @@ instance (Unbox a,
162167
bs <- M.basicUnsafeNew n_
163168
cs <- M.basicUnsafeNew n_
164169
return $ MV_3 n_ as bs cs
170+
{-# INLINE basicInitialize #-}
171+
basicInitialize (MV_3 _ as bs cs)
172+
= do
173+
M.basicInitialize as
174+
M.basicInitialize bs
175+
M.basicInitialize cs
165176
{-# INLINE basicUnsafeReplicate #-}
166177
basicUnsafeReplicate n_ (a, b, c)
167178
= do
@@ -337,6 +348,13 @@ instance (Unbox a,
337348
cs <- M.basicUnsafeNew n_
338349
ds <- M.basicUnsafeNew n_
339350
return $ MV_4 n_ as bs cs ds
351+
{-# INLINE basicInitialize #-}
352+
basicInitialize (MV_4 _ as bs cs ds)
353+
= do
354+
M.basicInitialize as
355+
M.basicInitialize bs
356+
M.basicInitialize cs
357+
M.basicInitialize ds
340358
{-# INLINE basicUnsafeReplicate #-}
341359
basicUnsafeReplicate n_ (a, b, c, d)
342360
= do
@@ -567,6 +585,14 @@ instance (Unbox a,
567585
ds <- M.basicUnsafeNew n_
568586
es <- M.basicUnsafeNew n_
569587
return $ MV_5 n_ as bs cs ds es
588+
{-# INLINE basicInitialize #-}
589+
basicInitialize (MV_5 _ as bs cs ds es)
590+
= do
591+
M.basicInitialize as
592+
M.basicInitialize bs
593+
M.basicInitialize cs
594+
M.basicInitialize ds
595+
M.basicInitialize es
570596
{-# INLINE basicUnsafeReplicate #-}
571597
basicUnsafeReplicate n_ (a, b, c, d, e)
572598
= do
@@ -846,6 +872,15 @@ instance (Unbox a,
846872
es <- M.basicUnsafeNew n_
847873
fs <- M.basicUnsafeNew n_
848874
return $ MV_6 n_ as bs cs ds es fs
875+
{-# INLINE basicInitialize #-}
876+
basicInitialize (MV_6 _ as bs cs ds es fs)
877+
= do
878+
M.basicInitialize as
879+
M.basicInitialize bs
880+
M.basicInitialize cs
881+
M.basicInitialize ds
882+
M.basicInitialize es
883+
M.basicInitialize fs
849884
{-# INLINE basicUnsafeReplicate #-}
850885
basicUnsafeReplicate n_ (a, b, c, d, e, f)
851886
= do

0 commit comments

Comments
 (0)