From 598039ee4b2a771c0d9fabdba83b2ccb348b1545 Mon Sep 17 00:00:00 2001 From: Christiaan Baaij Date: Fri, 7 Feb 2020 18:02:32 +0100 Subject: [PATCH] Keep all casts, and cast (Signal a ~ a) where appropriate --- clash-ghc/src-ghc/Clash/GHC/Evaluator.hs | 3 + clash-ghc/src-ghc/Clash/GHC/GHC2Core.hs | 90 ++++++++----- clash-lib/src/Clash/Core/Evaluator.hs | 18 ++- clash-lib/src/Clash/Core/Evaluator/Types.hs | 3 + clash-lib/src/Clash/Core/Type.hs | 6 +- clash-lib/src/Clash/Core/Util.hs | 18 ++- clash-lib/src/Clash/Netlist.hs | 6 + clash-lib/src/Clash/Netlist/BlackBox.hs | 5 +- .../src/Clash/Normalize/Transformations.hs | 124 +++++++++++++----- clash-lib/src/Clash/Normalize/Util.hs | 1 + clash-lib/src/Clash/Rewrite/Util.hs | 44 ++++--- 11 files changed, 220 insertions(+), 98 deletions(-) diff --git a/clash-ghc/src-ghc/Clash/GHC/Evaluator.hs b/clash-ghc/src-ghc/Clash/GHC/Evaluator.hs index 2cd2b269bf..a1ba925d21 100644 --- a/clash-ghc/src-ghc/Clash/GHC/Evaluator.hs +++ b/clash-ghc/src-ghc/Clash/GHC/Evaluator.hs @@ -3465,6 +3465,9 @@ naturalLiteral v = DC dc [Left (Literal (ByteArrayLiteral (Vector.Vector _ _ (ByteArray.ByteArray ba))))] | dcTag dc == 2 -> Just (Jp# (BN# ba)) + CastValue v0 _ _ + | Just n <- naturalLiteral v0 + -> Just n _ -> Nothing integerLiterals' :: [Value] -> [Integer] diff --git a/clash-ghc/src-ghc/Clash/GHC/GHC2Core.hs b/clash-ghc/src-ghc/Clash/GHC/GHC2Core.hs index 21f83e65e1..79a0e03ba3 100644 --- a/clash-ghc/src-ghc/Clash/GHC/GHC2Core.hs +++ b/clash-ghc/src-ghc/Clash/GHC/GHC2Core.hs @@ -91,7 +91,7 @@ import TyCon (AlgTyConRhs (..), TyCon, tyConName, tyConArity, tyConDataCons, tyConKind, tyConName, tyConUnique, isClassTyCon) -import Type (mkTvSubstPrs, substTy, coreView) +import Type (mkTvSubstPrs, substTy, coreView, piResultTys) import TyCoRep (Coercion (..), TyLit (..), Type (..)) import Unique (Uniquable (..), Unique, getKey, hasKey) import Var (Id, TyVar, Var, idDetails, @@ -288,71 +288,101 @@ coreToTerm primMap unlocs = term , let (nm, _) = RWS.evalRWS (qualifiedNameString (varName x)) noSrcSpan emptyGHC2CoreState - = go nm args + = go nm (varType x) args | otherwise = term' e where -- Remove most Signal transformers - go "Clash.Signal.Internal.mapSignal#" args - | length args == 5 - = term (App (args!!3) (args!!4)) - go "Clash.Signal.Internal.signal#" args - | length args == 3 - = term (args!!2) - go "Clash.Signal.Internal.appSignal#" args - | length args == 5 - = term (App (args!!3) (args!!4)) - go "Clash.Signal.Internal.joinSignal#" args + go "Clash.Signal.Internal.mapSignal#" pTy args + | [Type aTy, Type bTy, Type domTy, fTm, aSigTm] <- args + = do + let aSigTy = piResultTys pTy [bTy,aTy,domTy,aTy,aTy] + bSigTy = piResultTys pTy [aTy,bTy,domTy,bTy,bTy] + aTyC <- coreToType aTy + bTyC <- coreToType bTy + aSigTyC <- coreToType aSigTy + bSigTyC <- coreToType bSigTy + C.Cast <$> (C.App <$> term fTm + <*> (C.Cast <$> term aSigTm + <*> pure aSigTyC + <*> pure aTyC)) + <*> pure bTyC + <*> pure bSigTyC + go "Clash.Signal.Internal.signal#" pty args + | [Type aTy, Type domTy, aTm] <- args + = let aSigTy = piResultTys pty [aTy,domTy,aTy] + in C.Cast <$> term aTm <*> coreToType aTy <*> coreToType aSigTy + go "Clash.Signal.Internal.appSignal#" pTy args + | [Type domTy, Type aTy, Type bTy, fSigTm, aSigTm] <- args + = do + let aSigTy = piResultTys pTy [domTy,bTy,aTy,aTy,aTy] + bSigTy = piResultTys pTy [domTy,aTy,bTy,bTy,bTy] + fSigTy = piResultTys pTy [domTy,aTy,FunTy aTy bTy,aTy,aTy] + aTyC <- coreToType aTy + bTyC <- coreToType bTy + aSigTyC <- coreToType aSigTy + bSigTyC <- coreToType bSigTy + fSigTyC <- coreToType fSigTy + let fTyC = C.mkFunTy aTyC bTyC + C.Cast <$> (C.App <$> (C.Cast <$> term fSigTm + <*> pure fSigTyC + <*> pure fTyC) + <*> (C.Cast <$> term aSigTm + <*> pure aSigTyC + <*> pure aTyC)) + <*> pure bTyC + <*> pure bSigTyC + go "Clash.Signal.Internal.joinSignal#" _ args | length args == 3 = term (args!!2) - go "Clash.Signal.Bundle.vecBundle#" args + go "Clash.Signal.Bundle.vecBundle#" _ args | length args == 4 = term (args!!3) --- Remove `$` - go "GHC.Base.$" args + go "GHC.Base.$" _ args | length args == 5 = term (App (args!!3) (args!!4)) - go "GHC.Magic.noinline" args -- noinline :: forall a. a -> a + go "GHC.Magic.noinline" _ args -- noinline :: forall a. a -> a | [_ty, x] <- args = term x -- Remove most CallStack logic - go "GHC.Stack.Types.PushCallStack" args = term (last args) - go "GHC.Stack.Types.FreezeCallStack" args = term (last args) - go "GHC.Stack.withFrozenCallStack" args + go "GHC.Stack.Types.PushCallStack" _ args = term (last args) + go "GHC.Stack.Types.FreezeCallStack" _ args = term (last args) + go "GHC.Stack.withFrozenCallStack" _ args | length args == 3 = term (App (args!!2) (args!!1)) - go "Clash.Class.BitPack.packXWith" args + go "Clash.Class.BitPack.packXWith" _ args | [_nTy,_aTy,_kn,f] <- args = term f - go "Clash.Sized.BitVector.Internal.checkUnpackUndef" args + go "Clash.Sized.BitVector.Internal.checkUnpackUndef" _ args | [_nTy,_aTy,_kn,_typ,f] <- args = term f - go "Clash.Magic.prefixName" args + go "Clash.Magic.prefixName" _ args | [Type nmTy,_aTy,f] <- args = C.Tick <$> (C.NameMod C.PrefixName <$> coreToType nmTy) <*> term f - go "Clash.Magic.suffixName" args + go "Clash.Magic.suffixName" _ args | [Type nmTy,_aTy,f] <- args = C.Tick <$> (C.NameMod C.SuffixName <$> coreToType nmTy) <*> term f - go "Clash.Magic.suffixNameFromNat" args + go "Clash.Magic.suffixNameFromNat" _ args | [Type nmTy,_aTy,f] <- args = C.Tick <$> (C.NameMod C.SuffixName <$> coreToType nmTy) <*> term f - go "Clash.Magic.suffixNameP" args + go "Clash.Magic.suffixNameP" _ args | [Type nmTy,_aTy,f] <- args = C.Tick <$> (C.NameMod C.SuffixNameP <$> coreToType nmTy) <*> term f - go "Clash.Magic.suffixNameFromNatP" args + go "Clash.Magic.suffixNameFromNatP" _ args | [Type nmTy,_aTy,f] <- args = C.Tick <$> (C.NameMod C.SuffixNameP <$> coreToType nmTy) <*> term f - go "Clash.Magic.setName" args + go "Clash.Magic.setName" _ args | [Type nmTy,_aTy,f] <- args = C.Tick <$> (C.NameMod C.SetName <$> coreToType nmTy) <*> term f - go "Clash.Magic.deDup" args + go "Clash.Magic.deDup" _ args | [_aTy,f] <- args = C.Tick C.DeDup <$> term f - go "Clash.Magic.noDeDup" args + go "Clash.Magic.noDeDup" _ args | [_aTy,f] <- args = C.Tick C.NoDeDup <$> term f - go _ _ = term' e + go _ _ _ = term' e term' (Var x) = var x term' (Lit l) = return $ C.Literal (coreToLiteral l) term' (App eFun (Type tyArg)) = C.TyApp <$> term eFun <*> coreToType tyArg @@ -405,7 +435,7 @@ coreToTerm primMap unlocs = term case hasPrimCoM of Just _ | ty1_I || ty2_I -> C.Cast <$> term e <*> coreToType ty1 <*> coreToType ty2 - _ -> term e + _ -> C.Cast <$> term e <*> coreToType ty1 <*> coreToType ty2 term' (Tick (SourceNote rsp _) e) = C.Tick (C.SrcSpan (RealSrcSpan rsp)) <$> addUsefull (RealSrcSpan rsp) (term e) term' (Tick _ e) = term e diff --git a/clash-lib/src/Clash/Core/Evaluator.hs b/clash-lib/src/Clash/Core/Evaluator.hs index 846cf439f1..d9557bd4fc 100644 --- a/clash-lib/src/Clash/Core/Evaluator.hs +++ b/clash-lib/src/Clash/Core/Evaluator.hs @@ -140,6 +140,10 @@ unwindStack m let term = Tick sp (getTerm m') in unwindStack (setTerm term m') + Castish ty1 ty2 -> + let term = Cast (getTerm m') ty1 ty2 + in unwindStack (setTerm term m') + -- | A single step in the partial evaluator. The result is the new heap and -- stack, and the next expression to be reduced. -- @@ -232,6 +236,8 @@ stepApp x y m tcm = GT -> let (m0, n) = newLetBinding tcm m y in Just . setTerm x $ stackPush (Apply n) m0 + Cast {} -> error "stepApp QQ" + _ -> let (m0, n) = newLetBinding tcm m y in Just . setTerm x $ stackPush (Apply n) m0 where @@ -264,6 +270,8 @@ stepTyApp x ty m tcm = LT -> newBinder tys' (TyApp x ty) m tcm GT -> Just . setTerm x $ stackPush (Instantiate ty) m + Cast {} -> error "stepTyApp QQ" + _ -> Just . setTerm x $ stackPush (Instantiate ty) m where (term, args, _) = collectArgsTicks (TyApp x ty) @@ -273,17 +281,14 @@ stepLetRec :: [LetBinding] -> Term -> Step stepLetRec bs x m _ = Just (allocate bs x m) stepCase :: Term -> Type -> [Alt] -> Step +stepCase (Cast {}) _ty _alts _m _ = error "stepCase QQ" stepCase scrut ty alts m _ = Just . setTerm scrut $ stackPush (Scrutinise ty alts) m -- TODO Support stepwise evaluation of casts. -- stepCast :: Term -> Type -> Type -> Step -stepCast _ _ _ _ _ = - flip trace Nothing $ unlines - [ "WARNING: " <> $(curLoc) <> "Clash can't symbolically evaluate casts" - , "Please file an issue at https://github.com/clash-lang/clash-compiler/issues" - ] +stepCast x ty1 ty2 m _ = Just . setTerm x $ stackPush (Castish ty1 ty2) m stepTick :: TickInfo -> Term -> Step stepTick tick x m _ = @@ -356,7 +361,8 @@ unwind tcm m v = do go (Instantiate ty) = return . instantiate v ty go (PrimApply p tys vs tms) = mPrimUnwind m tcm p tys vs v tms go (Scrutinise _ as) = return . scrutinise v as - go (Tickish _) = return . setTerm (valToTerm v) + go (Tickish t) = flip (unwind tcm) (TickValue t v) + go (Castish ty1 ty2) = flip (unwind tcm) (CastValue v ty1 ty2) -- | Update the Heap with the evaluated term update :: IdScope -> Id -> Value -> Machine -> Machine diff --git a/clash-lib/src/Clash/Core/Evaluator/Types.hs b/clash-lib/src/Clash/Core/Evaluator/Types.hs index 4c3200edcc..903ab5e2f0 100644 --- a/clash-lib/src/Clash/Core/Evaluator/Types.hs +++ b/clash-lib/src/Clash/Core/Evaluator/Types.hs @@ -117,6 +117,7 @@ data StackFrame | PrimApply PrimInfo [Type] [Value] [Term] | Scrutinise Type [Alt] | Tickish TickInfo + | Castish Type Type deriving Show instance ClashPretty StackFrame where @@ -134,6 +135,8 @@ instance ClashPretty StackFrame where fromPpr (Case (Literal (CharLiteral '_')) a b)] clashPretty (Tickish sp) = hsep ["Tick", fromPpr sp] + clashPretty (Castish ty1 ty2) = + hsep ["Cast", fromPpr ty1, fromPpr ty2] -- Values data Value diff --git a/clash-lib/src/Clash/Core/Type.hs b/clash-lib/src/Clash/Core/Type.hs index 4090136881..42022d4235 100644 --- a/clash-lib/src/Clash/Core/Type.hs +++ b/clash-lib/src/Clash/Core/Type.hs @@ -196,9 +196,9 @@ coreView1 tcMap ty = case tyView ty of | nameOcc tcNm == "Clash.Signal.BiSignal.BiSignalOut" , [_,_,_,elTy] <- args -> Just elTy - | nameOcc tcNm == "Clash.Signal.Internal.Signal" - , [_,elTy] <- args - -> Just elTy + -- | nameOcc tcNm == "Clash.Signal.Internal.Signal" + -- , [_,elTy] <- args + -- -> Just elTy | otherwise -> case tcMap `lookupUniqMap'` tcNm of AlgTyCon {algTcRhs = (NewTyCon _ nt)} diff --git a/clash-lib/src/Clash/Core/Util.hs b/clash-lib/src/Clash/Core/Util.hs index 811f6b0fa7..d6fa06e5ce 100644 --- a/clash-lib/src/Clash/Core/Util.hs +++ b/clash-lib/src/Clash/Core/Util.hs @@ -50,7 +50,7 @@ import Clash.Core.Type coreView, coreView1, isFunTy, isPolyFunCoreTy, mkFunTy, splitFunTy, tyView, undefinedTy, isTypeFamilyApplication) import Clash.Core.TyCon - (TyConMap, tyConDataCons) + (TyConMap, TyConName, tyConDataCons) import Clash.Core.TysPrim (typeNatKind) import Clash.Core.Var (Id, TyVar, Var (..), isLocalId, mkLocalId, mkTyVar) @@ -995,15 +995,17 @@ shouldSplit shouldSplit tcm (tyView -> TyConApp (nameOcc -> "Clash.Explicit.SimIO.SimIO") [tyArg]) = -- We also look through `SimIO` to find things like Files shouldSplit tcm tyArg -shouldSplit tcm ty = shouldSplit0 tcm (tyView (coreView tcm ty)) +shouldSplit tcm ty = shouldSplit0 emptyUniqSet tcm (tyView (coreView tcm ty)) -- | Worker of 'shouldSplit', works on 'TypeView' instead of 'Type' shouldSplit0 - :: TyConMap + :: UniqSet TyConName + -> TyConMap -> TypeView -> Maybe (Term,[Type]) -shouldSplit0 tcm (TyConApp tcNm tyArgs) - | Just tc <- lookupUniqMap tcNm tcm +shouldSplit0 seen tcm (TyConApp tcNm tyArgs) + | tcNm `notElemUniqSet` seen + , Just tc <- lookupUniqMap tcNm tcm , [dc] <- tyConDataCons tc , let dcArgs = substArgTys dc tyArgs , let dcArgVs = map (tyView . coreView tcm) dcArgs @@ -1012,8 +1014,10 @@ shouldSplit0 tcm (TyConApp tcNm tyArgs) else Nothing where + seen1 = extendUniqSet seen tcNm + shouldSplitTy :: TypeView -> Bool - shouldSplitTy ty = isJust (shouldSplit0 tcm ty) || splitTy ty + shouldSplitTy ty = isJust (shouldSplit0 seen1 tcm ty) || splitTy ty -- Hidden constructs (HiddenClock, HiddenReset, ..) don't need to be split -- because KnownDomain will be filtered anyway during netlist generation due @@ -1046,7 +1050,7 @@ shouldSplit0 tcm (TyConApp tcNm tyArgs) ] splitTy _ = False -shouldSplit0 _ _ = Nothing +shouldSplit0 _ _ _ = Nothing -- | Potentially split apart a list of function argument types. e.g. given: -- diff --git a/clash-lib/src/Clash/Netlist.hs b/clash-lib/src/Clash/Netlist.hs index fc1bf9c93e..1c2c84156e 100644 --- a/clash-lib/src/Clash/Netlist.hs +++ b/clash-lib/src/Clash/Netlist.hs @@ -351,6 +351,9 @@ mkDeclarations' -> Term -- ^ RHS of the let-binder -> NetlistMonad [Declaration] +mkDeclarations' declType bndr (collectTicks -> (Cast e _ _,ticks)) = + mkDeclarations' declType bndr (mkTicks e ticks) + mkDeclarations' _declType bndr (collectTicks -> (Var v,ticks)) = withTicks ticks $ \tickDecls -> do mkFunApp (id2identifier bndr) v [] tickDecls @@ -734,6 +737,9 @@ mkExpr bbEasD declType bndr app = decls <- concat <$> mapM (uncurry mkDeclarations) binders (bodyE,bodyDecls) <- mkExpr bbEasD declType bndr (mkApps (mkTicks body ticks) args) return (bodyE,netDecls ++ decls ++ bodyDecls) + + Cast e0 _ _ | null args -> + mkExpr bbEasD declType bndr (mkTicks e0 ticks) _ -> throw (ClashException sp ($(curLoc) ++ "Not in normal form: application of a Lambda-expression\n\n" ++ showPpr app) Nothing) -- | Generate an expression that projects a field out of a data-constructor. diff --git a/clash-lib/src/Clash/Netlist/BlackBox.hs b/clash-lib/src/Clash/Netlist/BlackBox.hs index 578024114e..def3167583 100644 --- a/clash-lib/src/Clash/Netlist/BlackBox.hs +++ b/clash-lib/src/Clash/Netlist/BlackBox.hs @@ -61,7 +61,8 @@ import Clash.Core.Type as C (Type (..), ConstTy (..), TypeView (..), mkFunTy, splitFunTys, splitFunTy, tyView) import Clash.Core.TyCon as C (TyConMap, tyConDataCons) import Clash.Core.Util - (collectBndrs, inverseTopSortLetBindings, isFun, mkApps, splitShouldSplit, termType) + (collectBndrs, inverseTopSortLetBindings, isFun, mkApps, splitShouldSplit, + termType, mkTicks) import Clash.Core.Var as V (Id, Var (..), mkLocalId, modifyVarName) import Clash.Core.VarEnv @@ -188,6 +189,7 @@ isLiteral e = case collectArgs e of (Data _, args) -> all (either isLiteral (const True)) args (Prim _, args) -> all (either isLiteral (const True)) args (C.Literal _,_) -> True + (Cast e0 _ _, args) -> all (either isLiteral (const True)) (Left e0:args) _ -> False mkArgument @@ -239,6 +241,7 @@ mkArgument bndr e = do (Case scrut ty' [alt],[],_) -> do (projection,decls) <- mkProjection False (NetlistId bndr ty) scrut ty' alt return ((projection,hwTy,False),decls) + (Cast e0 _ _,[],ticks) -> mkArgument bndr (mkTicks e0 ticks) _ -> return ((Identifier (error ($(curLoc) ++ "Forced to evaluate unexpected function argument: " ++ eTyMsg)) Nothing ,hwTy,False),[]) diff --git a/clash-lib/src/Clash/Normalize/Transformations.hs b/clash-lib/src/Clash/Normalize/Transformations.hs index 5f77297b62..379c20fa0b 100644 --- a/clash-lib/src/Clash/Normalize/Transformations.hs +++ b/clash-lib/src/Clash/Normalize/Transformations.hs @@ -16,6 +16,8 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TemplateHaskell #-} +#include "../../ClashDebug.h" + module Clash.Normalize.Transformations ( appProp , caseLet @@ -78,6 +80,7 @@ import qualified Data.Maybe as Maybe import qualified Data.Monoid as Monoid import qualified Data.Primitive.ByteArray as BA import qualified Data.Text as Text +import Data.Text.Prettyprint.Doc ((<+>)) import qualified Data.Vector.Primitive as PV import Debug.Trace import GHC.Integer.GMP.Internals (Integer (..), BigNat (..)) @@ -94,7 +97,7 @@ import Clash.Core.FreeVars (localIdOccursIn, localIdsDoNotOccurIn, freeLocalIds, termFreeTyVars, typeFreeVars, localVarsDoNotOccurIn, localIdDoesNotOccurIn) import Clash.Core.Literal (Literal (..)) -import Clash.Core.Pretty (showPpr) +import Clash.Core.Pretty (ppr, showPpr) import Clash.Core.Subst (substTm, mkSubst, extendIdSubst, extendIdSubstList, extendTvSubst, extendTvSubstList, freshenTm, substTyInVar, deShadowTerm, deShadowAlt) @@ -106,7 +109,7 @@ import Clash.Core.Term ) import Clash.Core.Type (Type, TypeView (..), applyFunTy, isPolyFunCoreTy, isClassTy, - normalizeType, splitFunForallTy, + splitFunForallTy, splitFunTy, tyView, mkPolyFunTy, coreView) import Clash.Core.TyCon (TyConMap, tyConDataCons) @@ -838,6 +841,7 @@ removeUnusedExpr _ e = return e bindConstantVar :: HasCallStack => NormRewrite bindConstantVar = inlineBinders test where + test b (i,stripTicks -> (Cast e _ _)) = test b (i,e) test _ (i,stripTicks -> e) = case isLocalVar e of -- Don't inline `let x = x in x`, it throws us in an infinite loop True -> return (i `localIdDoesNotOccurIn` e) @@ -849,9 +853,10 @@ bindConstantVar = inlineBinders test -- | Push a cast over a case into it's alternatives. caseCast :: HasCallStack => NormRewrite -caseCast _ (Cast (stripTicks -> Case subj ty alts) ty1 ty2) = do - let alts' = map (\(p,e) -> (p, Cast e ty1 ty2)) alts - changed (Case subj ty alts') +caseCast _ e0@(Cast (collectTicks -> (Case subj tyA alts,ticks)) ty1 ty2) = + WARN( tyA /= ty1, "Bad Cast of Case:" <+> ppr e0 ) do + let alts' = map (\(p,e) -> (p, Cast e ty1 ty2)) alts + changed (mkTicks (Case subj ty2 alts') ticks) caseCast _ e = return e -- | Push a cast over a Letrec into it's body @@ -879,19 +884,23 @@ letCast _ e = return e -- and expression where two casts are "back-to-back" after which we can -- eliminate them in 'eliminateCastCast'. argCastSpec :: HasCallStack => NormRewrite -argCastSpec ctx e@(App _ (stripTicks -> Cast e' _ _)) = - if isWorkFree e' then - go - else - warn go - where - go = specializeNorm ctx e - warn = trace (unwords - [ "WARNING:", $(curLoc), "specializing a function on a non work-free" - , "cast. Generated HDL implementation might contain duplicate work." - , "Please report this as a bug.", "\n\nExpression where this occured:" - , "\n\n" ++ showPpr e - ]) +argCastSpec ctx@(TransformContext is0 ctx0) e@(App e0 (collectTicks -> (Cast eA t1 t2,ticks))) + | (Var {}, _) <- collectArgs e0 = do + tcm <- Lens.view tcCache + eARep <- representableType <$> Lens.view typeTranslator + <*> Lens.view customReprs + <*> pure False + <*> pure tcm + <*> pure t1 + if isWorkFree eA || not eARep then + specializeNorm ctx e + else do + specId <- mkTmBinderFor is0 tcm (mkUnsafeSystemName "argCastSpec" 0) eA + let is1 = extendInScopeSet is0 specId + e1 <- specializeNorm (TransformContext is1 (LetBody [specId]:ctx0)) + (App e0 (Cast (Var specId) t1 t2)) + changed (Letrec [(specId,mkTicks eA ticks)] e1) + argCastSpec _ e = return e -- | Only inline casts that just contain a 'Var', because these are guaranteed work-free. @@ -908,20 +917,19 @@ inlineCast = inlineBinders test -- (cast :: b -> a) $ (cast :: a -> b) x ==> x -- @ eliminateCastCast :: HasCallStack => NormRewrite -eliminateCastCast _ c@(Cast (stripTicks -> Cast e tyA tyB) tyB' tyC) = do - tcm <- Lens.view tcCache - let ntyA = normalizeType tcm tyA - ntyB = normalizeType tcm tyB - ntyB' = normalizeType tcm tyB' - ntyC = normalizeType tcm tyC - if ntyB == ntyB' && ntyA == ntyC then changed e - else throwError - where throwError = do - (nm,sp) <- Lens.use curFun - throw (ClashException sp ($(curLoc) ++ showPpr nm - ++ ": Found 2 nested casts whose types don't line up:\n" - ++ showPpr c) - Nothing) +eliminateCastCast _ c@(Cast (collectTicks -> (Cast e tyA tyB, ticks)) tyB' tyC) + | tyB == tyB' + = if tyA == tyC then + changed (mkTicks e ticks) + else + changed (Cast (mkTicks e ticks) tyA tyC) + | otherwise + = do + (nm,sp) <- Lens.use curFun + throw (ClashException sp ($(curLoc) ++ showPpr nm + ++ ": Found 2 nested casts whose types don't line up:\n" + ++ showPpr c) + Nothing) eliminateCastCast _ e = return e @@ -1192,6 +1200,17 @@ appProp ctx@(TransformContext is0 _) (App (collectTicks -> (Case scrut ty alts,t let alts' = map (second (`App` (Var boundArg))) alts changed (Letrec [(boundArg, arg)] (mkTicks (Case scrut ty' alts') ticks)) +appProp _ e0@(App (collectTicks -> (Cast e tyA tyB,ticks)) arg) = + case (tyView tyA, tyView tyB) of + (FunTy tyAArg tyARes,FunTy tyBArg tyBRes) -> + changed (Cast (App (mkTicks e ticks) (Cast arg tyBArg tyAArg)) tyARes tyBRes) + _ -> do + (nm,sp) <- Lens.use curFun + throw (ClashException sp ($(curLoc) ++ showPpr nm + ++ ": AppProp expected FunTy:\n" + ++ showPpr e0) + Nothing) + appProp (TransformContext is0 _) (TyApp (collectTicks -> (TyLam tv e,ticks)) t) = do let subst = extendTvSubst (mkSubst is0) tv t changed $ mkTicks (substTm "appProp.TyAppTyLam" subst e) ticks @@ -1205,6 +1224,13 @@ appProp _ (TyApp (collectTicks -> (Case scrut altsTy alts,ticks)) ty) = do let ty' = piResultTy tcm altsTy ty changed (mkTicks (Case scrut ty' alts') ticks) +appProp _ e0@(TyApp (collectTicks -> (Cast {},_)) _) = do + (nm,sp) <- Lens.use curFun + throw (ClashException sp ($(curLoc) ++ showPpr nm + ++ ": AppProp TPush not implemented:\n" + ++ showPpr e0) + Nothing) + appProp _ e = return e -- | Unlike 'appProp', which propagates a single argument in an application one @@ -1269,6 +1295,40 @@ appPropFast ctx@(TransformContext is _) = \case setChanged go is0 e args (sp:ticks) + go is0 c@(Cast e tyA tyB) args0 ticks = case args0 of + [] -> Cast <$> go is0 e args0 ticks <*> pure tyA <*> pure tyB + (Left arg:args1) -> do + setChanged + case (tyView tyA,tyView tyB) of + (FunTy aa ar,FunTy ba br) + -> Cast <$> go is0 e (Left (Cast arg ba aa):args1) ticks + <*> pure ar + <*> pure br + (TyConApp sigTc [_domTy,fTy],FunTy ba br) + | nameOcc sigTc == "Clash.Signal.Internal.Signal" + , FunTy aa ar <- tyView fTy + , aa == ba + , ar == br + -> go is0 e (Left arg:args1) ticks + (FunTy aa ar,TyConApp sigTc [_domTy,fTy]) + | nameOcc sigTc == "Clash.Signal.Internal.Signal" + , FunTy ba br <- tyView fTy + , aa == ba + , ar == br + -> go is0 e (Left arg:args1) ticks + _ -> do + (nm,sp) <- Lens.use curFun + throw (ClashException sp ($(curLoc) ++ showPpr nm + ++ ": AppPropFast FunTy expected:\n" + ++ showPpr (App c arg)) + Nothing) + _ -> do + (nm,sp) <- Lens.use curFun + throw (ClashException sp ($(curLoc) ++ showPpr nm + ++ ": AppPropFast TPush unimplemented:\n" + ++ showPpr c) + Nothing) + go _ fun args ticks = return (mkApps (mkTicks fun ticks) args) goAlt is0 args0 (p,e) = do diff --git a/clash-lib/src/Clash/Normalize/Util.hs b/clash-lib/src/Clash/Normalize/Util.hs index e1fd950526..40dbfe9467 100644 --- a/clash-lib/src/Clash/Normalize/Util.hs +++ b/clash-lib/src/Clash/Normalize/Util.hs @@ -375,6 +375,7 @@ classifyFunction = go (TermClassification 0 0 0) (_:_:_) -> c & selection +~ 1 _ -> c go !c (Tick _ e) = go c e + go !c (Cast e _ _) = go c e go c _ = c -- | Determine whether a function adds a lot of hardware or not. diff --git a/clash-lib/src/Clash/Rewrite/Util.hs b/clash-lib/src/Clash/Rewrite/Util.hs index b719241256..535c4b9559 100644 --- a/clash-lib/src/Clash/Rewrite/Util.hs +++ b/clash-lib/src/Clash/Rewrite/Util.hs @@ -220,15 +220,15 @@ applyDebug lvl _transformations name exprOld hasChanged exprNew = , "substitution." ]) - traceIf (lvl >= DebugApplied && (not (beforeTy `aeqType` afterTy))) - ( concat [ $(curLoc) + Monad.when (lvl >= DebugApplied && (not (beforeTy `aeqType` afterTy))) $ + error ( concat [ $(curLoc) , "Error when applying rewrite ", name , " to:\n" , before , "\nResult:\n" ++ after ++ "\n" , "Changes type from:\n", showPpr beforeTy , "\nto:\n", showPpr afterTy ] - ) (return ()) + ) Monad.when (lvl >= DebugApplied && not hasChanged && not (exprOld `aeqTerm` exprNew)) $ error $ $(curLoc) ++ "Expression changed without notice(" ++ name ++ "): before" @@ -397,6 +397,8 @@ tailCalls id_ = \case in case scrutTl of Just 0 | all (/= Nothing) altsTl -> Just (sum (catMaybes altsTl)) _ -> Nothing + Tick _ e -> tailCalls id_ e + Cast e _ _ -> tailCalls id_ e _ -> Just 0 -- | Determines whether a function has the following shape: @@ -456,7 +458,7 @@ isWorkFree (collectArgs -> (fun,args)) = case fun of Letrec bs e -> isWorkFree e && all (isWorkFree . snd) bs && all isWorkFreeArg args Case s _ [(_,a)] -> isWorkFree s && isWorkFree a && all isWorkFreeArg args - Cast e _ _ -> isWorkFree e && all isWorkFreeArg args + Cast e _ _ -> isWorkFree (mkApps e args) _ -> False where isWorkFreeArg = either isWorkFree (const True) @@ -476,6 +478,7 @@ isConstant e = case collectArgs e of (Prim _, args) -> all (either isConstant (const True)) args (Lam _ _, _) -> not (hasLocalFreeVars e) (Literal _,_) -> True + (Cast e0 _ _,args) -> all (either isConstant (const True)) (Left e0:args) _ -> False isConstantNotClockReset @@ -503,6 +506,7 @@ isWorkFreeClockOrResetOrEnable tcm e = (Var _, []) -> Just True (Data _, []) -> Just True -- For Enable True/False (Literal _,_) -> Just True + (Cast e0 _ _,[]) -> isWorkFreeClockOrResetOrEnable tcm e0 _ -> Just False else Nothing @@ -525,19 +529,21 @@ isWorkFreeIsh isWorkFreeIsh e = do tcm <- Lens.view tcCache case isWorkFreeClockOrResetOrEnable tcm e of - Just b -> pure b - Nothing -> - case collectArgs e of - (Data _, args) -> allM isWorkFreeIshArg args - (Prim pInfo, args) -> case primWorkInfo pInfo of - WorkAlways -> pure False -- Things like clock or reset generator always - -- perform work - WorkVariable -> pure (all isConstantArg args) - _ -> allM isWorkFreeIshArg args - (Lam _ _, _) -> pure (not (hasLocalFreeVars e)) - (Literal _,_) -> pure True - _ -> pure False + Just b -> pure b + Nothing -> go e where + go e0 = case collectArgs e0 of + (Data _, args) -> allM isWorkFreeIshArg args + (Prim pInfo, args) -> case primWorkInfo pInfo of + WorkAlways -> pure False -- Things like clock or reset generator always + -- perform work + WorkVariable -> pure (all isConstantArg args) + _ -> allM isWorkFreeIshArg args + (Lam _ _, _) -> pure (not (hasLocalFreeVars e)) + (Literal _,_) -> pure True + (Cast e1 _ _,args) -> go (mkApps e1 args) + _ -> pure False + isWorkFreeIshArg = either isWorkFreeIsh (pure . const True) isConstantArg = either isConstant (const True) @@ -878,9 +884,9 @@ specialise' specMapLbl specHistLbl specLimitLbl (TransformContext is0 _) e (Var -- are inlined, meaning the state-transition-function -- and the memory element will be in a single function. gTmM <- fmap (lookupUniqMap (varName g)) $ Lens.use bindings - return (g,maybe inl (^. _3) gTmM, maybe specArg (Left . (`mkApps` gArgs) . (^. _4)) gTmM) - else return (f,inl,specArg) - _ -> return (f,inl,specArg) + return (g,maybe inl (^. _3) gTmM, maybe specArgIn (Left . (`mkApps` gArgs) . (^. _4)) gTmM) + else return (f,inl,specArgIn) + _ -> return (f,inl,specArgIn) -- Create specialized functions let newBody = mkAbstraction (mkApps bodyTm (argVars ++ [specArg'])) (boundArgs ++ specBndrs) newf <- mkFunction (varName fId) sp inl' newBody