Skip to content

Commit 04a3d32

Browse files
committed
Reimplement basic export functionality
It's still not quite as powerful as it used to be (e.g. it doesn't support table arguments/results) and is not used anywhere yet. But it's a good point to checkpoint the progress.
1 parent 65ee2e0 commit 04a3d32

File tree

4 files changed

+141
-193
lines changed

4 files changed

+141
-193
lines changed

dex.cabal

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ library
5656
LLVM.JIT, LLVM.Shims, JIT, LLVMExec,
5757
Err, LabeledItems, SourceRename, Name, Parser, MTL1,
5858
Type, Builder, Inference, CheapReduction, GenericTraversal,
59-
Simplify, Imp, Algebra, Linearize, Transpose,
59+
Simplify, Imp, Algebra, Linearize, Transpose, Export,
6060
LLVM.HEAD.JIT
6161
if flag(live)
6262
exposed-modules: Actor, RenderHtml, LiveOutput

src/lib/Export.hs

Lines changed: 110 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -6,197 +6,130 @@
66

77
{-# LANGUAGE FlexibleContexts #-}
88
{-# LANGUAGE RecordWildCards #-}
9+
{-# LANGUAGE TypeFamilies #-}
10+
{-# LANGUAGE StandaloneDeriving #-}
11+
{-# LANGUAGE DerivingVia #-}
12+
{-# LANGUAGE UndecidableInstances #-}
913

1014
module Export (
1115
exportFunctions, prepareFunctionForExport, exportedSignatureDesc,
12-
ExportedSignature (..), ExportArrayType (..), ExportArg (..), ExportResult (..),
16+
ExportedSignature (..), ExportType (..), ExportArg (..), ExportResult (..),
1317
) where
1418

15-
import Control.Monad.State.Strict
16-
import Control.Monad.Writer hiding (pass)
17-
import qualified Data.Map.Strict as M
18-
import qualified Data.Text as T
19-
import Data.String
20-
import Data.Foldable
21-
import Data.List (nub, intercalate)
19+
import Data.List (intercalate)
2220

23-
import Algebra
21+
import Name
22+
import Err
2423
import Syntax
25-
import Builder
26-
import Cat
2724
import Type
2825
import Simplify
2926
import Imp
30-
import JIT
31-
import Logging
32-
import LLVMExec
33-
import PPrint
34-
35-
exportFunctions :: FilePath -> [(String, Atom)] -> Bindings -> IO ()
36-
exportFunctions = undefined
37-
-- exportFunctions objPath funcs env = do
38-
-- let names = fmap fst funcs
39-
-- unless (length (nub names) == length names) $
40-
-- throw CompilerErr "Duplicate export names"
41-
-- modules <- forM funcs $ \(name, funcAtom) -> do
42-
-- let (impModule, _) = prepareFunctionForExport env name funcAtom
43-
-- (,[name]) <$> execLogger Nothing (flip impToLLVM impModule)
44-
-- exportObjectFile modules \tm m -> Mod.writeObjectToFile tm (Mod.File objPath) m
45-
46-
type CArgList = [IBinder] -- ^ List of arguments to the C call
47-
data CArgSubst = CArgEnv { -- | Maps scalar atom binders to their CArgs. All atoms are Vars.
48-
cargScalarScope :: SubstSubst
49-
-- | Tracks the CArg names used so far (globally scoped, unlike Builder)
50-
, cargScope :: Subst () }
51-
type CArgM = WriterT CArgList (CatT CArgSubst Builder)
52-
53-
instance Semigroup CArgSubst where
54-
(CArgSubst a1 a2) <> (CArgEnv b1 b2) = CArgEnv (a1 <> b1) (a2 <> b2)
55-
56-
instance Monoid CArgSubst where
57-
mempty = CArgSubst mempty mempty
58-
59-
runCArg :: CArgSubst -> CArgM a -> Builder (a, [IBinder], CArgEnv)
60-
runCArg initSubst m = repack <$> runCatT (runWriterT m) initEnv
61-
where repack ((ans, cargs), env) = (ans, cargs, env)
62-
63-
prepareFunctionForExport :: Bindings -> String -> Atom -> (ImpModule, ExportedSignature)
64-
prepareFunctionForExport env nameStr func = do
65-
-- Create a module that simulates an application of arguments to the function
66-
-- TODO: Assert that the type of func is closed?
67-
let ((dest, cargs, apiDesc, resultName, resultType), (_, decls)) = runBuilder (freeVars func) mempty $ do
68-
(args, cargArgs, cargSubst) <- runCArg mempty $ createArgs $ getType func
69-
let (atomArgs, exportedArgSig) = unzip args
70-
resultAtom <- naryApp func atomArgs
71-
~(Var (outputName:>outputType)) <- emit $ Atom resultAtom
72-
((resultDest, exportedResSig), cdestArgs, _) <- runCArg cargSubst $ createDest mempty $ getType resultAtom
73-
let cargs' = cargArgs <> cdestArgs
74-
let exportedCCallSig = fmap (\(Bind (v:>_)) -> v) cargs'
75-
return (resultDest, cargs', ExportedSignature{..}, outputName, outputType)
76-
77-
let coreModule = Module Core decls $ EvaluatedModule mempty mempty $
78-
SourceMap $ M.singleton outputSourceName $ SrcAtomName resultName
79-
let defunctionalized = simplifyModule env coreModule
80-
let Module _ optDecls (EvaluatedModule optBindings _ (SourceMap sourceMap)) =
81-
optimizeModule defunctionalized
82-
let ~(Just (SrcAtomName outputName)) = M.lookup outputSourceName sourceMap
83-
-- XXX: this is a terrible hack. We could require any number of hops through
84-
-- the evaluated bindings. TODO: reconstruct the result properly.
85-
let outputExpr = case envLookup optBindings outputName of
86-
Just ~(AtomBinderInfo _ (LetBound PlainLet expr))-> expr
87-
Nothing -> Atom $ Var $ outputName :> resultType
88-
let block = Block optDecls outputExpr
89-
let name = Name TopFunctionName (fromString nameStr) 0
90-
let (_, impModule, _) = toImpModule env LLVM CEntryFun name cargs (Just dest) block
91-
(impModule, apiDesc)
27+
28+
exportFunctions :: FilePath -> [(String, Atom n)] -> Env n -> IO ()
29+
exportFunctions = error "Not implemented"
30+
31+
prepareFunctionForExport :: (EnvReader m, Fallible1 m) => Atom n -> m n (ImpFunction n, ExportedSignature VoidS)
32+
prepareFunctionForExport f = do
33+
naryPi <- getType f >>= asFirstOrderFunction >>= \case
34+
Nothing -> throw TypeErr "Only first-order functions can be exported"
35+
Just npi -> return npi
36+
closedNaryPi <- case hoistToTop naryPi of
37+
HoistFailure _ -> throw TypeErr "Types of exported functions have to be closed terms"
38+
HoistSuccess npi -> return npi
39+
sig <- case runFallibleM $ runEnvReaderT emptyOutMap $ naryPiToExportSig closedNaryPi of
40+
Success sig -> return sig
41+
Failure err -> throwErrs err
42+
fSimp <- simplifyTopFunction naryPi f
43+
fImp <- toImpExportedFunction fSimp
44+
return (fImp, sig)
9245
where
93-
outputSourceName = "_ans_"
94-
95-
createArgs :: Type -> CArgM [(Atom, ExportArg)]
96-
createArgs ty = case ty of
97-
PiTy b arrow result | arrow /= TabArrow -> do
98-
argSubst <- looks cargScalarScope
99-
let visibility = case arrow of
100-
PlainArrow Pure -> ExplicitArg
101-
PlainArrow _ -> error $ "Effectful functions cannot be exported"
102-
ImplicitArrow -> ImplicitArg
103-
_ -> error $ "Unexpected type for an exported function: " ++ pprint ty
104-
(:) <$> createArg visibility (subst (argSubst, mempty) b) <*> createArgs result
105-
_ -> return []
106-
107-
createArg :: ArgVisibility -> Binder -> CArgM (Atom, ExportArg)
108-
createArg vis b = case ty of
109-
BaseTy bt@(Scalar sbt) -> do
110-
~v@(Var (name:>_)) <- newCVar bt
111-
extend $ mempty { cargScalarScope = b @> SubstVal (Var $ name :> BaseTy bt) }
112-
return (v, ExportScalarArg vis name sbt)
113-
TabTy _ _ -> createTabArg vis mempty ty
114-
_ -> error $ "Unsupported arg type: " ++ pprint ty
115-
where ty = binderType b
116-
117-
createTabArg :: ArgVisibility -> IndexStructure -> Type -> CArgM (Atom, ExportArg)
118-
createTabArg vis idx ty = case ty of
119-
BaseTy bt@(Scalar sbt) -> do
120-
~v@(Var (name:>_)) <- newCVar (ptrTy bt)
121-
destAtom <- unsafePtrLoad =<< applyIdxs v idx
122-
funcArgScope <- looks cargScope
123-
let exportArg = ExportArrayArg vis name $ case getRectShape funcArgScope idx of
124-
Just rectShape -> RectContArrayPtr sbt rectShape
125-
Nothing -> GeneralArrayPtr sbt
126-
return (destAtom, exportArg)
127-
TabTy b elemTy -> do
128-
buildLamAux b (const $ return TabArrow) $ \(Var i) -> do
129-
elemTy' <- substBuilder (b@>SubstVal (Var i)) elemTy
130-
createTabArg vis (idx <> Nest (Bind i) Empty) elemTy'
131-
_ -> unsupported
132-
where unsupported = error $ "Unsupported table type suffix: " ++ pprint ty
133-
134-
createDest :: IndexStructure -> Type -> CArgM (Atom, ExportResult)
135-
createDest idx ty = case ty of
136-
BaseTy bt@(Scalar sbt) -> do
137-
~v@(Var (name:>_)) <- newCVar (ptrTy bt)
138-
dest <- Con . BaseTypeRef <$> applyIdxs v idx
139-
funcArgScope <- looks cargScope
140-
let exportResult = case idx of
141-
Empty -> ExportScalarResultPtr name sbt
142-
_ -> ExportArrayResult name $ case getRectShape funcArgScope idx of
143-
Just rectShape -> RectContArrayPtr sbt rectShape
144-
Nothing -> GeneralArrayPtr sbt
145-
return (dest, exportResult)
146-
TabTy b elemTy -> do
147-
(destTab, exportResult) <- buildLamAux b (const $ return TabArrow) $ \(Var i) -> do
148-
elemTy' <- substBuilder (b@>SubstVal (Var i)) elemTy
149-
createDest (idx <> Nest (Bind i) Empty) elemTy'
150-
return (Con $ TabRef destTab, exportResult)
151-
PairTy a b | idx == Empty -> do
152-
(atom_a, res_a) <- createDest idx a
153-
(atom_b, res_b) <- createDest idx b
154-
return (Con $ ConRef $ ProdCon [atom_a, atom_b], ExportPairResult res_a res_b)
155-
_ -> unsupported
156-
where unsupported = error $ "Unsupported result type: " ++ pprint ty
157-
158-
-- TODO: I guess that the address space depends on the backend?
159-
-- TODO: Have an ExternalPtr tag?
160-
ptrTy ty = PtrType (Heap CPU, ty)
161-
162-
getRectShape :: Subst () -> IndexStructure -> Maybe [Either Name Int]
163-
getRectShape scope idx = traverse (dimShape . binderType) $ toList idx
46+
naryPiToExportSig :: (EnvReader m, EnvExtender m, Fallible1 m)
47+
=> NaryPiType n -> m n (ExportedSignature n)
48+
naryPiToExportSig (NaryPiType (NonEmptyNest tb tbs) effs resultTy) = do
49+
case effs of
50+
Pure -> return ()
51+
_ -> throw TypeErr "Only pure functions can be exported"
52+
goArgs Empty [] (Nest tb tbs) resultTy
16453
where
165-
dimShape dimTy = case dimTy of
166-
Fin (IdxRepVal n) -> Just $ Right $ fromIntegral n
167-
Fin (Var v) | v `isin` scope -> Just $ Left $ varName v
168-
_ -> Nothing
169-
170-
newCVar :: BaseType -> CArgM Atom
171-
newCVar bt = do
172-
name <- genFresh (Name CArgName "arg" 0) <$> looks cargScope
173-
extend $ mempty { cargScope = name @> () }
174-
tell [Bind $ name :> bt]
175-
return $ Var $ name :> BaseTy bt
54+
goArgs :: (EnvReader m, EnvExtender m, Fallible1 m)
55+
=> Nest ExportArg n l' -> [AtomName l'] -> Nest PiBinder l' l -> Type l -> m l' (ExportedSignature n)
56+
goArgs argSig argVs piBs piRes = case piBs of
57+
Empty -> goResult piRes \resSig ->
58+
return $ ExportedSignature argSig resSig $
59+
(fromListE $ sink $ ListE argVs) ++ nestToList (sink . binderName) resSig
60+
Nest b bs -> do
61+
refreshAbs (Abs b (Abs bs piRes)) \(PiBinder v ty arrow) (Abs bs' piRes') -> do
62+
let invalidArrow = throw TypeErr
63+
"Exported functions can only have regular and implicit arrow types"
64+
vis <- case arrow of
65+
PlainArrow -> return ExplicitArg
66+
ImplicitArrow -> return ImplicitArg
67+
ClassArrow -> invalidArrow
68+
TabArrow -> invalidArrow
69+
LinArrow -> invalidArrow
70+
ety <- toExportType ty
71+
goArgs (argSig `joinNest` Nest (ExportArg vis (v:>ety)) Empty)
72+
((fromListE $ sink $ ListE argVs) ++ [binderName v]) bs' piRes'
73+
74+
goResult :: (EnvReader m, EnvExtender m, Fallible1 m)
75+
=> Type l
76+
-> (forall q. DExt l q => Nest ExportResult l q -> m q a)
77+
-> m l a
78+
goResult ty cont = case ty of
79+
ProdTy [lty, rty] ->
80+
goResult lty \lres ->
81+
goResult (sink rty) \rres ->
82+
cont $ joinNest lres rres
83+
_ -> withFreshBinder NoHint ty \b -> do
84+
ety <- toExportType ty
85+
cont $ Nest (ExportResult (b:>ety)) Empty
86+
87+
toExportType :: Fallible m => Type n -> m (ExportType n)
88+
toExportType = \case
89+
BaseTy (Scalar sbt) -> return $ ScalarType sbt
90+
-- TODO: Arrays!
91+
ty -> throw TypeErr $ "Unsupported type of argument in exported function: " ++ pprint ty
17692

17793
-- === Exported function signature ===
17894

179-
data ExportArrayType = GeneralArrayPtr ScalarBaseType
180-
| RectContArrayPtr ScalarBaseType [Either Name Int]
18195
data ArgVisibility = ImplicitArg | ExplicitArg
182-
data ExportArg = ExportArrayArg ArgVisibility Name ExportArrayType
183-
| ExportScalarArg ArgVisibility Name ScalarBaseType
184-
data ExportResult = ExportArrayResult Name ExportArrayType
185-
| ExportScalarResultPtr Name ScalarBaseType
186-
| ExportPairResult ExportResult ExportResult
187-
data ExportedSignature =
188-
ExportedSignature { exportedArgSig :: [ExportArg]
189-
, exportedResSig :: ExportResult
190-
, exportedCCallSig :: [Name]
96+
data ExportType n = RectContArrayPtr ScalarBaseType [Either (AtomName n) Int]
97+
| ScalarType ScalarBaseType
98+
99+
data ExportArg n l = ExportArg ArgVisibility (BinderP AtomNameC ExportType n l)
100+
newtype ExportResult n l = ExportResult (BinderP AtomNameC ExportType n l)
101+
data ExportedSignature n = forall l l'.
102+
ExportedSignature { exportedArgSig :: Nest ExportArg n l
103+
, exportedResSig :: Nest ExportResult l l'
104+
, exportedCCallSig :: [AtomName l']
191105
}
192106

107+
deriving via (BinderP AtomNameC ExportType) instance GenericB ExportResult
108+
deriving via (BinderP AtomNameC ExportType) instance ProvesExt ExportResult
109+
deriving via (BinderP AtomNameC ExportType) instance BindsNames ExportResult
110+
instance BindsAtMostOneName ExportResult AtomNameC where
111+
(ExportResult b) @> v = b @> v
112+
instance BindsOneName ExportResult AtomNameC where
113+
binderName (ExportResult b) = binderName b
114+
115+
instance GenericB ExportArg where
116+
type RepB ExportArg = PairB (LiftB (LiftE ArgVisibility)) (BinderP AtomNameC ExportType)
117+
fromB (ExportArg vis b) = PairB (LiftB (LiftE vis)) b
118+
toB (PairB (LiftB (LiftE vis)) b) = ExportArg vis b
119+
instance ProvesExt ExportArg
120+
instance BindsNames ExportArg
121+
instance BindsAtMostOneName ExportArg AtomNameC where
122+
(ExportArg _ b) @> v = b @> v
123+
instance BindsOneName ExportArg AtomNameC where
124+
binderName (ExportArg _ b) = binderName b
125+
193126
-- Serialization
194127

195-
exportedSignatureDesc :: ExportedSignature -> (String, String, String)
128+
exportedSignatureDesc :: ExportedSignature n -> (String, String, String)
196129
exportedSignatureDesc ExportedSignature{..} =
197-
( intercalate "," $ fmap show exportedArgSig
198-
, show exportedResSig
199-
, intercalate "," $ fmap showCArgName exportedCCallSig
130+
( intercalate "," $ nestToList show exportedArgSig
131+
, intercalate "," $ nestToList show exportedResSig
132+
, intercalate "," $ fmap pprint exportedCCallSig
200133
)
201134

202135
showExportSBT :: ScalarBaseType -> String
@@ -209,34 +142,21 @@ showExportSBT sbt = case sbt of
209142
Float64Type -> "f64"
210143
Float32Type -> "f32"
211144

212-
showCArgName :: Name -> String
213-
showCArgName ~name@(Name namespace tag idx) = case namespace of
214-
CArgName -> T.unpack tag <> show idx
215-
_ -> error $ "Expected a CArgName namespace: " ++ show name
216-
217-
instance Show ExportArrayType where
145+
instance Show (ExportType n) where
218146
show arr = case arr of
219-
GeneralArrayPtr sbt -> showExportSBT sbt <> "[?]"
220147
RectContArrayPtr sbt shape -> showExportSBT sbt <> showShape shape
148+
ScalarType sbt -> showExportSBT sbt
221149
where
222150
showShape shape = "[" <> (intercalate "," $ fmap showDim shape) <> "]"
223151
showDim size = case size of
224-
Left name -> showCArgName name
152+
Left name -> pprint name
225153
Right lit -> show lit
226154

227-
instance Show ExportArg where
228-
show arg = case arg of
229-
ExportArrayArg vis name ty -> showVis vis <> showCArgName name <> ":" <> show ty
230-
ExportScalarArg vis name sbt -> showVis vis <> showCArgName name <> ":" <> showExportSBT sbt
155+
instance Show (ExportArg n l) where
156+
show (ExportArg vis (name:>ty)) = showVis vis <> pprint name <> ":" <> show ty
231157
where
232158
showVis ImplicitArg = "?"
233159
showVis ExplicitArg = ""
234160

235-
instance Show ExportResult where
236-
show res = case res of
237-
ExportArrayResult name ty -> showCArgName name <> ":" <> show ty
238-
ExportScalarResultPtr name sbt -> showCArgName name <> ":" <> showExportSBT sbt
239-
-- Nested pairs / tuples are compiled down to a sequence of separate output
240-
-- arguments, so a pair result is serialized to look like two separate
241-
-- results.
242-
ExportPairResult left right -> show left <> "," <> show right
161+
instance Show (ExportResult n l) where
162+
show (ExportResult (name:>ty)) = pprint name <> ":" <> show ty

src/lib/Imp.hs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
{-# OPTIONS_GHC -Wno-orphans #-}
1313

1414
module Imp
15-
( toImpFunction, ImpFunctionWithRecon (..), toImpStandaloneFunction
15+
( toImpFunction, ImpFunctionWithRecon (..)
16+
, toImpStandaloneFunction, toImpExportedFunction
1617
, PtrBinder, impFunType, getIType) where
1718

1819
import Data.Functor
@@ -66,6 +67,33 @@ toImpStandaloneFunction' lam@(NaryLamExpr bs Pure body) = do
6667
return []
6768
toImpStandaloneFunction' (NaryLamExpr _ _ _) = error "effectful functions not implemented"
6869

70+
toImpExportedFunction :: EnvReader m => NaryLamExpr n -> m n (ImpFunction n)
71+
toImpExportedFunction lam@(NaryLamExpr (NonEmptyNest fb tb) effs body) = liftImpM do
72+
case effs of
73+
Pure -> return ()
74+
_ -> throw TypeErr "Can only export pure functions"
75+
let bs = Nest fb tb
76+
NaryPiType tbs _ resTy <- naryLamExprType lam
77+
(resDestAbsArgsPtrs, ptrFormals) <- refreshAbs (Abs tbs resTy) \tbs' resTy' -> do
78+
-- WARNING! This ties the makeDest implementation to the C API expected in export.
79+
-- In particular, every array has to be backend by a single pointer and pairs
80+
-- should be traversed left-to-right.
81+
AbsPtrs (Abs ptrBs' resDest') ptrInfo <- makeDest (LLVM, CPU, Unmanaged) resTy'
82+
let ptrFormals = ptrInfo <&> \(DestPtrInfo bt _) -> ("res"::NameHint, PtrType bt)
83+
return (Abs tbs' (Abs ptrBs' resDest'), ptrFormals)
84+
let argFormals = nestToList formalForBinder bs
85+
dropSubst $ buildImpFunction CEntryFun (argFormals ++ ptrFormals) \argsAndPtrs -> do
86+
let (args, ptrs) = splitAt (length argFormals) argsAndPtrs
87+
resDestAbsPtrs <- applyNaryAbs (sink resDestAbsArgsPtrs) args
88+
resDest <- applyNaryAbs resDestAbsPtrs ptrs
89+
extendSubst (bs @@> map SubstVal (Var <$> args)) do
90+
void $ translateBlock (Just $ sink resDest) body
91+
return []
92+
where
93+
formalForBinder b = case binderType b of
94+
BaseTy bt -> (NoHint, bt)
95+
_ -> error "Expected all binders to be of a BaseType"
96+
6997
loadArgDests :: (Emits n, ImpBuilder m) => NaryLamDest n -> m n ([Atom n], Dest n)
7098
loadArgDests (Abs Empty resultDest) = return ([], resultDest)
7199
loadArgDests (Abs (Nest (b:>argDest) bs) resultDest) = do

0 commit comments

Comments
 (0)