Skip to content

Commit b274115

Browse files
committed
Avoid some uses of :> and @>.
1 parent 75eacbf commit b274115

File tree

10 files changed

+61
-76
lines changed

10 files changed

+61
-76
lines changed

src/lib/Builder.hs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -786,8 +786,8 @@ buildMap :: (Emits n, ScopableBuilder r m)
786786
-> (forall l. (Emits l, DExt n l) => Atom r l -> m l (Atom r l))
787787
-> m n (Atom r n)
788788
buildMap xs f = do
789-
TabTy d (_:>t) _ <- return $ getType xs
790-
buildFor noHint Fwd (IxType t d) \i ->
789+
TabPi t <- return $ getType xs
790+
buildFor noHint Fwd (tabIxType t) \i ->
791791
tabApp (sink xs) (Var i) >>= f
792792

793793
unzipTab :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n, Atom r n)
@@ -857,8 +857,8 @@ zeroAt ty = liftEmitBuilder $ go ty where
857857
go = \case
858858
BaseTy bt -> return $ Con $ Lit $ zeroLit bt
859859
ProdTy tys -> ProdVal <$> mapM go tys
860-
TabTy d (b:>t) bodyTy -> buildFor (getNameHint b) Fwd (IxType t d) \i ->
861-
go =<< applySubst (b @> SubstVal (Var i)) bodyTy
860+
TabPi tabPi -> buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i ->
861+
go =<< instantiateTabPiTy (sink tabPi) (Var i)
862862
_ -> unreachable
863863
zeroLit bt = case bt of
864864
Scalar Float64Type -> Float64Lit 0.0
@@ -902,8 +902,8 @@ tangentBaseMonoidFor ty = do
902902
addTangent :: (Emits n, SBuilder m) => SAtom n -> SAtom n -> m n (SAtom n)
903903
addTangent x y = do
904904
case getType x of
905-
TabTy d (b:>t) _ ->
906-
liftEmitBuilder $ buildFor (getNameHint b) Fwd (IxType t d) \i -> do
905+
TabPi t ->
906+
liftEmitBuilder $ buildFor (getNameHint t) Fwd (tabIxType t) \i -> do
907907
bindM2 addTangent (tabApp (sink x) (Var i)) (tabApp (sink y) (Var i))
908908
TC con -> case con of
909909
BaseType (Scalar _) -> emitOp $ BinOp FAdd x y

src/lib/CheapReduction.hs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ module CheapReduction
1515
, unwrapLeadingNewtypesType, wrapNewtypesData, liftSimpAtom, liftSimpType
1616
, liftSimpFun, makeStructRepVal, NonAtomRenamer (..), Visitor (..), VisitGeneric (..)
1717
, visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2
18-
, visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiatePiTy
18+
, visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiatePiTy, instantiateTabPiTy
1919
, bindersToVars, bindersToAtoms)
2020
where
2121

@@ -474,6 +474,10 @@ instantiatePiTy :: (EnvReader m, IRRep r) => PiType r n -> [Atom r n] -> m n (Ef
474474
instantiatePiTy (PiType bs effTy) xs = do
475475
applySubst (bs @@> (SubstVal <$> xs)) effTy
476476

477+
instantiateTabPiTy :: (EnvReader m, IRRep r) => TabPiType r n -> Atom r n -> m n (Type r n)
478+
instantiateTabPiTy (TabPiType _ b resultTy) x = do
479+
applySubst (b @> SubstVal x) resultTy
480+
477481
-- Returns a representation type (type of an TypeCon-typed Newtype payload)
478482
-- given a list of instantiated DataConDefs.
479483
dataDefRep :: DataConDefs n -> CType n

src/lib/CheckType.hs

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -788,11 +788,7 @@ checkTabApp ty (i:rest) = do
788788
resultTy' <- applySubst (b@>SubstVal i') resultTy
789789
checkTabApp resultTy' rest
790790

