Skip to content

Commit 4066399

Browse files
committed
Begin reviving the Python bridge
The Python bindings have not been updated since the safer name rewrite so they were completely broken. This is the first change that tries to restore them to their former glory.
1 parent 37c4d7d commit 4066399

File tree

13 files changed

+134
-142
lines changed

13 files changed

+134
-142
lines changed

.github/workflows/python-ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on:
44
push:
55
branches: [ main ]
66
pull_request:
7-
branches: [ main ]
7+
branches: [ main, safe-names-dev ]
88

99
jobs:
1010
build:

makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,11 @@ install: dexrt-llvm
9696
$(STACK) install $(STACK_BIN_PATH) --flag dex:optimized $(STACK_FLAGS)
9797

9898
build-prof: dexrt-llvm
99-
$(STACK) build $(STACK_FLAGS) $(PROF) --flag dex:-foreign --flag dex:debug
99+
$(STACK) build $(STACK_FLAGS) $(PROF) --flag dex:debug
100100

101101
# For some reason stack fails to detect modifications to foreign library files
102102
build-ffis: dexrt-llvm
103-
$(STACK) build $(STACK_FLAGS) --force-dirty
103+
$(STACK) build $(STACK_FLAGS) --force-dirty --flag dex:foreign
104104
$(eval STACK_INSTALL_DIR=$(shell $(STACK) path --local-install-root))
105105
cp $(STACK_INSTALL_DIR)/lib/libDex.so python/dex/
106106
cp $(STACK_INSTALL_DIR)/lib/libDex.so julia/deps/

python/dex/__init__.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,21 @@ class Module:
2121

2222
def __init__(self, source):
2323
self._as_parameter_ = api.eval(prelude, api.as_cstr(source))
24-
if not self._as_parameter_:
25-
api.raise_from_dex()
24+
if not self._as_parameter_: api.raise_from_dex()
25+
26+
@classmethod
27+
def _from_ptr(cls, ptr):
28+
if not ptr: api.raise_from_dex()
29+
self = super().__new__(cls)
30+
self._as_parameter_ = ptr
31+
return self
2632

2733
def __del__(self):
28-
if api.nofree:
29-
return
34+
if api.nofree: return
3035
api.destroyContext(self)
3136

3237
def __getattr__(self, name):
3338
result = api.lookup(self, api.as_cstr(name))
34-
if not result:
35-
api.raise_from_dex()
3639
return Atom._from_ptr(result, self)
3740

3841

@@ -46,13 +49,11 @@ def __init__(self):
4649
prelude = Prelude()
4750

4851

49-
def eval(expr: str, module=prelude, _env=None):
50-
if _env is None:
51-
_env = module
52-
result = api.evalExpr(_env, api.as_cstr(expr))
53-
if not result:
54-
api.raise_from_dex()
55-
return Atom._from_ptr(result, module)
52+
def eval(expr: str):
53+
# TODO: Query a free source name
54+
_final_env = Module._from_ptr(api.eval(prelude, api.as_cstr("python_result = " + expr)))
55+
result = api.lookup(_final_env, api.as_cstr("python_result"))
56+
return Atom._from_ptr(result, _final_env)
5657

5758

5859
class Atom:
@@ -77,6 +78,7 @@ def __init__(self, value):
7778

7879
@classmethod
7980
def _from_ptr(cls, ptr, module):
81+
if not ptr: api.raise_from_dex()
8082
self = super().__new__(cls)
8183
self._as_parameter_ = ptr
8284
self.module = module
@@ -88,7 +90,7 @@ def __del__(self):
8890

8991
def __repr__(self):
9092
# TODO: Free!
91-
return api.from_cstr(api.print(self))
93+
return api.from_cstr(api.print(self.module, self))
9294

9395
def __int__(self):
9496
return int(self._as_scalar())
@@ -107,6 +109,7 @@ def _as_scalar(self):
107109
return value.value
108110

109111
def __call__(self, *args):
112+
raise NotImplementedError()
110113
# TODO: Make those calls more hygenic
111114
env = self.module
112115
for i, atom in enumerate(it.chain((self,), args)):
@@ -118,6 +121,7 @@ def __call__(self, *args):
118121
return eval(" ".join(f"python_arg{i}" for i in range(len(args) + 1)), module=self.module, _env=env)
119122

120123
def compile(self):
124+
raise NotImplementedError()
121125
func_ptr = api.compile(api.jit, self.module, self)
122126
if not func_ptr:
123127
api.raise_from_dex()

python/dex/api.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,12 @@ def dex_func(name, *signature):
7171
createContext = dex_func('dexCreateContext', HsContextPtr)
7272
destroyContext = dex_func('dexDestroyContext', HsContextPtr, None)
7373

