Skip to content

Commit 43e7751

Browse files
committed
Reenable Python JIT
1 parent 04a3d32 commit 43e7751

File tree

8 files changed

+88
-75
lines changed

8 files changed

+88
-75
lines changed

dex.cabal

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ foreign-library Dex
145145
ghc-options: -Wall -fPIC -optP-Wno-nonportable-include-path
146146
default-language: Haskell2010
147147
default-extensions: TypeApplications, ScopedTypeVariables, LambdaCase,
148-
BlockArguments
148+
BlockArguments, DataKinds, GADTs
149149
if flag(optimized)
150150
ghc-options: -O3
151151
else

python/dex/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,6 @@ def __call__(self, *args):
121121
return eval(" ".join(f"python_arg{i}" for i in range(len(args) + 1)), module=self.module, _env=env)
122122

123123
def compile(self):
124-
raise NotImplementedError()
125124
func_ptr = api.compile(api.jit, self.module, self)
126-
if not func_ptr:
127-
api.raise_from_dex()
125+
if not func_ptr: api.raise_from_dex()
128126
return NativeFunction(api.jit, func_ptr)

python/dex/native_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def __init__(self, text):
140140
self.text = text
141141

142142
def consume(self, char: str):
143-
assert self.text[self.offset] == ord(char)
143+
assert self.text[self.offset] == ord(char), (self.text, self.offset, char)
144144
self.offset += 1
145145

146146
def maybe_consume(self, char: str) -> bool:

python/tests/jit_test.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ def check_atom(dex_atom, reference, args_iter):
2424
rtol=1e-4, atol=1e-6)
2525
assert ran_any_iter, "Empty argument iterator!"
2626

27-
def expr_test(dex_source, reference, args_iter):
27+
def expr_test(dex_source, reference, args_iter, skip=False):
2828
def test(self):
2929
return check_atom(dex.eval(dex_source), reference, args_iter)
30+
if skip:
31+
test = unittest.skip(test)
3032
return test
3133

32-
@unittest.skip
3334
class JITTest(unittest.TestCase):
3435
test_sigmoid = expr_test(r"\x:Float. 1.0 / (1.0 + exp(-x))",
3536
lambda x: np.float32(1.0) / (np.float32(1.0) + np.exp(-x)),
@@ -46,33 +47,38 @@ class JITTest(unittest.TestCase):
4647

4748
test_array_scalar = expr_test(r"\x:((Fin 10)=>Float). sum x",
4849
np.sum,
49-
[(np.arange(10, dtype=np.float32),)])
50+
[(np.arange(10, dtype=np.float32),)],
51+
skip=True)
5052

5153
test_scalar_array = expr_test(r"\x:Int. for i:(Fin 10). x + ordinal i",
5254
lambda x: x + np.arange(10, dtype=np.int32),
53-
[(i,) for i in range(5)])
55+
[(i,) for i in range(5)],
56+
skip=True)
5457

5558
test_array_array = expr_test(r"\x:((Fin 10)=>Float). for i. exp x.i",
5659
np.exp,
57-
[(np.arange(10, dtype=np.float32),)])
60+
[(np.arange(10, dtype=np.float32),)],
61+
skip=True)
5862

63+
@unittest.skip
5964
def test_polymorphic_array_1d(self):
6065
m = dex.Module(dedent("""
61-
def addTwo (n: Int) ?-> (x: (Fin n)=>Float) : (Fin n)=>Float = for i. x.i + 2.0
66+
def addTwo {n} (x: (Fin n)=>Float) : (Fin n)=>Float = for i. x.i + 2.0
6267
"""))
6368
check_atom(m.addTwo, lambda x: x + 2,
6469
[(np.arange(l, dtype=np.float32),) for l in (2, 5, 10)])
6570

71+
@unittest.skip
6672
def test_polymorphic_array_2d(self):
6773
m = dex.Module(dedent("""
68-
def myTranspose (n: Int) ?-> (m: Int) ?->
69-
(x : (Fin n)=>(Fin m)=>Float) : (Fin m)=>(Fin n)=>Float =
74+
def myTranspose {n m} (x : (Fin n)=>(Fin m)=>Float) : (Fin m)=>(Fin n)=>Float =
7075
for i j. x.j.i
7176
"""))
7277
check_atom(m.myTranspose, lambda x: x.T,
7378
[(np.arange(a*b, dtype=np.float32).reshape((a, b)),)
7479
for a, b in it.product((2, 5, 10), repeat=2)])
7580

