Skip to content

Commit 11367a9

Browse files
committed
Touchups in response to code review of PR 894.
1 parent a7964d6 commit 11367a9

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

src/lib/Imp.hs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,8 +1591,7 @@ impInstrTypes instr = case instr of
15911591
impOpType :: IPrimOp n -> IType
15921592
impOpType pop = case pop of
15931593
ScalarBinOp op x _ -> typeBinOp op (getIType x)
1594-
-- All unary ops preserve the type of their input
1595-
ScalarUnOp _ x -> getIType x
1594+
ScalarUnOp op x -> typeUnOp op (getIType x)
15961595
VectorBinOp op x _ -> typeBinOp op (getIType x)
15971596
Select _ x _ -> getIType x
15981597
VectorPack xs -> Vector ty where Scalar ty = getIType $ head xs

src/lib/Linearize.hs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,16 @@ liftTangentM args m = liftSubstReaderT $ lift11 $ runReaderT1 args m
143143

144144
isTrivialForAD :: Expr o -> PrimalM i o Bool
145145
isTrivialForAD expr = do
146-
trivialTy <- any isSingletonType . maybeTangentType <$> getType expr
146+
trivialTy <- presentAnd isSingletonType . maybeTangentType <$> getType expr
147147
hasActiveEffs <- getEffects expr >>= \case
148148
Pure -> return False
149149
-- TODO: Be more precise here, such as checking
150150
-- whether the effects are themselves active.
151151
_ -> return True
152152
hasActiveVars <- isActive expr
153153
return $ not hasActiveEffs && (trivialTy || not hasActiveVars)
154+
where presentAnd :: (a -> Bool) -> Maybe a -> Bool
155+
presentAnd = any
154156

155157
isActive :: HoistableE e => e o -> PrimalM i o Bool
156158
isActive e = do

src/lib/QueryType.hs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module QueryType (
88
instantiateDataDef, instantiateDepPairTy, instantiatePi, instantiateTabPi,
99
litType, lamExprTy,
1010
numNaryPiArgs, naryLamExprType,
11-
oneEffect, projectLength, sourceNameType, typeAsBinderNest, typeBinOp,
11+
oneEffect, projectLength, sourceNameType, typeAsBinderNest, typeBinOp, typeUnOp,
1212
isSingletonType, singletonTypeVal,
1313
) where
1414

@@ -219,6 +219,9 @@ typeBinOp binop xTy = case binop of
219219
BXor -> xTy
220220
BShL -> xTy; BShR -> xTy
221221

222+
typeUnOp :: UnOp -> BaseType -> BaseType
223+
typeUnOp = const id -- All unary ops preserve the type of the input
224+
222225
-- === computing effects ===
223226

224227
computeAbsEffects :: (EnvExtender m, SubstE Name e)
@@ -476,8 +479,7 @@ getTypePrimOp op = case op of
476479
ScalarBinOp binop x _ -> do
477480
xTy <- getTypeBaseType x
478481
return $ TC $ BaseType $ typeBinOp binop xTy
479-
-- All unary ops preserve the type of the input
480-
ScalarUnOp _ x -> getTypeE x
482+
ScalarUnOp unop x -> TC . BaseType . typeUnOp unop <$> getTypeBaseType x
481483
Select _ x _ -> getTypeE x
482484
UnsafeFromOrdinal ty _ -> substM ty
483485
ToOrdinal _ -> return IdxRepTy

0 commit comments

Comments
 (0)