74-
eval = dex_func('dexEval', HsContextPtr, ctypes.c_char_p, HsContextPtr)
75-
insert = dex_func('dexInsert', HsContextPtr, ctypes.c_char_p, HsAtomPtr, HsContextPtr)
76-
evalExpr = dex_func('dexEvalExpr', HsContextPtr, ctypes.c_char_p, HsAtomPtr)
77-
lookup = dex_func('dexLookup', HsContextPtr, ctypes.c_char_p, HsAtomPtr)
78-
79-
print = dex_func('dexPrint', HsAtomPtr, ctypes.c_char_p)
80-
toCAtom = dex_func('dexToCAtom', HsAtomPtr, CAtomPtr, ctypes.c_int)
81-
fromCAtom = dex_func('dexFromCAtom', CAtomPtr, HsAtomPtr)
74+
eval = dex_func('dexEval', HsContextPtr, ctypes.c_char_p, HsContextPtr)
75+
lookup = dex_func('dexLookup', HsContextPtr, ctypes.c_char_p, HsAtomPtr)
76+
77+
print = dex_func('dexPrint', HsContextPtr, HsAtomPtr, ctypes.c_char_p)
78+
toCAtom = dex_func('dexToCAtom', HsAtomPtr, CAtomPtr, ctypes.c_int)
79+
fromCAtom = dex_func('dexFromCAtom', CAtomPtr, HsAtomPtr)
8280

8381
createJIT = dex_func('dexCreateJIT', HsJITPtr)
8482
destroyJIT = dex_func('dexDestroyJIT', HsJITPtr, None)

python/tests/api_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test_module_attrs(self):
2727
assert str(m.x) == "2.5"
2828
assert str(m.y) == "[2, 3, 4]"
2929