81+
@unittest.skip
7682
def test_tuple_return(self):
7783
dex_func = dex.eval(r"\x: ((Fin 10) => Float). (x, 2. .* x, 3. .* x)")
7884
reference = lambda x: (x, 2 * x, 3 * x)

src/Dex/Foreign/API.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,5 @@ foreign export ccall "dexCreateJIT" dexCreateJIT :: IO (Ptr JIT)
3737
foreign export ccall "dexDestroyJIT" dexDestroyJIT :: Ptr JIT -> IO ()
3838
foreign export ccall "dexCompile" dexCompile :: Ptr JIT -> Ptr Context -> Ptr AtomEx -> IO (Ptr NativeFunction)
3939
foreign export ccall "dexUnload" dexUnload :: Ptr JIT -> Ptr NativeFunction -> IO ()
40-
foreign export ccall "dexGetFunctionSignature" dexGetFunctionSignature :: Ptr JIT -> Ptr NativeFunction -> IO (Ptr ExportedSignature)
41-
foreign export ccall "dexFreeFunctionSignature" dexFreeFunctionSignature :: Ptr ExportedSignature -> IO ()
40+
foreign export ccall "dexGetFunctionSignature" dexGetFunctionSignature :: Ptr JIT -> Ptr NativeFunction -> IO (Ptr ClosedExportedSignature)
41+
foreign export ccall "dexFreeFunctionSignature" dexFreeFunctionSignature :: Ptr ClosedExportedSignature -> IO ()

src/Dex/Foreign/Context.hs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,11 @@ import Foreign.C.String
2020

2121
import Control.Monad.IO.Class
2222
import Data.String
23-
import Data.Int
2423
import Data.Functor
2524
import Data.Foldable
2625
import qualified Data.Map.Strict as M
2726

28-
import Resources
2927
import Syntax hiding (sizeOf)
30-
import Type
3128
import TopLevel
3229
import Name
3330
import PPrint
@@ -42,11 +39,6 @@ data Context = Context EvalConfig TopStateEx
4239
data AtomEx where
4340
AtomEx :: Atom n -> AtomEx
4441

45-
foreign import ccall "_internal_dexSetError" internalSetErrorPtr :: CString -> Int64 -> IO ()
46-
setError :: String -> IO ()
47-
setError msg = withCStringLen msg $ \(ptr, len) ->
48-
internalSetErrorPtr ptr (fromIntegral len)
49-
5042
dexCreateContext :: IO (Ptr Context)
5143
dexCreateContext = do
5244
let evalConfig = EvalConfig LLVM Nothing Nothing Nothing

src/Dex/Foreign/JIT.hs

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
-- https://developers.google.com/open-source/licenses/bsd
66

