@@ -8,7 +8,7 @@ module QueryType (
88 instantiateDataDef , instantiateDepPairTy , instantiatePi , instantiateTabPi ,
99 litType , lamExprTy ,
1010 numNaryPiArgs , naryLamExprType ,
11- oneEffect , projectLength , sourceNameType , typeAsBinderNest , typeBinOp ,
11+ oneEffect , projectLength , sourceNameType , typeAsBinderNest , typeBinOp , typeUnOp ,
1212 isSingletonType , singletonTypeVal ,
1313 ) where
1414
@@ -219,6 +219,9 @@ typeBinOp binop xTy = case binop of
219219 BXor -> xTy
220220 BShL -> xTy; BShR -> xTy
221221
222+ typeUnOp :: UnOp -> BaseType -> BaseType
223+ typeUnOp = const id -- All unary ops preserve the type of the input
224+
222225-- === computing effects ===
223226
224227computeAbsEffects :: (EnvExtender m , SubstE Name e )
@@ -476,8 +479,7 @@ getTypePrimOp op = case op of
476479 ScalarBinOp binop x _ -> do
477480 xTy <- getTypeBaseType x
478481 return $ TC $ BaseType $ typeBinOp binop xTy
479- -- All unary ops preserve the type of the input
480- ScalarUnOp _ x -> getTypeE x
482+ ScalarUnOp unop x -> TC . BaseType . typeUnOp unop <$> getTypeBaseType x
481483 Select _ x _ -> getTypeE x
482484 UnsafeFromOrdinal ty _ -> substM ty
483485 ToOrdinal _ -> return IdxRepTy
0 commit comments