30+
@unittest.skip
3031
def test_function_call(self):
3132
m = dex.Module(dedent("""
3233
def addOne (x: Float) : Float = x + 1.0

python/tests/jax_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from dex.interop.jax import primitive
1515

1616

17+
@unittest.skip
1718
class JAXTest(unittest.TestCase):
1819

1920
def test_impl_scalar(self):

python/tests/jit_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def test(self):
2929
return check_atom(dex.eval(dex_source), reference, args_iter)
3030
return test
3131

32+
@unittest.skip
3233
class JITTest(unittest.TestCase):
3334
test_sigmoid = expr_test(r"\x:Float. 1.0 / (1.0 + exp(-x))",
3435
lambda x: np.float32(1.0) / (np.float32(1.0) + np.exp(-x)),
@@ -75,7 +76,7 @@ def myTranspose (n: Int) ?-> (m: Int) ?->
7576
def test_tuple_return(self):
7677
dex_func = dex.eval(r"\x: ((Fin 10) => Float). (x, 2. .* x, 3. .* x)")
7778
reference = lambda x: (x, 2 * x, 3 * x)
78-
79+
7980
x = np.arange(10, dtype=np.float32)
8081

8182
dex_output = dex_func.compile()(x)

src/Dex/Foreign/API.hs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ module Dex.Foreign.API where
99
import Foreign.Ptr
1010
import Foreign.C
1111

12-
import Syntax
13-
1412
import Dex.Foreign.Context
1513
import Dex.Foreign.Serialize
1614
import Dex.Foreign.JIT
@@ -25,20 +23,19 @@ import Dex.Foreign.JIT
2523
-- Context
2624
foreign export ccall "dexCreateContext" dexCreateContext :: IO (Ptr Context)
2725
foreign export ccall "dexDestroyContext" dexDestroyContext :: Ptr Context -> IO ()
28-
foreign export ccall "dexInsert" dexInsert :: Ptr Context -> CString -> Ptr Atom -> IO (Ptr Context)
26+
foreign export ccall "dexInsert" dexInsert :: Ptr Context -> CString -> Ptr AtomEx -> IO (Ptr Context)
2927
foreign export ccall "dexEval" dexEval :: Ptr Context -> CString -> IO (Ptr Context)
30-
foreign export ccall "dexEvalExpr" dexEvalExpr :: Ptr Context -> CString -> IO (Ptr Atom)
31-
foreign export ccall "dexLookup" dexLookup :: Ptr Context -> CString -> IO (Ptr Atom)
28+
foreign export ccall "dexLookup" dexLookup :: Ptr Context -> CString -> IO (Ptr AtomEx)
3229

3330
-- Serialization
34-
foreign export ccall "dexPrint" dexPrint :: Ptr Atom -> IO CString
35-
foreign export ccall "dexToCAtom" dexToCAtom :: Ptr Atom -> Ptr CAtom -> IO CInt
36-
foreign export ccall "dexFromCAtom" dexFromCAtom :: Ptr CAtom -> IO (Ptr Atom)
31+
foreign export ccall "dexPrint" dexPrint :: Ptr Context -> Ptr AtomEx -> IO CString
32+
foreign export ccall "dexToCAtom" dexToCAtom :: Ptr AtomEx -> Ptr CAtom -> IO CInt
33+
foreign export ccall "dexFromCAtom" dexFromCAtom :: Ptr CAtom -> IO (Ptr AtomEx)
3734

3835
-- JIT
3936
foreign export ccall "dexCreateJIT" dexCreateJIT :: IO (Ptr JIT)
4037
foreign export ccall "dexDestroyJIT" dexDestroyJIT :: Ptr JIT -> IO ()
41-
foreign export ccall "dexCompile" dexCompile :: Ptr JIT -> Ptr Context -> Ptr Atom -> IO (Ptr NativeFunction)
38+
foreign export ccall "dexCompile" dexCompile :: Ptr JIT -> Ptr Context -> Ptr AtomEx -> IO (Ptr NativeFunction)
4239
foreign export ccall "dexUnload" dexUnload :: Ptr JIT -> Ptr NativeFunction -> IO ()
4340
foreign export ccall "dexGetFunctionSignature" dexGetFunctionSignature :: Ptr JIT -> Ptr NativeFunction -> IO (Ptr ExportedSignature)
4441
foreign export ccall "dexFreeFunctionSignature" dexFreeFunctionSignature :: Ptr ExportedSignature -> IO ()

src/Dex/Foreign/Context.hs

Lines changed: 36 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,21 @@
44
-- license that can be found in the LICENSE file or at
55
-- https://developers.google.com/open-source/licenses/bsd
66

7+
{-# LANGUAGE GADTs #-}
8+
79
module Dex.Foreign.Context (
8-
Context (..),
10+
Context (..), AtomEx (..),
911
setError,
1012
dexCreateContext, dexDestroyContext,
1113
dexInsert, dexLookup,
12-
dexEval, dexEvalExpr,
14+
dexEval,
1315
) where
1416

1517
import Foreign.Ptr
1618
import Foreign.StablePtr
1719
import Foreign.C.String
1820

21+
import Control.Monad.IO.Class
1922
import Data.String
2023
import Data.Int
2124
import Data.Functor
@@ -26,18 +29,18 @@ import Resources
2629
import Syntax hiding (sizeOf)
2730
import Type
2831
import TopLevel
29-
import Env hiding (Tag)
32+
import Name
3033
import PPrint
3134
import Err
35+
import Parser
36+
import Builder
3237

3338
import Dex.Foreign.Util
3439

35-
import SaferNames.Bridge
36-
import qualified SaferNames.Syntax as S
37-
import qualified SaferNames.Parser as S
38-
3940

4041
data Context = Context EvalConfig TopStateEx
42+
data AtomEx where
43+
AtomEx :: Atom n -> AtomEx
4144

4245
foreign import ccall "_internal_dexSetError" internalSetErrorPtr :: CString -> Int64 -> IO ()
4346
setError :: String -> IO ()
@@ -46,70 +49,45 @@ setError msg = withCStringLen msg $ \(ptr, len) ->
4649

4750
dexCreateContext :: IO (Ptr Context)
4851
dexCreateContext = do
49-
let evalConfig = EvalConfig LLVM Nothing Nothing
50-
maybePreludeEnv <- evalPrelude evalConfig preludeSource
51-
case maybePreludeEnv of
52-
Success preludeEnv -> toStablePtr $ Context evalConfig preludeEnv
53-
Failure err -> nullPtr <$ setError ("Failed to initialize standard library: " ++ pprint err)
54-
where
55-
evalPrelude :: EvalConfig -> String -> IO (Except TopStateEx)
56-
evalPrelude opts sourceText = do
57-
(results, env) <- runInterblockM opts initTopState $
58-
map snd <$> evalSourceText sourceText
59-
return $ env `unlessError` results
60-
where
61-
unlessError :: TopStateEx -> [Result] -> Except TopStateEx
62-
result `unlessError` [] = Success result
63-
_ `unlessError` ((Result _ (Failure err)):_) = Failure err
64-
result `unlessError` (_:t ) = result `unlessError` t
52+
let evalConfig = EvalConfig LLVM Nothing Nothing Nothing
53+
cachedEnv <- loadCache
54+
runTopperM evalConfig cachedEnv (evalSourceBlockRepl preludeImportBlock) >>= \case
55+
(Result [] (Success ()), preludeEnv) -> toStablePtr $ Context evalConfig preludeEnv
56+
(Result _ (Failure err), _ ) -> nullPtr <$
57+
setError ("Failed to initialize standard library: " ++ pprint err)
6558

6659
dexDestroyContext :: Ptr Context -> IO ()
6760
dexDestroyContext = freeStablePtr . castPtrToStablePtr . castPtr
6861

6962
dexEval :: Ptr Context -> CString -> IO (Ptr Context)
7063
dexEval ctxPtr sourcePtr = do
71-
Context evalConfig env <- fromStablePtr ctxPtr
64+
Context evalConfig initEnv <- fromStablePtr ctxPtr
7265
source <- peekCString sourcePtr
73-
(results, finalEnv) <- runInterblockM evalConfig env $ evalSourceText source
66+
(results, finalEnv) <- runTopperM evalConfig initEnv $ evalSourceText source
7467
let anyError = asum $ fmap (\case (_, Result _ (Failure err)) -> Just err; _ -> Nothing) results
7568
case anyError of
7669
Nothing -> toStablePtr $ Context evalConfig finalEnv
7770
Just err -> setError (pprint err) $> nullPtr
7871

79-
dexInsert :: Ptr Context -> CString -> Ptr Atom -> IO (Ptr Context)
72+
dexInsert :: Ptr Context -> CString -> Ptr AtomEx -> IO (Ptr Context)
8073
dexInsert ctxPtr namePtr atomPtr = do
81-
Context evalConfig (TopStateEx env) <- fromStablePtr ctxPtr
82-
name <- fromString <$> peekCString namePtr
83-
atom <- fromStablePtr atomPtr
84-
let freshName = genFresh (Name GenName (fromString name) 0) (topBindings $ topStateD env)
85-
let newBinding = AtomBinderInfo (getType atom) (LetBound PlainLet (Atom atom))
86-
let evaluated = EvaluatedModule (freshName @> newBinding) mempty
87-
(SourceMap (M.singleton name (SrcAtomName freshName)))
88-
let envNew = extendTopStateD env evaluated
89-
toStablePtr $ Context evalConfig $ envNew
74+
Context evalConfig initEnv <- fromStablePtr ctxPtr
75+
sourceName <- peekCString namePtr
76+
AtomEx atom <- fromStablePtr atomPtr
77+
(_, finalEnv) <- runTopperM evalConfig initEnv do
78+
-- TODO: Check if atom is compatible with context! Use module name?
79+
name <- emitTopLet (fromString sourceName) PlainLet $ Atom $ unsafeCoerceE atom
80+
emitSourceMap $ SourceMap $ M.singleton sourceName [ModuleVar Main $ Just $ UAtomVar name]
81+
toStablePtr $ Context evalConfig finalEnv
9082

91-
dexEvalExpr :: Ptr Context -> CString -> IO (Ptr Atom)
92-
dexEvalExpr ctxPtr sourcePtr = do
93-
Context evalConfig env <- fromStablePtr ctxPtr
94-
source <- peekCString sourcePtr
95-
case S.parseExpr source of
96-
Success expr -> do
97-
let (v, m) = S.exprAsModule expr
98-
let block = S.SourceBlock 0 0 LogNothing source (S.RunModule m) Nothing
99-
(Result [] maybeErr, newState) <- runInterblockM evalConfig env $ evalSourceBlock block
100-
case maybeErr of
101-
Success () -> do
102-
let Success (AtomBinderInfo _ (LetBound _ (Atom atom))) =
103-
lookupSourceName newState v
104-
toStablePtr atom
105-
Failure err -> setError (pprint err) $> nullPtr
106-
Failure err -> setError (pprint err) $> nullPtr
107-
108-
dexLookup :: Ptr Context -> CString -> IO (Ptr Atom)
83+
dexLookup :: Ptr Context -> CString -> IO (Ptr AtomEx)
10984
dexLookup ctxPtr namePtr = do
110-
Context _ env <- fromStablePtr ctxPtr
85+
Context evalConfig env <- fromStablePtr ctxPtr
11186
name <- peekCString namePtr
112-
case lookupSourceName env (fromString name) of
113-
Success (AtomBinderInfo _ (LetBound _ (Atom atom))) -> toStablePtr atom
114-
Failure _ -> setError "Unbound name" $> nullPtr
115-
Success _ -> setError "Looking up an expression" $> nullPtr
87+
fst <$> runTopperM evalConfig env do
88+
lookupSourceMap name >>= \case
89+
Just (UAtomVar v) -> lookupAtomName v >>= \case
90+
LetBound (DeclBinding _ _ (Atom atom)) -> liftIO $ toStablePtr $ AtomEx atom
91+
_ -> liftIO $ setError "Looking up an unevaluated atom?" $> nullPtr
92+
Just _ -> liftIO $ setError "Only Atom names can be looked up" $> nullPtr
93+
Nothing -> liftIO $ setError "Unbound name" $> nullPtr

0 commit comments

Comments
 (0)