diff --git a/llvm-pretty-bc-parser.cabal b/llvm-pretty-bc-parser.cabal index ffead9ca..e2fe4a39 100644 --- a/llvm-pretty-bc-parser.cabal +++ b/llvm-pretty-bc-parser.cabal @@ -24,6 +24,10 @@ Flag regressions Description: Enable regression testing build Default: False +Flag quick + Description: Disable traces and validation of LLVM input; minimal parsing checking + Default: False + Source-repository head type: git location: http://github.com/galoisinc/llvm-pretty-bc-parser @@ -60,6 +64,9 @@ Library -O2 -funbox-strict-fields + if flag(quick) + CPP-Options: -DQUICK + Build-depends: array >= 0.3, base >= 4.8 && < 5, binary >= 0.8, diff --git a/src/Data/LLVM/BitCode/BitString.hs b/src/Data/LLVM/BitCode/BitString.hs index 430762f1..77f14604 100644 --- a/src/Data/LLVM/BitCode/BitString.hs +++ b/src/Data/LLVM/BitCode/BitString.hs @@ -1,6 +1,9 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} module Data.LLVM.BitCode.BitString ( @@ -15,17 +18,24 @@ module Data.LLVM.BitCode.BitString , NumBits, NumBytes, pattern Bits', pattern Bytes' , bitCount, bitCount# , bitsToBytes, bytesToBits + , bitMask , addBitCounts , subtractBitCounts ) where +#ifdef QUICK +import Data.Bits ( Bits ) +import Numeric ( showIntAtBase ) +#else import Data.Bits ( bit, bitSizeMaybe, Bits ) -import GHC.Exts import Numeric ( showIntAtBase, showHex ) +#endif +import GHC.Exts import Prelude hiding (take,drop,splitAt) + ---------------------------------------------------------------------- -- Define some convenience newtypes to clarify whether the count of bits or count -- of bytes is being referenced, and to convert between the two. @@ -69,26 +79,27 @@ bytesToBits (NumBytes (I# n#)) = NumBits (I# (n# `uncheckedIShiftL#` 3#)) data BitString = BitString { bsLength :: !NumBits - , bsData :: !Int + , bsData :: !Word -- Note: the bsData was originally an Integer, which allows an essentially -- unlimited size value. However, this adds some overhead to various -- computations, and since LLVM Bitcode is unlikely to ever represent values -- greater than the native size (64 bits) as discrete values. By changing - -- this to @Int@, the use of unboxed calculations is enabled for better - -- performance. - -- - -- The use of Int is potentially unsound because GHC only guarantees it's a - -- signed integer of at least 32-bits. However current implementations in - -- all environments where it's reasonable to use this parser have a 64-bit - -- Int implementation. This can be verified via: + -- this to @Word@ (which is verified to be 64 bits), the use of unboxed + -- calculations is enabled for better performance. -- - -- > import Data.Bits - -- > bitSizeMaybe (maxBound :: Int) >= Just 64 - -- - -- There's no good location here to automate this check (perhaps - -- GetBits.hs:runGetBits?), which is why it isn't verified at runtime. + -- Note that Word is used instead of Word64; in GHC pre 9.x, Word64 was + -- intended to represent a 64-bit value on a 32-bit system. } deriving (Show, Eq) +-- Verify a Word is 64-bits (at compile time) +$(return $ if isTrue# ((int2Word# 3#) `eqWord#` + (((int2Word# 0xF0#) `uncheckedShiftL#` 58#) + `uncheckedShiftRL#` 62#)) + then [] + else error "Word type must be 64-bits!" + ) + + -- | Create an empty BitString emptyBitString :: BitString @@ -100,24 +111,26 @@ emptyBitString = BitString (NumBits 0) 0 -- BitString. joinBitString :: BitString -> BitString -> BitString -joinBitString (BitString (Bits' (I# szA#)) (I# a#)) - (BitString (Bits' (I# szB#)) (I# b#)) = +joinBitString (BitString (Bits' (I# szA#)) (W# a#)) + (BitString (Bits' (I# szB#)) (W# b#)) = BitString { bsLength = NumBits (I# (szA# +# szB#)) - , bsData = I# (a# `orI#` (b# `uncheckedIShiftL#` szA#)) + , bsData = W# (a# `or#` (b# `uncheckedShiftL#` szA#)) } +bitMask :: NumBits -> Word# +bitMask (Bits' (I# len#)) = + ((int2Word# 1#) `uncheckedShiftL#` len#) `minusWord#` (int2Word# 1#) + -- | Given a number of bits to take, and an @Integer@, create a @BitString@. -toBitString :: NumBits -> Int -> BitString -toBitString len@(Bits' (I# len#)) (I# val#) = - let !mask# = (1# `uncheckedIShiftL#` len#) -# 1# - in BitString len (I# (val# `andI#` mask#)) +toBitString :: NumBits -> Word -> BitString +toBitString len (W# val#) = BitString len (W# (val# `and#` (bitMask len))) -- | Extract the referenced Integer value from a BitString -bitStringValue :: BitString -> Int +bitStringValue :: BitString -> Word bitStringValue = bsData @@ -125,6 +138,9 @@ bitStringValue = bsData -- fromInteger to perform the target type conversion). fromBitString :: (Num a, Bits a) => BitString -> a +#ifdef QUICK +fromBitString (BitString _ i) = x +#else fromBitString (BitString l i) = case bitSizeMaybe x of Nothing -> x @@ -137,6 +153,7 @@ fromBitString (BitString l i) = , "(mask=0x" <> showHex i ")" , "could not be parsed into type with only", show n, "bits" ]) +#endif where x = fromInteger ival -- use Num to convert the Integer to the target type ival = toInteger i -- convert input to an Integer for ^^ @@ -164,10 +181,10 @@ take n bs@(BitString l i) -- return the remaining as a smaller BitString. drop :: NumBits -> BitString -> BitString -drop !n !(BitString l i) +drop !n !(BitString l v) | n >= l = emptyBitString | otherwise = let !(I# n#) = bitCount n !(I# l#) = bitCount l - !(I# i#) = i - in BitString (NumBits (I# (l# -# n#))) (I# (i# `uncheckedIShiftRL#` n#)) + !(W# v#) = v + in BitString (NumBits (I# (l# -# n#))) (W# (v# `uncheckedShiftRL#` n#)) diff --git a/src/Data/LLVM/BitCode/Bitstream.hs b/src/Data/LLVM/BitCode/Bitstream.hs index c90ff218..6054de26 100644 --- a/src/Data/LLVM/BitCode/Bitstream.hs +++ b/src/Data/LLVM/BitCode/Bitstream.hs @@ -1,4 +1,7 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MagicHash #-} {-# LANGUAGE PatternGuards #-} module Data.LLVM.BitCode.Bitstream ( @@ -16,27 +19,34 @@ module Data.LLVM.BitCode.Bitstream ( , parseMetadataStringLengths ) where -import Data.LLVM.BitCode.BitString as BS -import Data.LLVM.BitCode.GetBits - import Control.Monad ( unless, replicateM, guard ) import Data.Bits ( Bits ) import qualified Data.ByteString as S import qualified Data.ByteString.Lazy as L +import Data.LLVM.BitCode.BitString as BS +import Data.LLVM.BitCode.GetBits import qualified Data.Map as Map -import Data.Word ( Word8, Word16, Word32 ) +#ifdef QUICK +import GHC.Exts +#endif +import GHC.Word -- Primitive Reads ------------------------------------------------------------- -- | Parse a @Bool@ out of a single bit. boolean :: GetBits Bool -boolean = ((1 :: Word8) ==) . fromBitString <$> fixed (Bits' 1) +boolean = do i <- fixedWord (Bits' 1) + return $ 1 == i -- | Parse a Num type out of n-bits. numeric :: (Num a, Bits a) => NumBits -> GetBits a +#ifdef QUICK +numeric n = fromInteger . toInteger <$> fixedWord n +#else numeric n = fromBitString <$> fixed n +#endif -- | Get a @BitString@ formatted as vbr. @@ -52,13 +62,52 @@ vbr n = loop emptyBitString then loop acc' else return acc' +#ifdef QUICK +vbrWord :: NumBits -> GetBits Word +vbrWord n@(Bits' (I# n#)) = + let !contBitMask# = (int2Word# 1#) `uncheckedShiftL#` (n# -# 1#) + loop = do ic <- fixedWord n + let !(W# ic#) = ic + if isTrue# ((ic# `and#` contBitMask#) `eqWord#` (int2Word# 0#)) + then return ic + else do nxt <- loop + let !(W# nxt#) = nxt + let nxtshft# = nxt# `uncheckedShiftL#` (n# -# 1#) + return (W# ((ic# `xor#` contBitMask#)`or#` nxtshft#)) + in loop +#endif + + -- | Process a variable-bit encoded integer. vbrNum :: (Num a, Bits a) => NumBits -> GetBits a +#ifdef QUICK +vbrNum = fmap (fromInteger . toInteger) . vbrWord +#else vbrNum n = fromBitString <$> vbr n +#endif -- | Decode a 6-bit encoded character. char6 :: GetBits Word8 char6 = do +#ifdef QUICK + (W# w#) <- fixedWord (Bits' 6) + let !i# = word2Int# w# +#if MIN_VERSION_base(4,16,0) + let wordToWord8 = wordToWord8# +#else + let wordToWord8 :: Word# -> Word# + wordToWord8 !a# = a# +#endif + if isTrue# (i# <=# 25#) + then return (W8# (wordToWord8 (w# `plusWord#` (int2Word# 97#)))) + else if isTrue# (i# <=# 51#) + then return (W8# (wordToWord8 (w# `plusWord#` (int2Word# 39#)))) + else if isTrue# (i# <=# 61#) + then return (W8# (wordToWord8 (w# `minusWord#` (int2Word# 4#)))) + else if isTrue# (i# ==# 62#) + then return (fromIntegral (fromEnum '.')) + else return (fromIntegral (fromEnum '_')) +#else word <- numeric $ Bits' 6 case word of n | 0 <= n && n <= 25 -> return (n + 97) @@ -67,6 +116,7 @@ char6 = do 62 -> return (fromIntegral (fromEnum '.')) 63 -> return (fromIntegral (fromEnum '_')) _ -> fail "invalid char6" +#endif -- Bitstream Parsing ----------------------------------------------------------- @@ -86,10 +136,9 @@ parseBitCodeBitstreamLazy :: L.ByteString -> Either String Bitstream parseBitCodeBitstreamLazy = runGetBits getBitCodeBitstream . L.toStrict -- | The magic constant at the beginning of all llvm-bitcode files. -bcMagicConst :: BitString -bcMagicConst = toBitString (Bits' 8) 0x42 - `joinBitString` - toBitString (Bits' 8) 0x43 + +bcMagicConst :: Word +bcMagicConst = 0x4342 -- | Parse a @Bitstream@ from either a normal bitcode file, or a wrapped -- bitcode. @@ -108,21 +157,18 @@ getBitCodeBitstream = label "llvm-bitstream" $ do skip $ Bits' 32 -- CPUType isolate size getBitstream -bcWrapperMagicConst :: BitString -bcWrapperMagicConst = - foldr1 joinBitString [ byte 0xDE, byte 0xC0, byte 0x17, byte 0x0B] - where - byte = toBitString (Bits' 8) +bcWrapperMagicConst :: Word +bcWrapperMagicConst = 0x0b16c0de guardWrapperMagic :: GetBits () guardWrapperMagic = do - magic <- fixed (Bits' 32) + magic <- fixedWord (Bits' 32) guard (magic == bcWrapperMagicConst) -- | Parse a @Bitstream@. getBitstream :: GetBits Bitstream getBitstream = label "bitstream" $ do - bc <- fixed $ Bits' 16 + bc <- fixedWord $ Bits' 16 unless (bc == bcMagicConst) (fail "Invalid magic number") appMagic <- numeric $ Bits' 16 entries <- getTopLevelEntries diff --git a/src/Data/LLVM/BitCode/GetBits.hs b/src/Data/LLVM/BitCode/GetBits.hs index 1440646a..f831e5f6 100644 --- a/src/Data/LLVM/BitCode/GetBits.hs +++ b/src/Data/LLVM/BitCode/GetBits.hs @@ -6,7 +6,7 @@ module Data.LLVM.BitCode.GetBits ( GetBits , runGetBits - , fixed, align32bits + , fixed, fixedWord, align32bits , bytestring , label , isolate @@ -159,144 +159,151 @@ extractFromByteString :: Int# {-^ the last bit accessible in the ByteString -} -> Int# {-^ the bit to start extraction at -} -> Int# {-^ the number of bits to extract -} -> ByteString {-^ the ByteString to extract from -} - -> Either String (() -> (# Int#, Int# #)) + -> Either String (() -> (# Word#, Int# #)) extractFromByteString !bitLim# !sBit# !nbits# bs = - if isTrue# ((1# `uncheckedIShiftL#` (nbits#)) /=# 0#) +#ifndef QUICK + if isTrue# ((1# `uncheckedIShiftL#` (nbits#)) ==# 0#) -- (nbits# -# 1#) above would allow 64-bit value extraction, but this -- function cannot actually support a size of 64, because Int# is signed, -- so it doesn't properly use the high bit in numeric operations. This -- seems to be OK at this point because LLVM bitcode does not attempt to -- encode actual 64-bit values. then + -- BitString stores an Int, but number of extracted bits is larger than + -- an Int can represent. + Left "Attempt to extracted large value" + else +#endif let !updPos# = sBit# +# nbits# in if isTrue# (updPos# <=# bitLim#) then let !s8# = sBit# `uncheckedIShiftRL#` 3# !hop# = sBit# `andI#` 7# !r8# = ((hop# +# nbits# +# 7#) `uncheckedIShiftRL#` 3#) - !mask# = (1# `uncheckedIShiftL#` nbits#) -# 1# + !mask# = bitMask (Bits' (I# nbits#)) -- Here, s8# is the size in 8-bit bytes, hop# is the number of -- bits shifted from the byte boundary, r8# is the rounded number -- of bytes actually needed to retrieve to get the value to -- account for shifting, and mask# is the mask for the final -- target set of bits after shifting. #if MIN_VERSION_base(4,16,0) - word8ToInt !w8# = word2Int# (word8ToWord# w8#) + word8ToWord :: Word8# -> Word# + word8ToWord = word8ToWord# #else - -- technically #if !MIN_VERSION_ghc_prim(0,8,0), for GHC 9.2, but - -- since ghc_prim isn't a direct dependency and is re-exported - -- from base, this define needs to reference the base version. - word8ToInt = word2Int# + word8ToWord :: Word# -> Word# + word8ToWord !w# = w# #endif -- getB# gets a value from a byte starting at bit0 of the byte - getB# :: Int# -> Int# + getB# :: Int# -> Word# getB# !i# = case i# of 0# -> let !(W8# w#) = bs `BS.index` (I# s8#) - in word8ToInt w# + in word8ToWord w# _ -> let !(W8# w#) = (bs `BS.index` (I# (s8# +# i#))) - in (word8ToInt w#) `uncheckedIShiftL#` (8# *# i#) + in (word8ToWord w#) `uncheckedShiftL#` (8# *# i#) -- getSB# gets a value from a byte shifting from a non-zero start -- bit within the byte. - getSB# :: Int# -> Int# + getSB# :: Int# -> Word# getSB# !i# = case i# of 0# -> let !(W8# w#) = bs `BS.index` (I# s8#) - in (word8ToInt w#) `uncheckedIShiftRL#` hop# + in (word8ToWord w#) `uncheckedShiftRL#` hop# _ -> let !(W8# w#) = bs `BS.index` (I# (s8# +# i#)) !shft# = (8# *# i#) -# hop# - in (word8ToInt w#) `uncheckedIShiftL#` shft# - !vi# = mask# `andI#` + in (word8ToWord w#) `uncheckedShiftL#` shft# + !vi# = mask# `and#` (case hop# of 0# -> case r8# of 1# -> getB# 0# - 2# -> getB# 0# `orI#` getB# 1# - 3# -> getB# 0# `orI#` getB# 1# `orI#` + 2# -> getB# 0# `or#` getB# 1# + 3# -> getB# 0# `or#` getB# 1# `or#` getB# 2# - 4# -> getB# 0# `orI#` getB# 1# `orI#` - getB# 2# `orI#` getB# 3# - 5# -> getB# 0# `orI#` getB# 1# `orI#` - getB# 2# `orI#` getB# 3# `orI#` + 4# -> getB# 0# `or#` getB# 1# `or#` + getB# 2# `or#` getB# 3# + 5# -> getB# 0# `or#` getB# 1# `or#` + getB# 2# `or#` getB# 3# `or#` getB# 4# - 6# -> getB# 0# `orI#` getB# 1# `orI#` - getB# 2# `orI#` getB# 3# `orI#` - getB# 4# `orI#` getB# 5# - 7# -> getB# 0# `orI#` getB# 1# `orI#` - getB# 2# `orI#` getB# 3# `orI#` - getB# 4# `orI#` getB# 5# `orI#` + 6# -> getB# 0# `or#` getB# 1# `or#` + getB# 2# `or#` getB# 3# `or#` + getB# 4# `or#` getB# 5# + 7# -> getB# 0# `or#` getB# 1# `or#` + getB# 2# `or#` getB# 3# `or#` + getB# 4# `or#` getB# 5# `or#` getB# 6# - 8# -> getB# 0# `orI#` getB# 1# `orI#` - getB# 2# `orI#` getB# 3# `orI#` - getB# 4# `orI#` getB# 5# `orI#` - getB# 6# `orI#` getB# 7# + 8# -> getB# 0# `or#` getB# 1# `or#` + getB# 2# `or#` getB# 3# `or#` + getB# 4# `or#` getB# 5# `or#` + getB# 6# `or#` getB# 7# -- This is the catch-all loop for other sizes -- not addressed above. - _ -> let join !(W8# w#) !(I# a#) = - I# ((a# `uncheckedIShiftL#` 8#) - `orI#` (word8ToInt w#)) + _ -> let join !(W8# w#) !(W# a#) = + W# ((a# `uncheckedShiftL#` 8#) + `or#` (word8ToWord w#)) bs' = BS.take (I# (r8# +# 2#)) $ BS.drop (I# s8#) bs - !(I# v#) = BS.foldr join (0::Int) bs' - in mask# `andI#` (v# `uncheckedIShiftRL#` hop#) + !(W# v#) = BS.foldr join (0::Word) bs' + in v# `uncheckedShiftRL#` hop# _ -> case r8# of 1# -> getSB# 0# - 2# -> getSB# 0# `orI#` getSB# 1# - 3# -> getSB# 0# `orI#` getSB# 1# `orI#` + 2# -> getSB# 0# `or#` getSB# 1# + 3# -> getSB# 0# `or#` getSB# 1# `or#` getSB# 2# - 4# -> getSB# 0# `orI#` getSB# 1# `orI#` - getSB# 2# `orI#` getSB# 3# - 5# -> getSB# 0# `orI#` getSB# 1# `orI#` - getSB# 2# `orI#` getSB# 3# `orI#` + 4# -> getSB# 0# `or#` getSB# 1# `or#` + getSB# 2# `or#` getSB# 3# + 5# -> getSB# 0# `or#` getSB# 1# `or#` + getSB# 2# `or#` getSB# 3# `or#` getSB# 4# - 6# -> getSB# 0# `orI#` getSB# 1# `orI#` - getSB# 2# `orI#` getSB# 3# `orI#` - getSB# 4# `orI#` getSB# 5# - 7# -> getSB# 0# `orI#` getSB# 1# `orI#` - getSB# 2# `orI#` getSB# 3# `orI#` - getSB# 4# `orI#` getSB# 5# `orI#` + 6# -> getSB# 0# `or#` getSB# 1# `or#` + getSB# 2# `or#` getSB# 3# `or#` + getSB# 4# `or#` getSB# 5# + 7# -> getSB# 0# `or#` getSB# 1# `or#` + getSB# 2# `or#` getSB# 3# `or#` + getSB# 4# `or#` getSB# 5# `or#` getSB# 6# - 8# -> getSB# 0# `orI#` getSB# 1# `orI#` - getSB# 2# `orI#` getSB# 3# `orI#` - getSB# 4# `orI#` getSB# 5# `orI#` - getSB# 6# `orI#` getSB# 7# + 8# -> getSB# 0# `or#` getSB# 1# `or#` + getSB# 2# `or#` getSB# 3# `or#` + getSB# 4# `or#` getSB# 5# `or#` + getSB# 6# `or#` getSB# 7# -- n.b. these are hand-unrolled cases for common - -- sizes this is called for. - 9# -> getSB# 0# `orI#` getSB# 1# `orI#` - getSB# 2# `orI#` getSB# 3# `orI#` - getSB# 4# `orI#` getSB# 5# `orI#` - getSB# 6# `orI#` getSB# 7# `orI#` + -- sizes this function is called for. + 9# -> getSB# 0# `or#` getSB# 1# `or#` + getSB# 2# `or#` getSB# 3# `or#` + getSB# 4# `or#` getSB# 5# `or#` + getSB# 6# `or#` getSB# 7# `or#` getSB# 8# - 18# -> getSB# 0# `orI#` getSB# 1# `orI#` - getSB# 2# `orI#` getSB# 3# `orI#` - getSB# 4# `orI#` getSB# 5# `orI#` - getSB# 6# `orI#` getSB# 7# `orI#` - getSB# 8# `orI#` getSB# 9# `orI#` - getSB# 10# `orI#` getSB# 11# `orI#` - getSB# 12# `orI#` getSB# 13# `orI#` - getSB# 14# `orI#` getSB# 15# `orI#` - getSB# 16# `orI#` getSB# 17# + 18# -> getSB# 0# `or#` getSB# 1# `or#` + getSB# 2# `or#` getSB# 3# `or#` + getSB# 4# `or#` getSB# 5# `or#` + getSB# 6# `or#` getSB# 7# `or#` + getSB# 8# `or#` getSB# 9# `or#` + getSB# 10# `or#` getSB# 11# `or#` + getSB# 12# `or#` getSB# 13# `or#` + getSB# 14# `or#` getSB# 15# `or#` + getSB# 16# `or#` getSB# 17# -- This is the catch-all loop for other sizes -- not addressed above. - _ -> let join !(W8# w#) !(I# a#) = - I# ((a# `uncheckedIShiftL#` 8#) - `orI#` (word8ToInt w#)) + _ -> let join !(W8# w#) !(W# a#) = + W# ((a# `uncheckedShiftL#` 8#) + `or#` (word8ToWord w#)) bs' = BS.take (I# (r8# +# 2#)) $ BS.drop (I# s8#) bs - !(I# v#) = BS.foldr join (0::Int) bs' - in mask# `andI#` (v# `uncheckedIShiftRL#` hop#) + !(W# v#) = BS.foldr join (0::Word) bs' + in v# `uncheckedShiftRL#` hop# ) in Right $ \_ -> (# vi#, updPos# #) else Left "Attempt to read bits past limit" - else - -- BitString stores an Int, but number of extracted bits is larger than - -- an Int can represent. - Left "Attempt to extracted large value" -- Basic Interface ------------------------------------------------------------- -- | Read zeros up to an alignment of 32-bits. align32bits :: GetBits () +#ifdef QUICK +align32bits = GetBits $ \ !pos# _ -> + let !(# curBit#, ttlBits# #) = pos# + !cadj# = curBit# +# 0x1f# + in (# Right (), (# cadj# `xorI#` (cadj# `andI#` 0x1f#), ttlBits# #) #) +#else align32bits = GetBits $ \ !pos# inp -> let !(# curBit#, ttlBits# #) = pos# !s32# = curBit# `andI#` 31# @@ -308,10 +315,11 @@ align32bits = GetBits $ \ !pos# inp -> else case extractFromByteString ttlBits# curBit# r32# inp of Right getRes -> let !(# vi#, newPos# #) = getRes () - in if isTrue# (vi# ==# 0#) + in if isTrue# (vi# `eqWord#` (int2Word# 0#)) then (# Right (), (# newPos#, ttlBits# #) #) else (# Left nonZero, pos# #) Left e -> (# Left e, pos# #) +#endif -- | Read out n bits as a @BitString@. @@ -322,11 +330,21 @@ fixed !(Bits' (I# n#)) = GetBits case extractFromByteString lim# cur# n# inp of Right getRes -> let !(# v#, p# #) = getRes () - in (# pure $ toBitString (Bits' (I# n#)) (I# v#) + in (# pure $ toBitString (Bits' (I# n#)) (W# v#) , (# p#, lim# #) #) Left e -> (# Left e, s #) +fixedWord :: NumBits -> GetBits Word +fixedWord !(Bits' (I# n#)) = GetBits + $ \ !s@(# cur#, lim# #) -> + \inp -> + case extractFromByteString lim# cur# n# inp of + Right getRes -> + let !(# v#, p# #) = getRes () + in (# pure (W# v#) , (# p#, lim# #) #) + Left e -> (# Left e, s #) + -- | Read out n bytes as a @ByteString@, aligning to a 32-bit boundary before and -- after. @@ -349,11 +367,15 @@ bytestring n@(Bytes' nbytes) = do -- | Add a label to the error tag stack. label :: String -> GetBits a -> GetBits a +#ifdef QUICK +label _ m = m +#else label l m = GetBits $ \ !pos# inp -> let !(# j, n# #) = unGetBits m pos# inp in case j of Left e -> (# Left $ e <> "\n " <> l, n# #) Right r -> (# Right r, n# #) +#endif -- | Isolate input to a sub-span of the specified byte length. @@ -382,5 +404,5 @@ skip !(Bits' (I# n#)) = in if isTrue# (newLoc# ># lim#) then \_ -> (# Left "skipped past end of bytestring" , newPos# - #) + #) else \_ -> (# Right (), newPos# #) diff --git a/src/Data/LLVM/BitCode/IR.hs b/src/Data/LLVM/BitCode/IR.hs index 038dd423..11545975 100644 --- a/src/Data/LLVM/BitCode/IR.hs +++ b/src/Data/LLVM/BitCode/IR.hs @@ -5,7 +5,6 @@ module Data.LLVM.BitCode.IR where import Data.LLVM.BitCode.Bitstream -import Data.LLVM.BitCode.BitString import Data.LLVM.BitCode.IR.Blocks import Data.LLVM.BitCode.IR.Module (parseModuleBlock) import Data.LLVM.BitCode.Match @@ -21,9 +20,7 @@ import Data.Word (Word16) -- | The magic number that identifies a @Bitstream@ structure as LLVM IR. llvmIrMagic :: Word16 -llvmIrMagic = fromBitString (toBitString (Bits' 8) 0xc0 - `joinBitString` - toBitString (Bits' 8) 0xde) +llvmIrMagic = 0xdec0 -- | Parse an LLVM Module out of a Bitstream object. parseModule :: Bitstream -> Parse Module diff --git a/src/Data/LLVM/BitCode/IR/Constants.hs b/src/Data/LLVM/BitCode/IR/Constants.hs index 581ee527..dd3969e4 100644 --- a/src/Data/LLVM/BitCode/IR/Constants.hs +++ b/src/Data/LLVM/BitCode/IR/Constants.hs @@ -425,10 +425,14 @@ parseConstantEntry t (getTy,cs) (fromEntry -> Just r) = _asmDialect = mask `shiftR` 2 asmStrSize <- field 1 numeric +#ifndef QUICK Assert.recordSizeGreater r (1 + asmStrSize) +#endif constStrSize <- field (2 + asmStrSize) numeric +#ifndef QUICK Assert.recordSizeGreater r (2 + asmStrSize + constStrSize) +#endif asmStr <- fmap UTF8.decode $ parseSlice r 2 asmStrSize char constStr <- fmap UTF8.decode $ parseSlice r (3 + asmStrSize) constStrSize char diff --git a/src/Data/LLVM/BitCode/IR/Function.hs b/src/Data/LLVM/BitCode/IR/Function.hs index a7696faf..9fbb65fd 100644 --- a/src/Data/LLVM/BitCode/IR/Function.hs +++ b/src/Data/LLVM/BitCode/IR/Function.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE RecursiveDo #-} @@ -16,13 +17,15 @@ import Data.LLVM.BitCode.IR.Attrs import Data.LLVM.BitCode.Match import Data.LLVM.BitCode.Parse import Data.LLVM.BitCode.Record -import Data.LLVM.BitCode.Record import Text.LLVM.AST import Text.LLVM.Labels import Text.LLVM.PP -import Control.Monad (when,unless,mplus,mzero,foldM,(<=<)) +import Control.Monad ( when, mplus, mzero, foldM, (<=<) ) +#ifndef QUICK +import Control.Monad ( unless ) +#endif import Data.Bits (shiftR,bit,shiftL,testBit,(.&.),(.|.),complement,Bits) import Data.Int (Int32) import Data.Word (Word32) @@ -384,7 +387,11 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of -- [n] 1 -> label "FUNC_CODE_DECLARE_BLOCKS" - (Assert.recordSizeGreater r 0 >> return d) + ( +#ifndef QUICK + Assert.recordSizeGreater r 0 >> +#endif + return d) -- [opval,ty,opval,opcode] 2 -> label "FUNC_CODE_INST_BINOP" $ do @@ -413,7 +420,9 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of 3 -> label "FUNC_CODE_INST_CAST" $ do let field = parseField r (tv,ix) <- getValueTypePair t r 0 +#ifndef QUICK Assert.recordSizeIn r [ix + 2] +#endif resty <- getType =<< field ix numeric cast' <- field (ix+1) castOp result resty (cast' tv resty) d @@ -464,13 +473,19 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of 10 -> label "FUNC_CODE_INST_RET" $ case length (recordFields r) of 0 -> effect RetVoid d _ -> do +#ifdef QUICK + (tv, _) <- getValueTypePair t r 0 +#else (tv, ix) <- getValueTypePair t r 0 Assert.recordSizeIn r [ix] +#endif effect (Ret tv) d -- [bb#,bb#,cond] or [bb#] 11 -> label "FUNC_CODE_INST_BR" $ do +#ifndef QUICK Assert.recordSizeIn r [1, 3] +#endif let field = parseField r bb1 <- field 0 numeric let jump = effect (Jump bb1) d @@ -525,7 +540,9 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of -- [attrs,cc,normBB,unwindBB,fnty,op0,op1..] 13 -> label "FUNC_CODE_INST_INVOKE" $ do +#ifndef QUICK Assert.recordSizeGreater r 3 +#endif let field = parseField r ccinfo <- field 1 unsigned normal <- field 2 numeric @@ -578,7 +595,9 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of -- [instty,opty,op,align] 19 -> label "FUNC_CODE_INST_ALLOCA" $ do +#ifndef QUICK Assert.recordSizeIn r [4] +#endif let field = parseField r @@ -598,16 +617,22 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of explicitType = testBit align 6 ity = if explicitType then PtrTo instty else instty +#ifdef QUICK + ret <- return instty +#else ret <- if explicitType then return instty else Assert.elimPtrTo "In return type:" instty +#endif result ity (Alloca ret sval (Just aval)) d -- [opty,op,align,vol] 20 -> label "FUNC_CODE_INST_LOAD" $ do (tv,ix) <- getValueTypePair t r 0 +#ifndef QUICK Assert.recordSizeIn r [ix + 2, ix + 3] +#endif (ret,ix') <- if length (recordFields r) == ix + 3 @@ -626,7 +651,9 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of -- 22 is unused 23 -> label "FUNC_CODE_INST_VAARG" $ do +#ifndef QUICK Assert.recordSizeGreater r 2 +#endif let field = parseField r ty <- getType =<< field 0 numeric op <- getValue t ty =<< field 1 numeric @@ -650,8 +677,10 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of 26 -> label "FUNC_CODE_INST_EXTRACTVAL" $ do (tv, ix) <- getValueTypePair t r 0 +#ifndef QUICK when (length (recordFields r) == ix) $ fail "`extractval` instruction had zero indices" +#endif ixs <- parseIndexes r ix let instr = ExtractValue tv ixs @@ -664,9 +693,11 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of 27 -> label "FUNC_CODE_INST_INSERTVAL" $ do (tv,ix) <- getValueTypePair t r 0 +#ifndef QUICK -- See comment in FUNC_CODE_INST_EXTRACTVAL when (length (recordFields r) == ix) $ fail "Invalid instruction with zero indices" +#endif (elt,ix') <- getValueTypePair t r ix ixs <- parseIndexes r ix' @@ -699,32 +730,37 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of -- [paramattrs, cc, mb fmf, mb fnty, fnid, arg0 .. arg n, varargs] 34 -> label "FUNC_CODE_INST_CALL" $ do +#ifndef QUICK Assert.recordSizeGreater r 2 +#endif let field = parseField r -- pal <- field 0 numeric -- N.B. skipping param attributes ccinfo <- field 1 numeric let ix0 = if testBit ccinfo 17 then 3 else 2 -- N.B. skipping fast-math flags - (mbFnTy, ix1) <- if testBit (ccinfo :: Word32) callExplicitTypeBit + r1 <- if testBit (ccinfo :: Word32) callExplicitTypeBit then do fnTy <- getType =<< field ix0 numeric return (Just fnTy, ix0+1) else return (Nothing, ix0) + let ix1 = snd r1 (Typed opTy fn, ix2) <- getValueTypePair t r ix1 `mplus` fail "Invalid record" op <- Assert.elimPtrTo "Callee is not a pointer type" opTy - fnty <- case mbFnTy of +#ifdef QUICK + let fnty = op +#else + fnty <- case fst r1 of Just ty | ty == op -> return op - | otherwise -> fail "Explicit call type does not match \ - \pointee type of callee operand" - + | otherwise -> fail ("Explicit call type does not match " + <> " pointee type of callee operand") Nothing -> case op of FunTy{} -> return op _ -> fail "Callee is not of pointer to function type" - +#endif label (show fn) $ do (ret,as,va) <- elimFunTy fnty `mplus` fail "invalid CALL record" @@ -733,7 +769,9 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of -- [Line,Col,ScopeVal, IAVal] 35 -> label "FUNC_CODE_DEBUG_LOC" $ do +#ifndef QUICK Assert.recordSizeGreater r 3 +#endif let field = parseField r line <- field 0 numeric @@ -761,7 +799,9 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of -- [ordering, synchscope] 36 -> label "FUNC_CODE_INST_FENCE" $ do +#ifndef QUICK Assert.recordSizeIn r [2] +#endif mordval <- getDecodedOrdering =<< parseField r 0 unsigned -- TODO: parse scope case mordval of @@ -785,7 +825,9 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of -- [ty,val,val,num,id0,val0...] 40 -> label "FUNC_CODE_LANDINGPAD_OLD" $ do +#ifndef QUICK Assert.recordSizeGreater r 3 +#endif let field = parseField r ty <- getType =<< field 0 numeric (persFn,ix) <- getValueTypePair t r 1 @@ -799,7 +841,9 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of 41 -> label "FUNC_CODE_LOADATOMIC" $ do (tv,ix) <- getValueTypePair t r 0 +#ifndef QUICK Assert.recordSizeIn r [ix + 4, ix+ 5] +#endif (ret,ix') <- if length (recordFields r) == ix + 5 @@ -809,18 +853,24 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of else do ty <- Assert.elimPtrTo "" (typedType tv) return (ty, ix) +#ifndef QUICK Assert.ptrTo "load atomic : , *" tv (Typed ret ()) +#endif ordval <- getDecodedOrdering =<< parseField r (ix' + 2) unsigned +#ifndef QUICK when (ordval `elem` Nothing:map Just [Release, AcqRel]) $ fail $ "Invalid atomic ordering: " ++ show ordval +#endif aval <- parseField r ix' numeric let align | aval > 0 = Just (bit aval `shiftR` 1) | otherwise = Nothing +#ifndef QUICK when (ordval /= Nothing && align == Nothing) (fail "Invalid record") +#endif result ret (Load (tv { typedType = PtrTo ret }) ordval align) d @@ -836,8 +886,10 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of (ptr,ix) <- getValueTypePair t r 0 (val,ix') <- getValueTypePair t r ix +#ifndef QUICK Assert.recordSizeIn r [ix' + 2] Assert.ptrTo "store : , * " ptr val +#endif aval <- field ix' numeric let align | aval > 0 = Just (bit aval `shiftR` 1) @@ -849,13 +901,17 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of (ptr, ix) <- getValueTypePair t r 0 (val, ix') <- getValueTypePair t r ix +#ifndef QUICK Assert.recordSizeIn r [ix' + 4] Assert.ptrTo "store atomic : , * " ptr val +#endif -- TODO: There's no spot in the AST for this ordering. Should there be? ordering <- getDecodedOrdering =<< parseField r (ix' + 2) unsigned +#ifndef QUICK when (ordering `elem` Nothing:map Just [Acquire, AcqRel]) $ fail $ "Invalid atomic ordering: " ++ show ordering +#endif -- TODO: parse sync scope (ssid) @@ -874,6 +930,7 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of new <- getValue t (typedType val) =<< parseField r ix' numeric let ix'' = ix' + 1 -- TODO: is this right? +#ifndef QUICK -- TODO: record size assertion -- Assert.recordSizeGreater r (ix'' + 5) @@ -885,6 +942,7 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of , "cmp type: " ++ show (typedType val) , "new type: " ++ show (typedType new) ] +#endif volatile <- parseField r ix'' boolean @@ -926,7 +984,9 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of result ty (CmpXchg weak volatile ptr val new Nothing successOrdering failureOrdering) d 47 -> label "FUNC_CODE_LANDINGPAD" $ do +#ifndef QUICK Assert.recordSizeGreater r 2 +#endif let field = parseField r ty <- getType =<< field 0 numeric isCleanup <- (/=(0::Int)) <$> field 1 numeric @@ -935,22 +995,30 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of result ty (LandingPad ty Nothing isCleanup clauses) d 48 -> label "FUNC_CODE_CLEANUPRET" $ do +#ifndef QUICK -- Assert.recordSizeIn r [1, 2] +#endif notImplemented 49 -> label "FUNC_CODE_CATCHRET" $ do +#ifndef QUICK -- Assert.recordSizeIn r [2] +#endif notImplemented 50 -> label "FUNC_CODE_CATCHPAD" $ do notImplemented 51 -> label "FUNC_CODE_CLEANUPPAD" $ do +#ifndef QUICK -- Assert.recordSizeGreater r [1] +#endif notImplemented 52 -> label "FUNC_CODE_CATCHSWITCH" $ do +#ifndef QUICK -- Assert.recordSizeGreater r [1] +#endif notImplemented -- 53 is unused @@ -993,8 +1061,8 @@ parseFunctionBlockEntry _ t d (fromEntry -> Just r) = case recordCode r of fnty <- case mbFnTy of Just ty | ty == op -> return op - | otherwise -> fail "Explicit call type does not match \ - \pointee type of callee operand" + | otherwise -> fail ("Explicit call type does not match " + <> "pointee type of callee operand") Nothing -> case op of @@ -1058,12 +1126,16 @@ parseFunctionBlockEntry globals t d (metadataBlockId -> Just es) = do else return d -- silently drop unexpected local unnamed metadata parseFunctionBlockEntry globals t d (metadataAttachmentBlockId -> Just es) = do +#ifdef QUICK + (_,_,instrAtt,fnAtt,_) <- parseMetadataBlock globals t es +#else (_,(globalUnnamedMds, localUnnamedMds),instrAtt,fnAtt,_) <- parseMetadataBlock globals t es unless (null localUnnamedMds) (fail "parseFunctionBlockEntry PANIC: unexpected local unnamed metadata") unless (null globalUnnamedMds) (fail "parseFunctionBlockEntry PANIC: unexpected global unnamed metadata") +#endif return d { partialBody = addInstrAttachments instrAtt (partialBody d) , partialMetadata = Map.union fnAtt (partialMetadata d) } @@ -1164,14 +1236,17 @@ parseAtomicRMW old t r d = do (ptr, ix0) <- getValueTypePair t r 0 (val, ix1) <- case typedType ptr of - PtrTo ty@(PrimType prim) -> do + PtrTo ty -> do +#ifndef QUICK -- Catch pointers of the wrong type - when (case prim of + when (let (PrimType prim) = ty in + case prim of Integer _ -> False FloatType _ -> False _ -> True) $ fail $ "Expected pointer to integer or float, found " ++ show ty +#endif if old then -- FUNC_CODE_INST_ATOMICRMW_OLD @@ -1364,12 +1439,13 @@ parseNewSwitchLabels :: Word32 -> Record -> Int -> Int -> Parse [(Integer,Int)] parseNewSwitchLabels width r = loop where field = parseField r - len = length (recordFields r) -- parse each group of cases as one or more numbers, and a basic block. loop numCases n | numCases <= 0 = return [] - | n >= len = fail "invalid SWITCH record" +#ifndef QUICK + | n >= length (recordFields r) = fail "invalid SWITCH record" +#endif | otherwise = do numItems <- field n numeric (ls,n') <- parseItems numItems (n + 1) diff --git a/src/Data/LLVM/BitCode/IR/Metadata.hs b/src/Data/LLVM/BitCode/IR/Metadata.hs index 88343dff..b6222408 100644 --- a/src/Data/LLVM/BitCode/IR/Metadata.hs +++ b/src/Data/LLVM/BitCode/IR/Metadata.hs @@ -1,9 +1,11 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE RecursiveDo #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE ViewPatterns #-} @@ -31,7 +33,10 @@ import Text.LLVM.Labels import qualified Codec.Binary.UTF8.String as UTF8 (decode) import Control.Applicative ((<|>)) import Control.Exception (throw) -import Control.Monad (foldM, guard, mplus, when) +import Control.Monad ( foldM, guard, mplus ) +#ifndef QUICK +import Control.Monad ( when ) +#endif import Data.Bits (shiftR, testBit, shiftL, (.&.), (.|.), bit, complement) import Data.Data (Data) import Data.Typeable (Typeable) @@ -402,6 +407,11 @@ parseMetadataBlock globals vt es = label "METADATA_BLOCK" $ do parseMetadataEntry :: ValueTable -> MetadataTable -> PartialMetadata -> Entry -> Parse PartialMetadata parseMetadataEntry vt mt pm (fromEntry -> Just r) = +#ifdef QUICK + let assertRecordSizeBetween (_ :: Integer) (_ :: Integer) = return () + assertRecordSizeIn (_ :: [Integer]) = return () + assertRecordSizeAtLeast (_ :: Integer) = return () +#else let msg = [ "Are you sure you're using a supported version of LLVM/Clang?" , "Check here: https://github.com/GaloisInc/llvm-pretty-bc-parser" ] @@ -424,6 +434,7 @@ parseMetadataEntry vt mt pm (fromEntry -> Just r) = fail $ unlines $ [ "Invalid record size: " ++ show len , "Expected size of " ++ show lb ++ " or greater" ] ++ msg +#endif in case recordCode r of -- [values] 1 -> label "METADATA_STRING" $ do @@ -434,9 +445,13 @@ parseMetadataEntry vt mt pm (fromEntry -> Just r) = 2 -> label "METADATA_VALUE" $ do assertRecordSizeIn [2] let field = parseField r +#ifdef QUICK + (_ :: Int) <- field 0 numeric +#else ty <- getType =<< field 0 numeric when (ty == PrimType Metadata || ty == PrimType Void) (fail "invalid record") +#endif cxt <- getContext ix <- field 1 numeric @@ -499,8 +514,10 @@ parseMetadataEntry vt mt pm (fromEntry -> Just r) = -- [m x [value, [n x [id, mdnode]]] 11 -> label "METADATA_ATTACHMENT" $ do let recordSize = length (recordFields r) +#ifndef QUICK when (recordSize == 0) (fail "Invalid record") +#endif if recordSize `mod` 2 == 0 then label "function attachment" $ do att <- Map.fromList <$> parseAttachment r 0 @@ -774,12 +791,14 @@ parseMetadataEntry vt mt pm (fromEntry -> Just r) = | not hasSPFlags = recordSize >= 21 | otherwise = True +#ifndef QUICK -- Some additional sanity checking when (not hasSPFlags && hasUnit) (assertRecordSizeBetween 19 21) when (hasSPFlags && not hasUnit) (fail "DISubprogram record has subprogram flags, but does not have unit. Invalid record.") +#endif ctx <- getContext @@ -940,9 +959,11 @@ parseMetadataEntry vt mt pm (fromEntry -> Just r) = _alignInBits <- if hasAlignment then do n <- parseField r (adj 8) numeric +#ifndef QUICK when ((n :: Word64) > fromIntegral (maxBound :: Word32)) (fail "Alignment value is too large") - return (fromIntegral n :: Word32) +#endif + return (fromIntegral (n :: Word64) :: Word32) else return 0 @@ -1020,14 +1041,18 @@ parseMetadataEntry vt mt pm (fromEntry -> Just r) = count <- parseField r 0 numeric offset <- parseField r 1 numeric bs <- parseField r 2 fieldBlob +#ifndef QUICK when (count == 0) (fail "Invalid record: metadata strings with no strings") when (offset > S.length bs) (fail "Invalid record: metadata strings corrupt offset") +#endif let (bsLengths, bsStrings) = S.splitAt offset bs lengths <- either fail return $ parseMetadataStringLengths count bsLengths +#ifndef QUICK when (sum lengths > S.length bsStrings) (fail "Invalid record: metadata strings truncated") +#endif let strings = snd (mapAccumL f bsStrings lengths) where f s i = case S.splitAt i s of (str, rest) -> (rest, Char8.unpack str) @@ -1036,9 +1061,11 @@ parseMetadataEntry vt mt pm (fromEntry -> Just r) = -- [ valueid, n x [id, mdnode] ] 36 -> label "METADATA_GLOBAL_DECL_ATTACHMENT" $ do +#ifndef QUICK -- the record will always be of odd length when (mod (length (recordFields r)) 2 == 0) (fail "Invalid record") +#endif valueId <- parseField r 0 numeric sym <- case lookupValueTableAbs valueId vt of diff --git a/src/Data/LLVM/BitCode/Parse.hs b/src/Data/LLVM/BitCode/Parse.hs index 9b81d528..15e51cd1 100644 --- a/src/Data/LLVM/BitCode/Parse.hs +++ b/src/Data/LLVM/BitCode/Parse.hs @@ -540,8 +540,12 @@ getTypeSymtab = Parse (symTypeSymtab . envSymtab <$> ask) -- | Label a sub-computation with its context. label :: String -> Parse a -> Parse a +#ifdef QUICK +label _ m = m +#else label l m = Parse $ do local (addLabel l) (unParse m) +#endif -- | Fail, taking into account the current context. failWithContext :: String -> Parse a