1111module Inference
1212 ( inferTopUDecl , checkTopUType , inferTopUExpr
1313 , trySynthTerm , generalizeDict , asTopBlock
14- , synthTopE , UDeclInferenceResult (.. )) where
14+ , synthTopE , UDeclInferenceResult (.. ), asFFIFunType ) where
1515
1616import Prelude hiding ((.) , id )
1717import Control.Category
@@ -45,6 +45,7 @@ import SourceInfo
4545import Subst
4646import QueryType
4747import Types.Core
48+ import Types.Imp
4849import Types.Primitives
4950import Types.Source
5051import Util hiding (group )
@@ -3184,6 +3185,43 @@ withFabricatedEmitsInf cont = fromWrapWithEmitsInf
31843185newtype WrapWithEmitsInf n r =
31853186 WrapWithEmitsInf { fromWrapWithEmitsInf :: EmitsInf n => r }
31863187
3188+ -- === IFunType ===
3189+
3190+ asFFIFunType :: EnvReader m => CType n -> m n (Maybe (IFunType , CorePiType n ))
3191+ asFFIFunType ty = return do
3192+ Pi piTy <- return ty
3193+ impTy <- checkFFIFunTypeM piTy
3194+ return (impTy, piTy)
3195+
3196+ checkFFIFunTypeM :: Fallible m => CorePiType n -> m IFunType
3197+ checkFFIFunTypeM (CorePiType appExpl (_: expls) (Nest b bs) effTy) = do
3198+ argTy <- checkScalar $ binderType b
3199+ case bs of
3200+ Empty -> do
3201+ resultTys <- checkScalarOrPairType (etTy effTy)
3202+ let cc = case length resultTys of
3203+ 0 -> error " Not implemented"
3204+ 1 -> FFICC
3205+ _ -> FFIMultiResultCC
3206+ return $ IFunType cc [argTy] resultTys
3207+ Nest b' rest -> do
3208+ let naryPiRest = CorePiType appExpl expls (Nest b' rest) effTy
3209+ IFunType cc argTys resultTys <- checkFFIFunTypeM naryPiRest
3210+ return $ IFunType cc (argTy: argTys) resultTys
3211+ checkFFIFunTypeM _ = error " expected at least one argument"
3212+
3213+ checkScalar :: (IRRep r , Fallible m ) => Type r n -> m BaseType
3214+ checkScalar (BaseTy ty) = return ty
3215+ checkScalar ty = throw TypeErr $ pprint ty
3216+
3217+ checkScalarOrPairType :: (IRRep r , Fallible m ) => Type r n -> m [BaseType ]
3218+ checkScalarOrPairType (PairTy a b) = do
3219+ tys1 <- checkScalarOrPairType a
3220+ tys2 <- checkScalarOrPairType b
3221+ return $ tys1 ++ tys2
3222+ checkScalarOrPairType (BaseTy ty) = return [ty]
3223+ checkScalarOrPairType ty = throw TypeErr $ pprint ty
3224+
31873225-- === instances ===
31883226
31893227instance PrettyE e => Pretty (UDeclInferenceResult e l ) where
@@ -3197,9 +3235,11 @@ instance SinkableE e => SinkableE (UDeclInferenceResult e) where
31973235
31983236instance (RenameE e , CheckableE CoreIR e ) => CheckableE CoreIR (UDeclInferenceResult e ) where
31993237 checkE = \ case
3200- UDeclResultDone _ -> return ()
3201- UDeclResultBindName _ block _ -> checkE block
3202- UDeclResultBindPattern _ block _ -> checkE block
3238+ UDeclResultDone e -> UDeclResultDone <$> checkE e
3239+ UDeclResultBindName ann block ab ->
3240+ UDeclResultBindName ann <$> checkE block <*> renameM ab -- TODO: check result
3241+ UDeclResultBindPattern hint block recon ->
3242+ UDeclResultBindPattern hint <$> checkE block <*> renameM recon -- TODO: check recon
32033243
32043244instance HasType CoreIR InfEmission where
32053245 getType = \ case
0 commit comments