791-
checkArgTys
792-
:: (Typer m r, SubstB AtomSubstVal b, BindsNames b, BindsOneAtomName r b, IRRep r)
793-
=> Nest b o o'
794-
-> [Atom r o]
795-
-> m i o ()
791+
checkArgTys :: (Typer m r, IRRep r) => Nest (Binder r) o o' -> [Atom r o] -> m i o ()
796792
checkArgTys Empty [] = return ()
797793
checkArgTys (Nest b bs) (x:xs) = do
798794
dropSubst $ x |: binderType b
@@ -930,15 +926,14 @@ checkedInstantiateTyConDef (TyConDef _ _ bs cons) (TyConParams _ xs) = do
930926
checkedApplyNaryAbs (Abs bs cons) xs
931927

932928
checkedApplyNaryAbs
933-
:: forall b r e o m
934-
. ( BindsOneAtomName r b, EnvReader m, Fallible1 m, SinkableE e
935-
, SubstE AtomSubstVal e, IRRep r, SubstB AtomSubstVal b)
936-
=> Abs (Nest b) e o -> [Atom r o] -> m o (e o)
929+
:: forall r e o m
930+
. ( EnvReader m, Fallible1 m, SinkableE e , SubstE AtomSubstVal e, IRRep r)
931+
=> Abs (Nest (Binder r)) e o -> [Atom r o] -> m o (e o)
937932
checkedApplyNaryAbs (Abs bsTop e) xsTop = do
938933
go (EmptyAbs bsTop) xsTop
939934
applySubst (bsTop@@>(SubstVal<$>xsTop)) e
940935
where
941-
go :: EmptyAbs (Nest b) o -> [Atom r o] -> m o ()
936+
go :: EmptyAbs (Nest (Binder r)) o -> [Atom r o] -> m o ()
942937
go (Abs Empty UnitE) [] = return ()
943938
go (Abs (Nest b bs) UnitE) (x:xs) = do
944939
checkAlphaEq (binderType b) (getType x)

src/lib/Export.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ parseTabTy = go []
175175
NewtypeTyCon Nat -> return $ Just $ RectContArrayPtr IdxRepScalarBaseTy shape
176176
TabTy d (b:>ixty) a -> do
177177
maybeN <- case IxType ixty d of
178-
(IxType (NewtypeTyCon (Fin n)) _) -> return $ Just n
179-
(IxType _ (IxDictRawFin n)) -> return $ Just n
178+
IxType (NewtypeTyCon (Fin n)) _ -> return $ Just n
179+
IxType _ (IxDictRawFin n) -> return $ Just n
180180
_ -> return Nothing
181181
maybeDim <- case maybeN of
182182
Just (Var v) -> do

