@@ -50,7 +50,7 @@ import PPrint
5050import Builder (TopBuilder (.. ), queryObjCache )
5151import Logging
5252import LLVMExec
53- import Util (IsBool (.. ))
53+ import Util (IsBool (.. ), bindM2 )
5454
5555-- === Compile monad ===
5656
@@ -118,7 +118,7 @@ impToLLVM' logger fName f = do
118118 dtorRegEntryTy = L. StructureType False [ i32, hostPtrTy dtorType, hostVoidp ]
119119 makeDtorRegEntry dtorName = C. Struct Nothing False
120120 [ C. Int 32 1
121- , C. GlobalReference (hostPtrTy dtorType) dtorName
121+ , globalReference (hostPtrTy dtorType) dtorName
122122 , C. Null hostVoidp ]
123123 defineGlobalDtors globalDtors =
124124 L. GlobalDefinition $ L. globalVariableDefaults
@@ -158,11 +158,11 @@ compileFunction logger fName fun@(ImpFunction (IFunType cc argTys retTys)
158158 (resultPtrParam, resultPtrOperand) <- freshParamOpPair attrs $ hostPtrTy i64
159159 initializeOutputStream streamFDOperand
160160 argOperands <- forM (zip [0 .. ] argTys) \ (i, ty) ->
161- gep argPtrOperand (i64Lit i) >>= castLPtr (scalarTy ty) >>= load
161+ gep i64 argPtrOperand (i64Lit i) >>= castLPtr (scalarTy ty) >>= load (scalarTy ty)
162162 when (toBool requiresCUDA) ensureHasCUDAContext
163163 results <- extendSubst (bs @@> map opSubstVal argOperands) $ compileBlock body
164164 forM_ (zip [0 .. ] results) \ (i, x) ->
165- gep resultPtrOperand (i64Lit i) >>= castLPtr (typeOf x) >>= flip store x
165+ gep i64 resultPtrOperand (i64Lit i) >>= castLPtr (typeOf x) >>= flip store x
166166 mainFun <- makeFunction (topLevelFunName fName)
167167 [streamFDParam, argPtrParam, resultPtrParam] (Just $ i64Lit 0 )
168168 extraSpecs <- gets funSpecs
@@ -191,21 +191,24 @@ compileFunction logger fName fun@(ImpFunction (IFunType cc argTys retTys)
191191 , L. type' = hostVoidp
192192 , L. linkage = L. Private
193193 , L. initializer = Just $ C. Null hostVoidp }
194- let kernelModuleCache = L. ConstantOperand $ C. GlobalReference (hostPtrTy hostVoidp) kernelModuleCacheName
194+ let kernelModuleCache = L. ConstantOperand $ globalReference (hostPtrTy hostVoidp) kernelModuleCacheName
195195 let kernelFuncCacheName = fromString $ pprint fName ++ " #cuFunction"
196196 let kernelFuncCacheDef = L. globalVariableDefaults
197197 { L. name = kernelFuncCacheName
198198 , L. type' = hostVoidp
199199 , L. linkage = L. Private
200200 , L. initializer = Just $ C. Null hostVoidp }
201- let kernelFuncCache = L. ConstantOperand $ C. GlobalReference (hostPtrTy hostVoidp) kernelFuncCacheName
201+ let kernelFuncCache = L. ConstantOperand $ globalReference (hostPtrTy hostVoidp) kernelFuncCacheName
202202 let textPtr = C. GetElementPtr True
203- (C. GlobalReference (hostPtrTy textArrTy) textGlobalName)
203+ #if MIN_VERSION_llvm_hs(15,0,0)
204+ textArrTy
205+ #endif
206+ (globalReference (hostPtrTy textArrTy) textGlobalName)
204207 [C. Int 32 0 , C. Int 32 0 ]
205208 loaderDef <- liftCompile CPU $ do
206209 emitVoidExternCall kernelLoaderSpec
207210 [ L. ConstantOperand $ textPtr, kernelModuleCache, kernelFuncCache]
208- kernelFunc <- load kernelFuncCache
211+ kernelFunc <- load hostVoidp kernelFuncCache
209212 makeFunction (topLevelFunName fName) [] (Just kernelFunc)
210213 let dtorName = fromString $ pprint fName ++ " #dtor"
211214 dtorDef <- liftCompile CPU $ do
@@ -263,12 +266,13 @@ compileInstr instr = case instr of
263266 return [i32Lit 1 , numThreads]
264267 where queryParallelismMCFun = ExternFunSpec " dex_queryParallelismMC" i32 [] [] [i64]
265268 CUDAKernelLaunch -> do
266- let ptxPtrFun = callableOperand (funTy (hostPtrTy i8) [] ) kernelFuncName
267- kernelPtr <- emitInstr (hostPtrTy i8) $ callInstr ptxPtrFun []
269+ let ptxPtrFunTy = funTy (hostPtrTy i8) []
270+ let ptxPtrFun = callableOperand (hostPtrTy ptxPtrFunTy) kernelFuncName
271+ kernelPtr <- emitInstr (hostPtrTy i8) $ callInstr ptxPtrFunTy ptxPtrFun []
268272 numWorkgroupsPtr <- alloca 1 i32
269273 workgroupSizePtr <- alloca 1 i32
270274 emitVoidExternCall queryParallelismCUDAFun [kernelPtr, n, numWorkgroupsPtr, workgroupSizePtr]
271- traverse load [numWorkgroupsPtr, workgroupSizePtr]
275+ traverse ( load i32) [numWorkgroupsPtr, workgroupSizePtr]
272276 where
273277 queryParallelismCUDAFun = ExternFunSpec " dex_queryParallelismCUDA" L. VoidType [] []
274278 [hostPtrTy i8, i64, hostPtrTy i32, hostPtrTy i32]
@@ -287,11 +291,12 @@ compileInstr instr = case instr of
287291 MCThreadLaunch -> do
288292 kernelParams <- packArgs args'
289293 let funPtr = L. ConstantOperand $ C. BitCast
290- (C. GlobalReference mcKernelPtrType kernelFuncName) hostVoidp
294+ (globalReference mcKernelPtrType kernelFuncName) hostVoidp
291295 emitVoidExternCall runMCKernel [funPtr, size',kernelParams]
292296 CUDAKernelLaunch -> do
293- let ptxPtrFun = callableOperand (funTy (hostPtrTy i8) [] ) kernelFuncName
294- kernelPtr <- emitInstr (hostPtrTy i8) $ callInstr ptxPtrFun []
297+ let ptxPtrFunTy = funTy (hostPtrTy i8) []
298+ let ptxPtrFun = callableOperand (hostPtrTy ptxPtrFunTy) kernelFuncName
299+ kernelPtr <- emitInstr (hostPtrTy i8) $ callInstr ptxPtrFunTy ptxPtrFun []
295300 kernelParams <- packArgs args'
296301 launchCUDAKernel kernelPtr size' kernelParams
297302 _ -> error $ " Not a valid calling convention for a launch: " ++ pprint cc
@@ -302,7 +307,7 @@ compileInstr instr = case instr of
302307 -- TODO: Implement proper error handling on GPUs.
303308 -- For now we generate an invalid memory access, hoping that the
304309 -- runtime will catch it.
305- GPU -> [] <$ load (L. ConstantOperand $ C. Null $ devicePtrTy i8)
310+ GPU -> [] <$ load i8 (L. ConstantOperand $ C. Null $ devicePtrTy i8)
306311 Alloc a t s -> (: [] ) <$> case a of
307312 Stack -> alloca (getIntLit l) elemTy where ILit l = s
308313 Heap dev -> do
@@ -341,7 +346,12 @@ compileInstr instr = case instr of
341346 val' <- compileExpr val
342347 store dest' val'
343348 return Nothing
344- IPrimOp (PtrLoad ptr) -> (: [] ) <$> (compileExpr ptr >>= load)
349+ -- We handle pointer operations explicitly, because we need type information that
350+ -- might get erased by compileExpr.
351+ IPrimOp (PtrLoad ptr) ->
352+ (: [] ) <$> (load (pointeeType $ getIType ptr) =<< compileExpr ptr)
353+ IPrimOp (PtrOffset ptr off) ->
354+ (: [] ) <$> bindM2 (gep (pointeeType $ getIType ptr)) (compileExpr ptr) (compileExpr off)
345355 IPrimOp op -> (: [] ) <$> (traverse compileExpr op >>= compilePrimOp)
346356 ICastOp idt ix -> (: [] ) <$> do
347357 x <- compileExpr ix
@@ -358,10 +368,15 @@ compileInstr instr = case instr of
358368 GT -> emitInstr dt $ L. FPTrunc x dt []
359369 (L. FloatingPointType _, L. IntegerType _) -> emitInstr dt $ L. FPToSI x dt []
360370 (L. IntegerType _, L. FloatingPointType _) -> emitInstr dt $ L. SIToFP x dt []
371+ #if MIN_VERSION_llvm_hs(15,0,0)
372+ (L. IntegerType 64 , ptrTy@ (L. PointerType _)) -> emitInstr ptrTy $ L. IntToPtr x ptrTy []
373+ (L. PointerType _ , L. IntegerType 64 ) -> emitInstr i64 $ L. PtrToInt x i64 []
374+ #else
361375 (L. PointerType _ _, L. PointerType eltTy _) -> castLPtr eltTy x
362376 (L. IntegerType 64 , ptrTy@ (L. PointerType _ _)) ->
363377 emitInstr ptrTy $ L. IntToPtr x ptrTy []
364378 (L. PointerType _ _, L. IntegerType 64 ) -> emitInstr i64 $ L. PtrToInt x i64 []
379+ #endif
365380 _ -> error $ " Unsupported cast"
366381 ICall f args -> do
367382 fImpName <- substM f
@@ -384,12 +399,12 @@ compileInstr instr = case instr of
384399 emitVoidExternCall (makeFunSpec fname ty) (resultPtr : args')
385400 loadMultiResultAlloc resultTys' resultPtr
386401 CEntryFun -> do
387- exitCode <- emitInstr i64 (callInstr fun args') >>= (`asIntWidth` i1)
402+ exitCode <- emitInstr i64 (callInstr fTy fun args') >>= (`asIntWidth` i1)
388403 compileIf exitCode throwRuntimeError (return () )
389404 return []
390405 where
391406 fTy = funTy i64 $ map scalarTy argTys
392- fun = callableOperand fTy $ topLevelFunName fname
407+ fun = callableOperand (hostPtrTy fTy) $ topLevelFunName fname
393408 CInternalFun -> do
394409 exitCode <- emitExternCall (makeFunSpec fname ty) args' >>= (`asIntWidth` i1)
395410 compileIf exitCode throwRuntimeError (return () )
@@ -422,7 +437,7 @@ compileLoop d iBinder n compileBody = do
422437 store i i0
423438 entryCond <- (0 `withWidthOf` n) `ilt` n
424439 finishBlock (L. CondBr entryCond loopBlock nextBlock [] ) loopBlock
425- iVal <- load i
440+ iVal <- load (scalarTy $ iBinderType iBinder) i
426441 extendSubst (iBinder @> opSubstVal iVal) $ compileBody
427442 iValNew <- case d of Fwd -> add iVal (1 `withWidthOf` iVal)
428443 Rev -> sub iVal (1 `withWidthOf` iVal)
@@ -471,9 +486,7 @@ compilePrimOp pop = case pop of
471486 undef = L. ConstantOperand $ C. Undef resTy
472487 VectorIndex v i -> emitInstr resTy $ L. ExtractElement v i []
473488 where (L. VectorType _ resTy) = typeOf v
474- PtrOffset ptr off -> gep ptr off
475489 OutputStreamPtr -> return outputStreamPtr
476-
477490 _ -> error $ " Can't JIT primop: " ++ pprint pop
478491
479492compileUnOp :: LLVMBuilder m => UnOp -> Operand -> m Operand
@@ -563,8 +576,8 @@ impKernelToLLVMGPU (ImpFunction _ (Abs args body)) = do
563576 let numThreadInfoArgs = 4 -- [threadIdParam, nThreadParam, argArrayParam]
564577 let argTypes = drop numThreadInfoArgs $ nestToList (scalarTy . iBinderType) args
565578 let kernelMeta = L. MetadataNodeDefinition kernelMetaId $ L. MDTuple
566- [ Just $ L. MDValue $ L. ConstantOperand $ C. GlobalReference
567- (funTy L. VoidType argTypes) " kernel"
579+ [ Just $ L. MDValue $ L. ConstantOperand $ globalReference
580+ (hostPtrTy $ funTy L. VoidType argTypes) " kernel"
568581 , Just $ L. MDString " kernel"
569582 , Just $ L. MDValue $ L. ConstantOperand $ C. Int 32 1
570583 ]
@@ -654,7 +667,7 @@ _gpuDebugPrint i32Val = do
654667 valPtri8 <- castLPtr i8 valPtr
655668 void $ emitExternCall vprintfSpec [formatStrPtr, valPtri8]
656669 where
657- genericPtrTy ty = L. PointerType ty $ L. AddrSpace 0
670+ genericPtrTy ty = pointerType ty $ L. AddrSpace 0
658671 vprintfSpec = ExternFunSpec " vprintf" i32 [] [] [genericPtrTy i8, genericPtrTy i8]
659672
660673-- Takes a single int64 payload. TODO: implement a varargs version
@@ -699,30 +712,30 @@ packArgs elems = do
699712 forM_ (zip [0 .. ] elems) \ (i, e) -> do
700713 eptr <- alloca 1 $ typeOf e
701714 store eptr e
702- earr <- gep arr $ i32Lit i
715+ earr <- gep hostVoidp arr $ i32Lit i
703716 store earr =<< castVoidPtr eptr
704717 return arr
705718
706719unpackArgs :: LLVMBuilder m => Operand -> [L. Type ] -> m [Operand ]
707720unpackArgs argArrayPtr types =
708721 forM (zip [0 .. ] types) \ (i, ty) -> do
709- argVoidPtr <- gep argArrayPtr $ i64Lit i
722+ argVoidPtr <- gep hostVoidp argArrayPtr $ i64Lit i
710723 argPtr <- castLPtr (hostPtrTy ty) argVoidPtr
711- load =<< load argPtr
724+ load ty =<< load (hostPtrTy ty) argPtr
712725
713726makeMultiResultAlloc :: LLVMBuilder m => [L. Type ] -> m Operand
714727makeMultiResultAlloc tys = do
715728 resultsPtr <- alloca (length tys) hostVoidp
716729 forM_ (zip [0 .. ] tys) \ (i, ty) -> do
717730 ptr <- alloca 1 ty >>= castVoidPtr
718- resultsPtrOffset <- gep resultsPtr $ i32Lit i
731+ resultsPtrOffset <- gep hostVoidp resultsPtr $ i32Lit i
719732 store resultsPtrOffset ptr
720733 return resultsPtr
721734
722735loadMultiResultAlloc :: LLVMBuilder m => [L. Type ] -> Operand -> m [Operand ]
723736loadMultiResultAlloc tys ptr =
724737 forM (zip [0 .. ] tys) \ (i, ty) ->
725- gep ptr (i32Lit i) >>= load >>= castLPtr ty >>= load
738+ gep hostVoidp ptr (i32Lit i) >>= load hostVoidp >>= castLPtr ty >>= load ty
726739
727740runMCKernel :: ExternFunSpec
728741runMCKernel = ExternFunSpec " dex_launchKernelMC" L. VoidType [] [] [hostVoidp, i64, hostPtrTy hostVoidp]
@@ -776,9 +789,13 @@ withWidthOfFP x template = case typeOf template of
776789store :: LLVMBuilder m => Operand -> Operand -> m ()
777790store ptr x = addInstr $ L. Do $ L. Store False ptr x Nothing 0 []
778791
779- load :: LLVMBuilder m => Operand -> m Operand
780- load ptr = emitInstr ty $ L. Load False ptr Nothing 0 []
781- where (L. PointerType ty _) = typeOf ptr
792+ load :: LLVMBuilder m => L. Type -> Operand -> m Operand
793+ load pointeeTy ptr =
794+ #if MIN_VERSION_llvm_hs(15,0,0)
795+ emitInstr pointeeTy $ L. Load False pointeeTy ptr Nothing 0 []
796+ #else
797+ emitInstr pointeeTy $ L. Load False ptr Nothing 0 []
798+ #endif
782799
783800ilt :: LLVMBuilder m => Operand -> Operand -> m Operand
784801ilt x y = emitInstr i1 $ L. ICmp IP. SLT x y []
@@ -795,8 +812,14 @@ sub x y = emitInstr (typeOf x) $ L.Sub False False x y []
795812mul :: LLVMBuilder m => Operand -> Operand -> m Operand
796813mul x y = emitInstr (typeOf x) $ L. Mul False False x y []
797814
798- gep :: LLVMBuilder m => Operand -> Operand -> m Operand
799- gep ptr i = emitInstr (typeOf ptr) $ L. GetElementPtr False ptr [i] []
815+ gep :: LLVMBuilder m => L. Type -> Operand -> Operand -> m Operand
816+ #if MIN_VERSION_llvm_hs(15,0,0)
817+ gep pointeeTy ptr i =
818+ emitInstr (typeOf ptr) $ L. GetElementPtr False pointeeTy ptr [i] []
819+ #else
820+ gep _ ptr i =
821+ emitInstr (typeOf ptr) $ L. GetElementPtr False ptr [i] []
822+ #endif
800823
801824sizeof :: L. Type -> Operand
802825sizeof t = L. ConstantOperand $ C. sizeof 64 t
@@ -822,24 +845,34 @@ free ptr = do
822845 emitVoidExternCall freeFun [ptr']
823846
824847castLPtr :: LLVMBuilder m => L. Type -> Operand -> m Operand
848+ #if MIN_VERSION_llvm_hs(15,0,0)
849+ castLPtr _ = return
850+ #else
825851castLPtr ty ptr = emitInstr newPtrTy $ L. BitCast ptr newPtrTy []
826852 where
827853 L. PointerType _ addr = typeOf ptr
828854 newPtrTy = L. PointerType ty addr
855+ #endif
829856
830857castVoidPtr :: LLVMBuilder m => Operand -> m Operand
831858castVoidPtr = castLPtr i8
832859
833860zeroExtendTo :: LLVMBuilder m => Operand -> L. Type -> m Operand
834861zeroExtendTo x t = emitInstr t $ L. ZExt x t []
835862
836- callInstr :: L. CallableOperand -> [L. Operand ] -> L. Instruction
837- callInstr fun xs = L. Call Nothing L. C [] fun xs' [] []
863+ callInstr :: L. Type -> L. CallableOperand -> [L. Operand ] -> L. Instruction
864+ #if MIN_VERSION_llvm_hs(15,0,0)
865+ callInstr fty fun xs = L. Call Nothing L. C [] fty fun xs' [] []
866+ #else
867+ callInstr _ fun xs = L. Call Nothing L. C [] fun xs' [] []
868+ #endif
838869 where xs' = [(x ,[] ) | x <- xs]
839870
840871externCall :: ExternFunSpec -> [L. Operand ] -> L. Instruction
841- externCall (ExternFunSpec fname retTy _ _ argTys) xs = callInstr fun xs
842- where fun = callableOperand (funTy retTy argTys) fname
872+ externCall (ExternFunSpec fname retTy _ _ argTys) xs = callInstr ft fun xs
873+ where
874+ ft = funTy retTy argTys
875+ fun = callableOperand (hostPtrTy ft) fname
843876
844877emitExternCall :: LLVMBuilder m => ExternFunSpec -> [L. Operand ] -> m Operand
845878emitExternCall f@ (ExternFunSpec _ retTy _ _ _) xs = do
@@ -862,13 +895,18 @@ scalarTy b = case b of
862895 Float64Type -> fp64
863896 Float32Type -> fp32
864897 Vector sb -> L. VectorType (fromIntegral vectorWidth) $ scalarTy $ Scalar sb
865- PtrType (s, t) -> L. PointerType (scalarTy t) (lAddress s)
898+ PtrType (s, t) -> pointerType (scalarTy t) (lAddress s)
899+
900+ pointeeType :: BaseType -> L. Type
901+ pointeeType b = case b of
902+ PtrType (_, t) -> scalarTy t
903+ _ -> error " Not a pointer type!"
866904
867905hostPtrTy :: L. Type -> L. Type
868- hostPtrTy ty = L. PointerType ty $ L. AddrSpace 0
906+ hostPtrTy ty = pointerType ty $ L. AddrSpace 0
869907
870908devicePtrTy :: L. Type -> L. Type
871- devicePtrTy ty = L. PointerType ty $ L. AddrSpace 1
909+ devicePtrTy ty = pointerType ty $ L. AddrSpace 1
872910
873911lAddress :: HasCallStack => AddressSpace -> L. AddrSpace
874912lAddress s = case s of
@@ -877,7 +915,7 @@ lAddress s = case s of
877915 Heap GPU -> L. AddrSpace 1
878916
879917callableOperand :: L. Type -> L. Name -> L. CallableOperand
880- callableOperand ty name = Right $ L. ConstantOperand $ C. GlobalReference ty name
918+ callableOperand ty name = Right $ L. ConstantOperand $ globalReference ty name
881919
882920asIntWidth :: LLVMBuilder m => Operand -> L. Type -> m Operand
883921asIntWidth op ~ expTy@ (L. IntegerType expWidth) = case compare expWidth opWidth of
@@ -890,7 +928,11 @@ freshParamOpPair :: LLVMBuilder m => [L.ParameterAttribute] -> L.Type -> m (Para
890928freshParamOpPair ptrAttrs ty = do
891929 v <- freshName noHint
892930 let attrs = case ty of
931+ #if MIN_VERSION_llvm_hs(15,0,0)
932+ L. PointerType _ -> ptrAttrs
933+ #else
893934 L. PointerType _ _ -> ptrAttrs
935+ #endif
894936 _ -> []
895937 return (L. Parameter ty v attrs, L. LocalReference ty v)
896938
@@ -958,8 +1000,22 @@ outputStreamPtrDef = L.GlobalDefinition $ L.globalVariableDefaults
9581000 , L. initializer = Just $ C. Null hostVoidp }
9591001
9601002outputStreamPtr :: Operand
961- outputStreamPtr = L. ConstantOperand $ C. GlobalReference
962- (hostPtrTy hostVoidp) outputStreamPtrLName
1003+ outputStreamPtr =
1004+ L. ConstantOperand $ globalReference (hostPtrTy hostVoidp) outputStreamPtrLName
1005+
1006+ globalReference :: L. Type -> L. Name -> C. Constant
1007+ #if MIN_VERSION_llvm_hs(15,0,0)
1008+ globalReference = const C. GlobalReference
1009+ #else
1010+ globalReference = C. GlobalReference
1011+ #endif
1012+
1013+ pointerType :: L. Type -> L. AddrSpace -> L. Type
1014+ #if MIN_VERSION_llvm_hs(15,0,0)
1015+ pointerType = const L. PointerType
1016+ #else
1017+ pointerType = L. PointerType
1018+ #endif
9631019
9641020initializeOutputStream :: LLVMBuilder m => Operand -> m ()
9651021initializeOutputStream streamFD = do
@@ -1099,7 +1155,7 @@ deviceVoidp :: L.Type
10991155deviceVoidp = devicePtrTy i8
11001156
11011157funTy :: L. Type -> [L. Type ] -> L. Type
1102- funTy retTy argTys = hostPtrTy $ L. FunctionType retTy argTys False
1158+ funTy retTy argTys = L. FunctionType retTy argTys False
11031159
11041160-- === Module building ===
11051161
0 commit comments