@@ -39,8 +39,10 @@ prepareFunctionForExport f = do
3939 sig <- case runFallibleM $ runEnvReaderT emptyOutMap $ naryPiToExportSig closedNaryPi of
4040 Success sig -> return sig
4141 Failure err -> throwErrs err
42+ let argRecon = case sig of
43+ ExportedSignature argSig _ _ -> runEnvReaderM emptyOutMap $ exportArgRecon argSig
4244 fSimp <- simplifyTopFunction naryPi f
43- fImp <- toImpExportedFunction fSimp
45+ fImp <- toImpExportedFunction fSimp (sinkFromTop argRecon)
4446 return (fImp, sig)
4547 where
4648 naryPiToExportSig :: (EnvReader m , EnvExtender m , Fallible1 m )
@@ -84,6 +86,25 @@ prepareFunctionForExport f = do
8486 ety <- toExportType ty
8587 cont $ Nest (ExportResult (b:> ety)) Empty
8688
89+ exportArgRecon :: (EnvReader m , EnvExtender m ) => Nest ExportArg n l -> m n (Abs (Nest IBinder ) (ListE Atom ) n )
90+ exportArgRecon topArgs = go (ListE [] ) topArgs
91+ where
92+ go :: (EnvReader m , EnvExtender m )
93+ => ListE Atom n -> Nest ExportArg n l -> m n (Abs (Nest IBinder ) (ListE Atom ) n )
94+ go argAtoms = \ case
95+ Empty -> return $ Abs Empty argAtoms
96+ Nest (ExportArg _ b) bs ->
97+ refreshAbs (Abs b (EmptyAbs bs)) \ (v':> ety) (Abs bs' UnitE ) -> do
98+ let (ity, atom) = typeToAtom ety v'
99+ Abs ibs' allAtoms' <- go (ListE $ fromListE (sink argAtoms) ++ [atom]) bs'
100+ return $ Abs (Nest (IBinder v' ity) ibs') allAtoms'
101+
102+ typeToAtom :: ExportType n -> AtomNameBinder n l -> (IType , Atom l )
103+ typeToAtom ety v = case ety of
104+ ScalarType sbt -> (Scalar sbt , Var $ binderName v)
105+ RectContArrayPtr sbt shape -> (PtrType (Heap CPU , Scalar sbt), tableAtom shape )
106+ where tableAtom = undefined
107+
87108 toExportType :: Fallible m => Type n -> m (ExportType n )
88109 toExportType = \ case
89110 BaseTy (Scalar sbt) -> return $ ScalarType sbt
@@ -104,6 +125,38 @@ data ExportedSignature n = forall l l'.
104125 , exportedCCallSig :: [AtomName l' ]
105126 }
106127
128+ instance GenericE ExportType where
129+ type RepE ExportType = EitherE (LiftE ScalarBaseType )
130+ (LiftE ScalarBaseType `PairE ` ListE (EitherE AtomName (LiftE Int )))
131+ fromE = \ case
132+ ScalarType sbt -> LeftE $ LiftE sbt
133+ RectContArrayPtr sbt shape -> RightE $ LiftE sbt `PairE ` shapeToE shape
134+ toE = \ case
135+ LeftE (LiftE sbt) -> ScalarType sbt
136+ RightE (LiftE sbt `PairE ` shape) -> RectContArrayPtr sbt (shapeFromE shape)
137+ instance SubstE Name ExportType
138+ instance SinkableE ExportType
139+
140+ shapeToE :: [Either (AtomName n ) Int ] -> ListE (EitherE AtomName (LiftE Int )) n
141+ shapeToE shape = ListE (dimToE <$> shape)
142+ where dimToE = \ case Left n -> LeftE n; Right n -> RightE (LiftE n)
143+
144+ shapeFromE :: ListE (EitherE AtomName (LiftE Int )) n -> [Either (AtomName n ) Int ]
145+ shapeFromE (ListE shape) = (dimFromE <$> shape)
146+ where dimFromE = \ case LeftE n -> Left n; RightE (LiftE n) -> Right n
147+
148+ instance ToBinding ExportType AtomNameC where
149+ toBinding = \ case
150+ ScalarType sbt -> toBinding $ BaseTy $ Scalar sbt
151+ RectContArrayPtr sbt shape -> toBinding $ buildArr $ shapeToE shape
152+ where
153+ buildArr :: ListE (EitherE AtomName (LiftE Int )) n -> Type n
154+ buildArr (ListE sl) = case sl of
155+ [] -> BaseTy $ Scalar sbt
156+ (h: t) -> case toConstAbsPure (ListE t) of
157+ Abs b t' -> TabTy (PiBinder b (Fin s) TabArrow ) $ buildArr t'
158+ where s = case h of LeftE v -> Var v; RightE (LiftE n) -> IdxRepVal $ fromIntegral n
159+
107160deriving via (BinderP AtomNameC ExportType ) instance GenericB ExportResult
108161deriving via (BinderP AtomNameC ExportType ) instance ProvesExt ExportResult
109162deriving via (BinderP AtomNameC ExportType ) instance BindsNames ExportResult
@@ -116,8 +169,10 @@ instance GenericB ExportArg where
116169 type RepB ExportArg = PairB (LiftB (LiftE ArgVisibility )) (BinderP AtomNameC ExportType )
117170 fromB (ExportArg vis b) = PairB (LiftB (LiftE vis)) b
118171 toB (PairB (LiftB (LiftE vis)) b) = ExportArg vis b
119- instance ProvesExt ExportArg
120- instance BindsNames ExportArg
172+ instance ProvesExt ExportArg
173+ instance BindsNames ExportArg
174+ instance SinkableB ExportArg
175+ instance SubstB Name ExportArg
121176instance BindsAtMostOneName ExportArg AtomNameC where
122177 (ExportArg _ b) @> v = b @> v
123178instance BindsOneName ExportArg AtomNameC where
0 commit comments