66
77{-# LANGUAGE FlexibleContexts #-}
88{-# LANGUAGE RecordWildCards #-}
9+ {-# LANGUAGE TypeFamilies #-}
10+ {-# LANGUAGE StandaloneDeriving #-}
11+ {-# LANGUAGE DerivingVia #-}
12+ {-# LANGUAGE UndecidableInstances #-}
913
1014module Export (
1115 exportFunctions , prepareFunctionForExport , exportedSignatureDesc ,
12- ExportedSignature (.. ), ExportArrayType (.. ), ExportArg (.. ), ExportResult (.. ),
16+ ExportedSignature (.. ), ExportType (.. ), ExportArg (.. ), ExportResult (.. ),
1317 ) where
1418
15- import Control.Monad.State.Strict
16- import Control.Monad.Writer hiding (pass )
17- import qualified Data.Map.Strict as M
18- import qualified Data.Text as T
19- import Data.String
20- import Data.Foldable
21- import Data.List (nub , intercalate )
19+ import Data.List (intercalate )
2220
23- import Algebra
21+ import Name
22+ import Err
2423import Syntax
25- import Builder
26- import Cat
2724import Type
2825import Simplify
2926import Imp
30- import JIT
31- import Logging
32- import LLVMExec
33- import PPrint
34-
35- exportFunctions :: FilePath -> [(String , Atom )] -> Bindings -> IO ()
36- exportFunctions = undefined
37- -- exportFunctions objPath funcs env = do
38- -- let names = fmap fst funcs
39- -- unless (length (nub names) == length names) $
40- -- throw CompilerErr "Duplicate export names"
41- -- modules <- forM funcs $ \(name, funcAtom) -> do
42- -- let (impModule, _) = prepareFunctionForExport env name funcAtom
43- -- (,[name]) <$> execLogger Nothing (flip impToLLVM impModule)
44- -- exportObjectFile modules \tm m -> Mod.writeObjectToFile tm (Mod.File objPath) m
45-
46- type CArgList = [IBinder ] -- ^ List of arguments to the C call
47- data CArgSubst = CArgEnv { -- | Maps scalar atom binders to their CArgs. All atoms are Vars.
48- cargScalarScope :: SubstSubst
49- -- | Tracks the CArg names used so far (globally scoped, unlike Builder)
50- , cargScope :: Subst () }
51- type CArgM = WriterT CArgList (CatT CArgSubst Builder )
52-
53- instance Semigroup CArgSubst where
54- (CArgSubst a1 a2) <> (CArgEnv b1 b2) = CArgEnv (a1 <> b1) (a2 <> b2)
55-
56- instance Monoid CArgSubst where
57- mempty = CArgSubst mempty mempty
58-
59- runCArg :: CArgSubst -> CArgM a -> Builder (a , [IBinder ], CArgEnv )
60- runCArg initSubst m = repack <$> runCatT (runWriterT m) initEnv
61- where repack ((ans, cargs), env) = (ans, cargs, env)
62-
63- prepareFunctionForExport :: Bindings -> String -> Atom -> (ImpModule , ExportedSignature )
64- prepareFunctionForExport env nameStr func = do
65- -- Create a module that simulates an application of arguments to the function
66- -- TODO: Assert that the type of func is closed?
67- let ((dest, cargs, apiDesc, resultName, resultType), (_, decls)) = runBuilder (freeVars func) mempty $ do
68- (args, cargArgs, cargSubst) <- runCArg mempty $ createArgs $ getType func
69- let (atomArgs, exportedArgSig) = unzip args
70- resultAtom <- naryApp func atomArgs
71- ~ (Var (outputName:> outputType)) <- emit $ Atom resultAtom
72- ((resultDest, exportedResSig), cdestArgs, _) <- runCArg cargSubst $ createDest mempty $ getType resultAtom
73- let cargs' = cargArgs <> cdestArgs
74- let exportedCCallSig = fmap (\ (Bind (v:> _)) -> v) cargs'
75- return (resultDest, cargs', ExportedSignature {.. }, outputName, outputType)
76-
77- let coreModule = Module Core decls $ EvaluatedModule mempty mempty $
78- SourceMap $ M. singleton outputSourceName $ SrcAtomName resultName
79- let defunctionalized = simplifyModule env coreModule
80- let Module _ optDecls (EvaluatedModule optBindings _ (SourceMap sourceMap)) =
81- optimizeModule defunctionalized
82- let ~ (Just (SrcAtomName outputName)) = M. lookup outputSourceName sourceMap
83- -- XXX: this is a terrible hack. We could require any number of hops through
84- -- the evaluated bindings. TODO: reconstruct the result properly.
85- let outputExpr = case envLookup optBindings outputName of
86- Just ~ (AtomBinderInfo _ (LetBound PlainLet expr))-> expr
87- Nothing -> Atom $ Var $ outputName :> resultType
88- let block = Block optDecls outputExpr
89- let name = Name TopFunctionName (fromString nameStr) 0
90- let (_, impModule, _) = toImpModule env LLVM CEntryFun name cargs (Just dest) block
91- (impModule, apiDesc)
27+
28+ exportFunctions :: FilePath -> [(String , Atom n )] -> Env n -> IO ()
29+ exportFunctions = error " Not implemented"
30+
31+ prepareFunctionForExport :: (EnvReader m , Fallible1 m ) => Atom n -> m n (ImpFunction n , ExportedSignature VoidS )
32+ prepareFunctionForExport f = do
33+ naryPi <- getType f >>= asFirstOrderFunction >>= \ case
34+ Nothing -> throw TypeErr " Only first-order functions can be exported"
35+ Just npi -> return npi
36+ closedNaryPi <- case hoistToTop naryPi of
37+ HoistFailure _ -> throw TypeErr " Types of exported functions have to be closed terms"
38+ HoistSuccess npi -> return npi
39+ sig <- case runFallibleM $ runEnvReaderT emptyOutMap $ naryPiToExportSig closedNaryPi of
40+ Success sig -> return sig
41+ Failure err -> throwErrs err
42+ fSimp <- simplifyTopFunction naryPi f
43+ fImp <- toImpExportedFunction fSimp
44+ return (fImp, sig)
9245 where
93- outputSourceName = " _ans_"
94-
95- createArgs :: Type -> CArgM [(Atom , ExportArg )]
96- createArgs ty = case ty of
97- PiTy b arrow result | arrow /= TabArrow -> do
98- argSubst <- looks cargScalarScope
99- let visibility = case arrow of
100- PlainArrow Pure -> ExplicitArg
101- PlainArrow _ -> error $ " Effectful functions cannot be exported"
102- ImplicitArrow -> ImplicitArg
103- _ -> error $ " Unexpected type for an exported function: " ++ pprint ty
104- (:) <$> createArg visibility (subst (argSubst, mempty ) b) <*> createArgs result
105- _ -> return []
106-
107- createArg :: ArgVisibility -> Binder -> CArgM (Atom , ExportArg )
108- createArg vis b = case ty of
109- BaseTy bt@ (Scalar sbt) -> do
110- ~ v@ (Var (name:> _)) <- newCVar bt
111- extend $ mempty { cargScalarScope = b @> SubstVal (Var $ name :> BaseTy bt) }
112- return (v, ExportScalarArg vis name sbt)
113- TabTy _ _ -> createTabArg vis mempty ty
114- _ -> error $ " Unsupported arg type: " ++ pprint ty
115- where ty = binderType b
116-
117- createTabArg :: ArgVisibility -> IndexStructure -> Type -> CArgM (Atom , ExportArg )
118- createTabArg vis idx ty = case ty of
119- BaseTy bt@ (Scalar sbt) -> do
120- ~ v@ (Var (name:> _)) <- newCVar (ptrTy bt)
121- destAtom <- unsafePtrLoad =<< applyIdxs v idx
122- funcArgScope <- looks cargScope
123- let exportArg = ExportArrayArg vis name $ case getRectShape funcArgScope idx of
124- Just rectShape -> RectContArrayPtr sbt rectShape
125- Nothing -> GeneralArrayPtr sbt
126- return (destAtom, exportArg)
127- TabTy b elemTy -> do
128- buildLamAux b (const $ return TabArrow ) $ \ (Var i) -> do
129- elemTy' <- substBuilder (b@> SubstVal (Var i)) elemTy
130- createTabArg vis (idx <> Nest (Bind i) Empty ) elemTy'
131- _ -> unsupported
132- where unsupported = error $ " Unsupported table type suffix: " ++ pprint ty
133-
134- createDest :: IndexStructure -> Type -> CArgM (Atom , ExportResult )
135- createDest idx ty = case ty of
136- BaseTy bt@ (Scalar sbt) -> do
137- ~ v@ (Var (name:> _)) <- newCVar (ptrTy bt)
138- dest <- Con . BaseTypeRef <$> applyIdxs v idx
139- funcArgScope <- looks cargScope
140- let exportResult = case idx of
141- Empty -> ExportScalarResultPtr name sbt
142- _ -> ExportArrayResult name $ case getRectShape funcArgScope idx of
143- Just rectShape -> RectContArrayPtr sbt rectShape
144- Nothing -> GeneralArrayPtr sbt
145- return (dest, exportResult)
146- TabTy b elemTy -> do
147- (destTab, exportResult) <- buildLamAux b (const $ return TabArrow ) $ \ (Var i) -> do
148- elemTy' <- substBuilder (b@> SubstVal (Var i)) elemTy
149- createDest (idx <> Nest (Bind i) Empty ) elemTy'
150- return (Con $ TabRef destTab, exportResult)
151- PairTy a b | idx == Empty -> do
152- (atom_a, res_a) <- createDest idx a
153- (atom_b, res_b) <- createDest idx b
154- return (Con $ ConRef $ ProdCon [atom_a, atom_b], ExportPairResult res_a res_b)
155- _ -> unsupported
156- where unsupported = error $ " Unsupported result type: " ++ pprint ty
157-
158- -- TODO: I guess that the address space depends on the backend?
159- -- TODO: Have an ExternalPtr tag?
160- ptrTy ty = PtrType (Heap CPU , ty)
161-
162- getRectShape :: Subst () -> IndexStructure -> Maybe [Either Name Int ]
163- getRectShape scope idx = traverse (dimShape . binderType) $ toList idx
46+ naryPiToExportSig :: (EnvReader m , EnvExtender m , Fallible1 m )
47+ => NaryPiType n -> m n (ExportedSignature n )
48+ naryPiToExportSig (NaryPiType (NonEmptyNest tb tbs) effs resultTy) = do
49+ case effs of
50+ Pure -> return ()
51+ _ -> throw TypeErr " Only pure functions can be exported"
52+ goArgs Empty [] (Nest tb tbs) resultTy
16453 where
165- dimShape dimTy = case dimTy of
166- Fin (IdxRepVal n) -> Just $ Right $ fromIntegral n
167- Fin (Var v) | v `isin` scope -> Just $ Left $ varName v
168- _ -> Nothing
169-
170- newCVar :: BaseType -> CArgM Atom
171- newCVar bt = do
172- name <- genFresh (Name CArgName " arg" 0 ) <$> looks cargScope
173- extend $ mempty { cargScope = name @> () }
174- tell [Bind $ name :> bt]
175- return $ Var $ name :> BaseTy bt
54+ goArgs :: (EnvReader m , EnvExtender m , Fallible1 m )
55+ => Nest ExportArg n l' -> [AtomName l' ] -> Nest PiBinder l' l -> Type l -> m l' (ExportedSignature n )
56+ goArgs argSig argVs piBs piRes = case piBs of
57+ Empty -> goResult piRes \ resSig ->
58+ return $ ExportedSignature argSig resSig $
59+ (fromListE $ sink $ ListE argVs) ++ nestToList (sink . binderName) resSig
60+ Nest b bs -> do
61+ refreshAbs (Abs b (Abs bs piRes)) \ (PiBinder v ty arrow) (Abs bs' piRes') -> do
62+ let invalidArrow = throw TypeErr
63+ " Exported functions can only have regular and implicit arrow types"
64+ vis <- case arrow of
65+ PlainArrow -> return ExplicitArg
66+ ImplicitArrow -> return ImplicitArg
67+ ClassArrow -> invalidArrow
68+ TabArrow -> invalidArrow
69+ LinArrow -> invalidArrow
70+ ety <- toExportType ty
71+ goArgs (argSig `joinNest` Nest (ExportArg vis (v:> ety)) Empty )
72+ ((fromListE $ sink $ ListE argVs) ++ [binderName v]) bs' piRes'
73+
74+ goResult :: (EnvReader m , EnvExtender m , Fallible1 m )
75+ => Type l
76+ -> (forall q . DExt l q => Nest ExportResult l q -> m q a )
77+ -> m l a
78+ goResult ty cont = case ty of
79+ ProdTy [lty, rty] ->
80+ goResult lty \ lres ->
81+ goResult (sink rty) \ rres ->
82+ cont $ joinNest lres rres
83+ _ -> withFreshBinder NoHint ty \ b -> do
84+ ety <- toExportType ty
85+ cont $ Nest (ExportResult (b:> ety)) Empty
86+
87+ toExportType :: Fallible m => Type n -> m (ExportType n )
88+ toExportType = \ case
89+ BaseTy (Scalar sbt) -> return $ ScalarType sbt
90+ -- TODO: Arrays!
91+ ty -> throw TypeErr $ " Unsupported type of argument in exported function: " ++ pprint ty
17692
17793-- === Exported function signature ===
17894
179- data ExportArrayType = GeneralArrayPtr ScalarBaseType
180- | RectContArrayPtr ScalarBaseType [Either Name Int ]
18195data ArgVisibility = ImplicitArg | ExplicitArg
182- data ExportArg = ExportArrayArg ArgVisibility Name ExportArrayType
183- | ExportScalarArg ArgVisibility Name ScalarBaseType
184- data ExportResult = ExportArrayResult Name ExportArrayType
185- | ExportScalarResultPtr Name ScalarBaseType
186- | ExportPairResult ExportResult ExportResult
187- data ExportedSignature =
188- ExportedSignature { exportedArgSig :: [ ExportArg ]
189- , exportedResSig :: ExportResult
190- , exportedCCallSig :: [Name ]
96+ data ExportType n = RectContArrayPtr ScalarBaseType [ Either ( AtomName n ) Int ]
97+ | ScalarType ScalarBaseType
98+
99+ data ExportArg n l = ExportArg ArgVisibility ( BinderP AtomNameC ExportType n l )
100+ newtype ExportResult n l = ExportResult ( BinderP AtomNameC ExportType n l )
101+ data ExportedSignature n = forall l l' .
102+ ExportedSignature { exportedArgSig :: Nest ExportArg n l
103+ , exportedResSig :: Nest ExportResult l l'
104+ , exportedCCallSig :: [AtomName l' ]
191105 }
192106
107+ deriving via (BinderP AtomNameC ExportType ) instance GenericB ExportResult
108+ deriving via (BinderP AtomNameC ExportType ) instance ProvesExt ExportResult
109+ deriving via (BinderP AtomNameC ExportType ) instance BindsNames ExportResult
110+ instance BindsAtMostOneName ExportResult AtomNameC where
111+ (ExportResult b) @> v = b @> v
112+ instance BindsOneName ExportResult AtomNameC where
113+ binderName (ExportResult b) = binderName b
114+
115+ instance GenericB ExportArg where
116+ type RepB ExportArg = PairB (LiftB (LiftE ArgVisibility )) (BinderP AtomNameC ExportType )
117+ fromB (ExportArg vis b) = PairB (LiftB (LiftE vis)) b
118+ toB (PairB (LiftB (LiftE vis)) b) = ExportArg vis b
119+ instance ProvesExt ExportArg
120+ instance BindsNames ExportArg
121+ instance BindsAtMostOneName ExportArg AtomNameC where
122+ (ExportArg _ b) @> v = b @> v
123+ instance BindsOneName ExportArg AtomNameC where
124+ binderName (ExportArg _ b) = binderName b
125+
193126-- Serialization
194127
195- exportedSignatureDesc :: ExportedSignature -> (String , String , String )
128+ exportedSignatureDesc :: ExportedSignature n -> (String , String , String )
196129exportedSignatureDesc ExportedSignature {.. } =
197- ( intercalate " ," $ fmap show exportedArgSig
198- , show exportedResSig
199- , intercalate " ," $ fmap showCArgName exportedCCallSig
130+ ( intercalate " ," $ nestToList show exportedArgSig
131+ , intercalate " , " $ nestToList show exportedResSig
132+ , intercalate " ," $ fmap pprint exportedCCallSig
200133 )
201134
202135showExportSBT :: ScalarBaseType -> String
@@ -209,34 +142,21 @@ showExportSBT sbt = case sbt of
209142 Float64Type -> " f64"
210143 Float32Type -> " f32"
211144
212- showCArgName :: Name -> String
213- showCArgName ~ name@ (Name namespace tag idx) = case namespace of
214- CArgName -> T. unpack tag <> show idx
215- _ -> error $ " Expected a CArgName namespace: " ++ show name
216-
217- instance Show ExportArrayType where
145+ instance Show (ExportType n ) where
218146 show arr = case arr of
219- GeneralArrayPtr sbt -> showExportSBT sbt <> " [?]"
220147 RectContArrayPtr sbt shape -> showExportSBT sbt <> showShape shape
148+ ScalarType sbt -> showExportSBT sbt
221149 where
222150 showShape shape = " [" <> (intercalate " ," $ fmap showDim shape) <> " ]"
223151 showDim size = case size of
224- Left name -> showCArgName name
152+ Left name -> pprint name
225153 Right lit -> show lit
226154
227- instance Show ExportArg where
228- show arg = case arg of
229- ExportArrayArg vis name ty -> showVis vis <> showCArgName name <> " :" <> show ty
230- ExportScalarArg vis name sbt -> showVis vis <> showCArgName name <> " :" <> showExportSBT sbt
155+ instance Show (ExportArg n l ) where
156+ show (ExportArg vis (name:> ty)) = showVis vis <> pprint name <> " :" <> show ty
231157 where
232158 showVis ImplicitArg = " ?"
233159 showVis ExplicitArg = " "
234160
235- instance Show ExportResult where
236- show res = case res of
237- ExportArrayResult name ty -> showCArgName name <> " :" <> show ty
238- ExportScalarResultPtr name sbt -> showCArgName name <> " :" <> showExportSBT sbt
239- -- Nested pairs / tuples are compiled down to a sequence of separate output
240- -- arguments, so a pair result is serialized to look like two separate
241- -- results.
242- ExportPairResult left right -> show left <> " ," <> show right
161+ instance Show (ExportResult n l ) where
162+ show (ExportResult (name:> ty)) = pprint name <> " :" <> show ty
0 commit comments