src/lib/Imp.hs

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -366,16 +366,16 @@ toImpRefOp refDest' m = do
366366
ans <- liftBuilderImp $ emitBlock (sink body')
367367
storeAtom accDest ans
368368
False -> case accTy of
369-
TabTy d (b:>t) eltTy -> do
370-
let ixTy = IxType t d
369+
TabPi t -> do
370+
let ixTy = tabIxType t
371371
n <- indexSetSizeImp ixTy
372372
emitLoop noHint Fwd n \i -> do
373373
idx <- unsafeFromOrdinalImp (sink ixTy) i
374374
xElt <- liftBuilderImp $ tabApp (sink x) (sink idx)
375375
yElt <- liftBuilderImp $ tabApp (sink y) (sink idx)
376-
eltTy' <- applySubst (b@>SubstVal idx) eltTy
376+
eltTy <- instantiateTabPiTy (sink t) idx
377377
ithDest <- indexDest (sink accDest) idx
378-
liftMonoidCombine ithDest eltTy' (sink bc) xElt yElt
378+
liftMonoidCombine ithDest eltTy (sink bc) xElt yElt
379379
_ -> error $ "Base monoid type mismatch: can't lift " ++
380380
pprint baseTy ++ " to " ++ pprint accTy
381381

@@ -578,15 +578,15 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do
578578
alphaEq xTy accTy >>= \case
579579
True -> storeAtom accDest x
580580
False -> case accTy of
581-
TabTy d (b:>t) eltTy -> do
582-
let ixTy = IxType t d
581+
TabPi t -> do
582+
let ixTy = tabIxType t
583583
n <- indexSetSizeImp ixTy
584584
emitLoop noHint Fwd n \i -> do
585585
idx <- unsafeFromOrdinalImp (sink ixTy) i
586586
x' <- sinkM x
587-
eltTy' <- applySubst (b@>SubstVal idx) eltTy
587+
eltTy <- instantiateTabPiTy (sink t) idx
588588
ithDest <- indexDest (sink accDest) idx
589-
liftMonoidEmpty ithDest eltTy' x'
589+
liftMonoidEmpty ithDest eltTy x'
590590
_ -> error $ "Base monoid type mismatch: can't lift " ++
591591
pprint xTy ++ " to " ++ pprint accTy
592592

@@ -1002,11 +1002,11 @@ buildGarbageVal ty =
10021002
-- === Operations on dests ===
10031003

10041004
indexDest :: Emits n => Dest n -> SAtom n -> SubstImpM i n (Dest n)
1005-
indexDest (Dest destValTy@(TabTy d (b:>t) eltTy) tree) i = do
1006-
eltTy' <- applySubst (b@>SubstVal i) eltTy
1007-
ord <- ordinalImp (IxType t d) i
1008-
leafTys <- typeToTree destValTy
1009-
Dest eltTy' <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do
1005+
indexDest (Dest (TabPi tabTy) tree) i = do
1006+
eltTy <- instantiateTabPiTy tabTy i
1007+
ord <- ordinalImp (tabIxType tabTy) i
1008+
leafTys <- typeToTree $ TabPi tabTy
1009+
Dest eltTy <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do
10101010
BufferType ixStruct _ <- return $ getRefBufferType leafTy
10111011
offset <- computeOffsetImp ixStruct ord
10121012
impOffset ptr offset
@@ -1026,10 +1026,10 @@ indexRepValParam :: Emits n
10261026
=> SRepVal n -> SAtom n -> (SType n -> SType n)
10271027
-> (IExpr n -> SubstImpM i n (IExpr n))
10281028
-> SubstImpM i n (SRepVal n)
1029-
indexRepValParam (RepVal tabTy@(TabPi (TabPiType d (b:>t) eltTy)) vals) i tyFunc func = do
1030-
eltTy' <- applySubst (b@>SubstVal i) eltTy
1031-
ord <- ordinalImp (IxType t d) i
1032-
leafTys <- typeToTree tabTy
1029+
indexRepValParam (RepVal (TabPi tabTy) vals) i tyFunc func = do
1030+
eltTy <- instantiateTabPiTy tabTy i
1031+
ord <- ordinalImp (tabIxType tabTy) i
1032+
leafTys <- typeToTree (TabPi tabTy)
10331033
vals' <- forM (zipTrees leafTys vals) \(leafTy, ptr) -> do
10341034
BufferPtr (BufferType ixStruct _) <- return $ getIExprInterpretation leafTy
10351035
offset <- computeOffsetImp ixStruct ord
@@ -1041,7 +1041,7 @@ indexRepValParam (RepVal tabTy@(TabPi (TabPiType d (b:>t) eltTy)) vals) i tyFunc
10411041
_ -> func ptr'
10421042
-- `func` may have changed the types of the `vals'`. The caller must also
10431043
-- supply `tyFunc` to reflect that change in the SType.
1044-
return $ RepVal (tyFunc eltTy') vals'
1044+
return $ RepVal (tyFunc eltTy) vals'
10451045
indexRepValParam _ _ _ _ = error "expected table type"
10461046
{-# INLINE indexRepValParam #-}
10471047

src/lib/Lower.hs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,12 @@ lowerFor _ _ _ _ _ = error "expected a unary lambda expression"
153153
lowerTabCon :: forall i o. Emits o
154154
=> Maybe (Dest SimpIR o) -> SType i -> [SAtom i] -> LowerM i o (SExpr o)
155155
lowerTabCon maybeDest tabTy elems = do
156-
tabTy'@(TabPi (TabPiType dict (_:>t) _)) <- substM tabTy
156+
TabPi tabTy' <- substM tabTy
157157
dest <- case maybeDest of
158158
Just d -> return d
159-
Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest tabTy'
159+
Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest $ TabPi tabTy'
160160
Abs bord ufoBlock <- buildAbs noHint IdxRepTy \ord -> do
161-
buildBlock $ unsafeFromOrdinal (sink $ IxType t dict) $ Var $ sink ord
161+
buildBlock $ unsafeFromOrdinal (sink $ tabIxType tabTy') $ Var $ sink ord
162162
-- This is emitting a chain of RememberDest ops to force `dest` to be used
163163
-- linearly, and to force reads of the `Freeze dest'` result not to be
164164
-- reordered in front of the writes.

src/lib/QueryTypePure.hs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ instance IRRep r => HasType r (Con r) where
119119

120120
getSuperclassType :: RNest CBinder n l -> Nest CBinder l l' -> Int -> CType n
121121
getSuperclassType _ Empty = error "bad index"
122-
getSuperclassType bsAbove (Nest b bs) = \case
123-
0 -> ignoreHoistFailure $ hoist bsAbove $ binderType b
122+
getSuperclassType bsAbove (Nest b@(_:>t) bs) = \case
123+
0 -> ignoreHoistFailure $ hoist bsAbove t
124124
i -> getSuperclassType (RNest bsAbove b) bs (i-1)
125125

126126
instance IRRep r => HasType r (Expr r) where
@@ -213,6 +213,9 @@ rawStrType = case newName "n" of
213213
rawFinTabType :: IRRep r => Atom r n -> Type r n -> Type r n
214214
rawFinTabType n eltTy = IxType IdxRepTy (IxDictRawFin n) ==> eltTy
215215

216+
tabIxType :: TabPiType r n -> IxType r n
217+
tabIxType (TabPiType d (_:>t) _) = IxType t d
218+
216219
typesAsBinderNest
217220
:: (SinkableE e, HoistableE e, IRRep r)
218221
=> [Type r n] -> e n -> Abs (Nest (Binder r)) e n

src/lib/RuntimePrint.hs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,8 @@ bufferTy h = do
185185
extendBuffer :: (Emits n, CBuilder m) => CAtom n -> CAtom n -> m n ()
186186
extendBuffer buf tab = do
187187
RefTy h _ <- return $ getType buf
188-
TabTy d (_:>t) _ <- return $ getType tab
189-
n <- applyIxMethodCore Size (IxType t d) []
188+
TabPi t <- return $ getType tab
189+
n <- applyIxMethodCore Size (tabIxType t) []
190190
void $ applyPreludeFunction "stack_extend_internal" [n, h, buf, tab]
191191

192192
-- argument has type `Word8`
@@ -237,8 +237,8 @@ forEachTabElt
237237
-> (forall l. (Emits l, DExt n l) => CAtom l -> CAtom l -> m l ())
238238
-> m n ()
239239
forEachTabElt tab cont = do
240-
TabTy d (_:>t) _ <- return $ getType tab
241-
let ixTy = IxType t d
240+
TabPi t <- return $ getType tab
241+
let ixTy = tabIxType t
242242
void $ buildFor "i" Fwd ixTy \i -> do
243243
x <- tabApp (sink tab) (Var i)
244244
i' <- applyIxMethodCore Ordinal (sink ixTy) [Var i]

src/lib/Simplify.hs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,12 @@ getRepType ty = go ty where
153153
x <- liftSimpAtom (sink l) (Var $ binderVar b')
154154
r' <- go =<< applySubst (b@>SubstVal x) r
155155
return $ DepPairTy $ DepPairType expl b' r'
156-
TabPi (TabPiType d (b:>t) bodyTy) -> do
157-
let ixTy = IxType t d
156+
TabPi tabTy -> do
157+
let ixTy = tabIxType tabTy
158158
IxType t' d' <- simplifyIxType ixTy
159-
withFreshBinder (getNameHint b) t' \b' -> do
159+
withFreshBinder (getNameHint tabTy) t' \b' -> do
160160
x <- liftSimpAtom (sink $ ixTypeType ixTy) (Var $ binderVar b')
161-
bodyTy' <- go =<< applySubst (b@>SubstVal x) bodyTy
161+
bodyTy' <- go =<< instantiateTabPiTy (sink tabTy) x
162162
return $ TabPi $ TabPiType d' b' bodyTy'
163163
NewtypeTyCon con -> do
164164
(_, ty') <- unwrapNewtypeType con
@@ -1025,7 +1025,7 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do
10251025
return $ activeArg':rest
10261026
buildTangentArgs _ _ _ = error "zip error"
10271027

1028-
fromNonDepNest :: (HoistableB b, BindsOneAtomName CoreIR b) => Nest b n l -> [CType n]
1028+
fromNonDepNest :: Nest CBinder n l -> [CType n]
10291029
fromNonDepNest Empty = []
10301030
fromNonDepNest (Nest b bs) =
10311031
case ignoreHoistFailure $ hoist b (Abs bs UnitE) of

src/lib/Types/Core.hs

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -897,17 +897,13 @@ data LinearizationSpec (n::S) =
897897
LinearizationSpec (TopFunName n) [Active]
898898
deriving (Show, Generic)
899899

900-
-- === BindsOneAtomName ===
900+
-- === Binder utils ===
901901

902-
class BindsOneName b (AtomNameC r) => BindsOneAtomName (r::IR) (b::B) | b -> r where
903-
binderType :: b n l -> Type r n
904-
binderVar :: DExt n l => b n l -> AtomVar r l
902+
binderType :: Binder r n l -> Type r n
903+
binderType (_:>ty) = ty
905904

906-
bindersTypes :: (IRRep r, Distinct l, ProvesExt b, BindsNames b, BindsOneAtomName r b)
907-
=> Nest b n l -> [Type r l]
908-
bindersTypes Empty = []
909-
bindersTypes n@(Nest b bs) = ty : bindersTypes bs
910-
where ty = withExtEvidence n $ sink (binderType b)
905+
binderVar :: (IRRep r, DExt n l) => Binder r n l -> AtomVar r l
906+
binderVar (b:>ty) = AtomVar (binderName b) (sink ty)
911907

912908
nestToAtomVars :: (Distinct l, Ext n l, IRRep r)
913909
=> Nest (Binder r) n l -> [AtomVar r l]
@@ -916,14 +912,6 @@ nestToAtomVars = \case
916912
Nest b bs -> withExtEvidence b $ withSubscopeDistinct bs $
917913
sink (binderVar b) : nestToAtomVars bs
918914

919-
instance IRRep r => BindsOneAtomName r (BinderP (AtomNameC r) (Type r)) where
920-
binderType (_ :> ty) = ty
921-
binderVar (b:>t) = AtomVar (binderName b) (sink t)
922-
923-
toBinderNest :: BindsOneAtomName r b => Nest b n l -> Nest (Binder r) n l
924-
toBinderNest Empty = Empty
925-
toBinderNest (Nest b bs) = Nest (asNameBinder b :> binderType b) (toBinderNest bs)
926-
927915
-- === ToBinding ===
928916

929917
atomBindingToBinding :: AtomBinding r n -> Binding (AtomNameC r) n
@@ -957,14 +945,6 @@ instance (ToBinding e1 c, ToBinding e2 c) => ToBinding (EitherE e1 e2) c where
957945
toBinding (LeftE e) = toBinding e
958946
toBinding (RightE e) = toBinding e
959947

960-
-- === HasArgType ===
961-
962-
class HasArgType (e::E) (r::IR) | e -> r where
963-
argType :: e n -> Type r n
964-
965-
instance HasArgType (TabPiType r) r where
966-
argType (TabPiType _ (_:>ty) _) = ty
967-
968948
-- === Pattern synonyms ===
969949

970950
-- XXX: only use this pattern when you're actually expecting a type. If it's
@@ -2055,6 +2035,9 @@ instance IRRep r => AlphaEqE (TabPiType r) where
20552035
instance IRRep r => AlphaHashableE (TabPiType r) where
20562036
hashWithSaltE env salt (TabPiType _ b t) = hashWithSaltE env salt $ Abs b t
20572037

2038+
instance HasNameHint (TabPiType r n) where
2039+
getNameHint (TabPiType _ b _) = getNameHint b
2040+
20582041
instance IRRep r => SinkableE (TabPiType r)
20592042
instance IRRep r => HoistableE (TabPiType r)
20602043
instance IRRep r => RenameE (TabPiType r)

0 commit comments

Comments
 (0)