Skip to content

Commit 68c89d9

Browse files
committed
Add reconstructions for exported Atoms
Tables have to be reconstructed from their underlying pointers. There's no new functionality in this patch, but it should enable us to add support for tables back.
1 parent 43e7751 commit 68c89d9

File tree

3 files changed

+72
-12
lines changed

3 files changed

+72
-12
lines changed

src/lib/Export.hs

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ prepareFunctionForExport f = do
3939
sig <- case runFallibleM $ runEnvReaderT emptyOutMap $ naryPiToExportSig closedNaryPi of
4040
Success sig -> return sig
4141
Failure err -> throwErrs err
42+
let argRecon = case sig of
43+
ExportedSignature argSig _ _ -> runEnvReaderM emptyOutMap $ exportArgRecon argSig
4244
fSimp <- simplifyTopFunction naryPi f
43-
fImp <- toImpExportedFunction fSimp
45+
fImp <- toImpExportedFunction fSimp (sinkFromTop argRecon)
4446
return (fImp, sig)
4547
where
4648
naryPiToExportSig :: (EnvReader m, EnvExtender m, Fallible1 m)
@@ -84,6 +86,25 @@ prepareFunctionForExport f = do
8486
ety <- toExportType ty
8587
cont $ Nest (ExportResult (b:>ety)) Empty
8688