77
{-# LANGUAGE RecordWildCards #-}
8+
{-# LANGUAGE FlexibleInstances #-}
89
{-# OPTIONS_GHC -Wno-orphans #-}
910

1011
module Dex.Foreign.JIT (
11-
JIT, NativeFunction, ExportedSignature,
12+
JIT, NativeFunction, ClosedExportedSignature,
1213
dexCreateJIT, dexDestroyJIT,
1314
dexGetFunctionSignature, dexFreeFunctionSignature,
1415
dexCompile, dexUnload
@@ -32,96 +33,95 @@ import qualified LLVM.CodeGenOpt as CGO
3233
import qualified LLVM.JIT
3334
import qualified LLVM.Shims
3435

36+
import Name
3537
import Logging
38+
import Builder
3639
import LLVMExec
3740
import TopLevel
3841
import JIT
42+
import Export
3943
import Syntax hiding (sizeOf)
4044

4145
import Dex.Foreign.Util
4246
import Dex.Foreign.Context
4347

44-
-- TODO: Update Export to safer names
45-
-- import Export
46-
newtype ExportedSignature = ExportedSignature ()
47-
exportedSignatureDesc :: ExportedSignature -> (String, String, String)
48-
exportedSignatureDesc = undefined
49-
48+
type ClosedExportedSignature = ExportedSignature 'VoidS
5049
data NativeFunction =
5150
NativeFunction { nativeModule :: LLVM.JIT.NativeModule
52-
, nativeSignature :: ExportedSignature }
51+
, nativeSignature :: ExportedSignature 'VoidS }
5352
type NativeFunctionAddr = Ptr NativeFunction
5453

5554
data JIT = ForeignJIT { jit :: LLVM.JIT.JIT
5655
, jitTargetMachine :: TargetMachine
5756
, addrTableRef :: IORef (M.Map NativeFunctionAddr NativeFunction)
5857
}
5958

60-
instance Storable ExportedSignature where
59+
instance Storable (ExportedSignature 'VoidS) where
6160
sizeOf _ = 3 * sizeOf (undefined :: Ptr ())
6261
alignment _ = alignment (undefined :: Ptr ())
6362
peek _ = error "peek not implemented for ExportedSignature"
6463
poke addr sig = do
65-
let strAddr = castPtr @ExportedSignature @CString addr
64+
let strAddr = castPtr @(ExportedSignature 'VoidS) @CString addr
6665
let (arg, res, ccall) = exportedSignatureDesc sig
6766
pokeElemOff strAddr 0 =<< newCString arg
6867
pokeElemOff strAddr 1 =<< newCString res
6968
pokeElemOff strAddr 2 =<< newCString ccall
7069

7170
dexCreateJIT :: IO (Ptr JIT)
7271
dexCreateJIT = do
73-
setError "currently disabled" $> nullPtr
74-
--jitTargetMachine <- LLVM.Shims.newHostTargetMachine R.PIC CM.Large CGO.Aggressive
75-
--jit <- LLVM.JIT.createJIT jitTargetMachine
76-
--addrTableRef <- newIORef mempty
77-
--toStablePtr ForeignJIT{..}
72+
jitTargetMachine <- LLVM.Shims.newHostTargetMachine R.PIC CM.Large CGO.Aggressive
73+
jit <- LLVM.JIT.createJIT jitTargetMachine
74+
addrTableRef <- newIORef mempty
75+
toStablePtr ForeignJIT{..}
7876

7977
dexDestroyJIT :: Ptr JIT -> IO ()
8078
dexDestroyJIT jitPtr = do
81-
return ()
82-
--ForeignJIT{..} <- fromStablePtr jitPtr
83-
--addrTable <- readIORef addrTableRef
84-
--forM_ (M.toList addrTable) $ \(_, m) -> LLVM.JIT.unloadNativeModule $ nativeModule m
85-
--LLVM.JIT.destroyJIT jit
86-
--LLVM.Shims.disposeTargetMachine jitTargetMachine
79+
ForeignJIT{..} <- fromStablePtr jitPtr
80+
addrTable <- readIORef addrTableRef
81+
forM_ (M.toList addrTable) $ \(_, m) -> LLVM.JIT.unloadNativeModule $ nativeModule m
82+
LLVM.JIT.destroyJIT jit
83+
LLVM.Shims.disposeTargetMachine jitTargetMachine
8784

8885
dexCompile :: Ptr JIT -> Ptr Context -> Ptr AtomEx -> IO NativeFunctionAddr
89-
dexCompile jitPtr ctxPtr funcAtomPtr = do
90-
setError "currently disabled" $> nullPtr
91-
--ForeignJIT{..} <- fromStablePtr jitPtr
92-
--Context _ (TopStateEx env) <- fromStablePtr ctxPtr
93-
--funcAtom <- fromStablePtr funcAtomPtr
94-
--let (impMod, nativeSignature) = prepareFunctionForExport
95-
--(topBindings $ topStateD env) "userFunc" funcAtom
96-
--nativeModule <- execLogger Nothing $ \logger -> do
97-
--llvmAST <- impToLLVM logger impMod
98-
--LLVM.JIT.compileModule jit llvmAST
99-
--(standardCompilationPipeline logger ["userFunc"] jitTargetMachine)
100-
--funcPtr <- castFunPtrToPtr <$> LLVM.JIT.getFunctionPtr nativeModule "userFunc"
101-
--modifyIORef addrTableRef $ M.insert funcPtr NativeFunction{..}
102-
--return $ funcPtr
103-
104-
dexGetFunctionSignature :: Ptr JIT -> NativeFunctionAddr -> IO (Ptr ExportedSignature)
86+
dexCompile jitPtr ctxPtr funcAtomPtr = catchErrors $ do
87+
ForeignJIT{..} <- fromStablePtr jitPtr
88+
Context evalConfig initEnv <- fromStablePtr ctxPtr
89+
AtomEx funcAtom <- fromStablePtr funcAtomPtr
90+
fst <$> runTopperM evalConfig initEnv do
91+
-- TODO: Check if atom is compatible with context! Use module name?
92+
(impFunc, nativeSignature) <- prepareFunctionForExport (unsafeCoerceE funcAtom)
93+
(_, llvmAST) <- impToLLVM "userFunc" impFunc
94+
logger <- getLogger
95+
objFileNames <- getAllRequiredObjectFiles
96+
objFiles <- forM objFileNames \objFileName -> do
97+
ObjectFileBinding (ObjectFile bytes _ _) <- lookupEnv objFileName
98+
return bytes
99+
liftIO do
100+
nativeModule <- LLVM.JIT.compileModule jit objFiles llvmAST
101+
(standardCompilationPipeline logger ["userFunc"] jitTargetMachine)
102+
funcPtr <- castFunPtrToPtr <$> LLVM.JIT.getFunctionPtr nativeModule "userFunc"
103+
modifyIORef addrTableRef $ M.insert funcPtr NativeFunction{..}
104+
return $ funcPtr
105+
106+
dexGetFunctionSignature :: Ptr JIT -> NativeFunctionAddr -> IO (Ptr (ExportedSignature 'VoidS))
105107
dexGetFunctionSignature jitPtr funcPtr = do
106-
setError "currently disabled" $> nullPtr
107-
--ForeignJIT{..} <- fromStablePtr jitPtr
108-
--addrTable <- readIORef addrTableRef
109-
--case M.lookup funcPtr addrTable of
110-
--Nothing -> setError "Invalid function address" $> nullPtr
111-
--Just NativeFunction{..} -> putOnHeap nativeSignature
112-
113-
dexFreeFunctionSignature :: Ptr ExportedSignature -> IO ()
108+
ForeignJIT{..} <- fromStablePtr jitPtr
109+
addrTable <- readIORef addrTableRef
110+
case M.lookup funcPtr addrTable of
111+
Nothing -> setError "Invalid function address" $> nullPtr
112+
Just NativeFunction{..} -> putOnHeap nativeSignature
113+
114+
dexFreeFunctionSignature :: Ptr (ExportedSignature 'VoidS) -> IO ()
114115
dexFreeFunctionSignature sigPtr = do
115-
let strPtr = castPtr @ExportedSignature @CString sigPtr
116+
let strPtr = castPtr @(ExportedSignature 'VoidS) @CString sigPtr
116117
free =<< peekElemOff strPtr 0
117118
free =<< peekElemOff strPtr 1
118119
free =<< peekElemOff strPtr 2
119120
free sigPtr
120121

121122
dexUnload :: Ptr JIT -> NativeFunctionAddr -> IO ()
122123
dexUnload jitPtr funcPtr = do
123-
return ()
124-
--ForeignJIT{..} <- fromStablePtr jitPtr
125-
--addrTable <- readIORef addrTableRef
126-
--LLVM.JIT.unloadNativeModule $ nativeModule $ addrTable M.! funcPtr
127-
--modifyIORef addrTableRef $ M.delete funcPtr
124+
ForeignJIT{..} <- fromStablePtr jitPtr
125+
addrTable <- readIORef addrTableRef
126+
LLVM.JIT.unloadNativeModule $ nativeModule $ addrTable M.! funcPtr
127+
modifyIORef addrTableRef $ M.delete funcPtr

src/Dex/Foreign/Util.hs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,19 @@
44
-- license that can be found in the LICENSE file or at
55
-- https://developers.google.com/open-source/licenses/bsd
66

7-
module Dex.Foreign.Util (fromStablePtr, toStablePtr, putOnHeap) where
7+
module Dex.Foreign.Util (fromStablePtr, toStablePtr, putOnHeap, setError, catchErrors) where
8+
9+
import Data.Int
10+
import Data.Functor
811

912
import Foreign.Ptr
1013
import Foreign.StablePtr
1114
import Foreign.Storable
15+
import Foreign.C.String
1216
import Foreign.Marshal.Alloc
1317

18+
import Err
19+
1420
fromStablePtr :: Ptr a -> IO a
1521
fromStablePtr = deRefStablePtr . castPtrToStablePtr . castPtr
1622

@@ -22,3 +28,14 @@ putOnHeap x = do
2228
ptr <- malloc
2329
poke ptr x
2430
return ptr
31+
32+
catchErrors :: IO (Ptr a) -> IO (Ptr a)
33+
catchErrors m = catchIOExcept m >>= \case
34+
Success ans -> return ans
35+
Failure err -> setError (pprint err) $> nullPtr
36+
37+
foreign import ccall "_internal_dexSetError" internalSetErrorPtr :: CString -> Int64 -> IO ()
38+
39+
setError :: String -> IO ()
40+
setError msg = withCStringLen msg $ \(ptr, len) ->
41+
internalSetErrorPtr ptr (fromIntegral len)

0 commit comments

Comments
 (0)