89+
exportArgRecon :: (EnvReader m, EnvExtender m) => Nest ExportArg n l -> m n (Abs (Nest IBinder) (ListE Atom) n)
90+
exportArgRecon topArgs = go (ListE []) topArgs
91+
where
92+
go :: (EnvReader m, EnvExtender m)
93+
=> ListE Atom n -> Nest ExportArg n l -> m n (Abs (Nest IBinder) (ListE Atom) n)
94+
go argAtoms = \case
95+
Empty -> return $ Abs Empty argAtoms
96+
Nest (ExportArg _ b) bs ->
97+
refreshAbs (Abs b (EmptyAbs bs)) \(v':>ety) (Abs bs' UnitE) -> do
98+
let (ity, atom) = typeToAtom ety v'
99+
Abs ibs' allAtoms' <- go (ListE $ fromListE (sink argAtoms) ++ [atom]) bs'
100+
return $ Abs (Nest (IBinder v' ity) ibs') allAtoms'
101+
102+
typeToAtom :: ExportType n -> AtomNameBinder n l -> (IType, Atom l)
103+
typeToAtom ety v = case ety of
104+
ScalarType sbt -> (Scalar sbt , Var $ binderName v)
105+
RectContArrayPtr sbt shape -> (PtrType (Heap CPU, Scalar sbt), tableAtom shape )
106+
where tableAtom = undefined
107+
87108
toExportType :: Fallible m => Type n -> m (ExportType n)
88109
toExportType = \case
89110
BaseTy (Scalar sbt) -> return $ ScalarType sbt
@@ -104,6 +125,38 @@ data ExportedSignature n = forall l l'.
104125
, exportedCCallSig :: [AtomName l']
105126
}
106127

128+
instance GenericE ExportType where
129+
type RepE ExportType = EitherE (LiftE ScalarBaseType)
130+
(LiftE ScalarBaseType `PairE` ListE (EitherE AtomName (LiftE Int)))
131+
fromE = \case
132+
ScalarType sbt -> LeftE $ LiftE sbt
133+
RectContArrayPtr sbt shape -> RightE $ LiftE sbt `PairE` shapeToE shape
134+
toE = \case
135+
LeftE (LiftE sbt) -> ScalarType sbt
136+
RightE (LiftE sbt `PairE` shape) -> RectContArrayPtr sbt (shapeFromE shape)
137+
instance SubstE Name ExportType
138+
instance SinkableE ExportType
139+
140+
shapeToE :: [Either (AtomName n) Int] -> ListE (EitherE AtomName (LiftE Int)) n
141+
shapeToE shape = ListE (dimToE <$> shape)
142+
where dimToE = \case Left n -> LeftE n; Right n -> RightE (LiftE n)
143+
144+
shapeFromE :: ListE (EitherE AtomName (LiftE Int)) n -> [Either (AtomName n) Int]
145+
shapeFromE (ListE shape) = (dimFromE <$> shape)
146+
where dimFromE = \case LeftE n -> Left n; RightE (LiftE n) -> Right n
147+
148+
instance ToBinding ExportType AtomNameC where
149+
toBinding = \case
150+
ScalarType sbt -> toBinding $ BaseTy $ Scalar sbt
151+
RectContArrayPtr sbt shape -> toBinding $ buildArr $ shapeToE shape
152+
where
153+
buildArr :: ListE (EitherE AtomName (LiftE Int)) n -> Type n
154+
buildArr (ListE sl) = case sl of
155+
[] -> BaseTy $ Scalar sbt
156+
(h:t) -> case toConstAbsPure (ListE t) of
157+
Abs b t' -> TabTy (PiBinder b (Fin s) TabArrow) $ buildArr t'
158+
where s = case h of LeftE v -> Var v; RightE (LiftE n) -> IdxRepVal $ fromIntegral n
159+
107160
deriving via (BinderP AtomNameC ExportType) instance GenericB ExportResult
108161
deriving via (BinderP AtomNameC ExportType) instance ProvesExt ExportResult
109162
deriving via (BinderP AtomNameC ExportType) instance BindsNames ExportResult
@@ -116,8 +169,10 @@ instance GenericB ExportArg where
116169
type RepB ExportArg = PairB (LiftB (LiftE ArgVisibility)) (BinderP AtomNameC ExportType)
117170
fromB (ExportArg vis b) = PairB (LiftB (LiftE vis)) b
118171
toB (PairB (LiftB (LiftE vis)) b) = ExportArg vis b
119-
instance ProvesExt ExportArg
120-
instance BindsNames ExportArg
172+
instance ProvesExt ExportArg
173+
instance BindsNames ExportArg
174+
instance SinkableB ExportArg
175+
instance SubstB Name ExportArg
121176
instance BindsAtMostOneName ExportArg AtomNameC where
122177
(ExportArg _ b) @> v = b @> v
123178
instance BindsOneName ExportArg AtomNameC where

src/lib/Imp.hs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,11 @@ toImpStandaloneFunction' lam@(NaryLamExpr bs Pure body) = do
6767
return []
6868
toImpStandaloneFunction' (NaryLamExpr _ _ _) = error "effectful functions not implemented"
6969

70-
toImpExportedFunction :: EnvReader m => NaryLamExpr n -> m n (ImpFunction n)
71-
toImpExportedFunction lam@(NaryLamExpr (NonEmptyNest fb tb) effs body) = liftImpM do
70+
toImpExportedFunction :: EnvReader m
71+
=> NaryLamExpr n
72+
-> (Abs (Nest IBinder) (ListE Atom) n)
73+
-> m n (ImpFunction n)
74+
toImpExportedFunction lam@(NaryLamExpr (NonEmptyNest fb tb) effs body) argRecon@(Abs baseArgBs _) = liftImpM do
7275
case effs of
7376
Pure -> return ()
7477
_ -> throw TypeErr "Can only export pure functions"
@@ -81,18 +84,15 @@ toImpExportedFunction lam@(NaryLamExpr (NonEmptyNest fb tb) effs body) = liftImp
8184
AbsPtrs (Abs ptrBs' resDest') ptrInfo <- makeDest (LLVM, CPU, Unmanaged) resTy'
8285
let ptrFormals = ptrInfo <&> \(DestPtrInfo bt _) -> ("res"::NameHint, PtrType bt)
8386
return (Abs tbs' (Abs ptrBs' resDest'), ptrFormals)
84-
let argFormals = nestToList formalForBinder bs
87+
let argFormals = nestToList ((NoHint,) . iBinderType) baseArgBs
8588
dropSubst $ buildImpFunction CEntryFun (argFormals ++ ptrFormals) \argsAndPtrs -> do
8689
let (args, ptrs) = splitAt (length argFormals) argsAndPtrs
8790
resDestAbsPtrs <- applyNaryAbs (sink resDestAbsArgsPtrs) args
8891
resDest <- applyNaryAbs resDestAbsPtrs ptrs
89-
extendSubst (bs @@> map SubstVal (Var <$> args)) do
92+
argAtoms <- fromListE <$> applyNaryAbs (sink argRecon) args
93+
extendSubst (bs @@> map SubstVal argAtoms) do
9094
void $ translateBlock (Just $ sink resDest) body
9195
return []
92-
where
93-
formalForBinder b = case binderType b of
94-
BaseTy bt -> (NoHint, bt)
95-
_ -> error "Expected all binders to be of a BaseType"
9696

9797
loadArgDests :: (Emits n, ImpBuilder m) => NaryLamDest n -> m n ([Atom n], Dest n)
9898
loadArgDests (Abs Empty resultDest) = return ([], resultDest)

src/lib/Name.hs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ module Name (
3737
HashMapE (..), HashableE,
3838
MaybeE, fromMaybeE, toMaybeE, pattern JustE, pattern NothingE, MaybeB,
3939
pattern JustB, pattern NothingB,
40-
toConstAbs, PrettyE, PrettyB, ShowE, ShowB,
40+
toConstAbs, toConstAbsPure, PrettyE, PrettyB, ShowE, ShowB,
4141
runScopeReaderT, runScopeReaderM, runSubstReaderT, idNameSubst, liftSubstReaderT,
4242
liftScopeReaderT, liftScopeReaderM,
4343
ScopeReaderT (..), SubstReaderT (..),
@@ -420,6 +420,11 @@ toConstAbs body = do
420420
withFresh "ignore" scope \b -> do
421421
sinkM $ Abs b $ sink body'
422422

423+
toConstAbsPure :: (HoistableE e, SinkableE e, Color c)
424+
=> e n -> (Abs (NameBinder c) e n)
425+
toConstAbsPure e = Abs (UnsafeMakeBinder n) (unsafeCoerceE e)
426+
where n = freshRawName NoHint $ freeVarsE e
427+
423428
-- === type classes for traversing names ===
424429

425430
class FromName v => SubstE (v::V) (e::E) where

0 commit comments

Comments
 (0)