33import ctypes
44import os
55import platform
6- import re
76import subprocess
87import sys
98import textwrap
2827
2928EmbeddingsType = TypeVar ('EmbeddingsType' , bound = 'list[Any]' )
3029
30+ cuda_found : bool = False
31+
32+
33+ # TODO(jared): use operator.call after we drop python 3.10 support
34+ def _operator_call (obj , / , * args , ** kwargs ):
35+ return obj (* args , ** kwargs )
36+
3137
3238# Detect Rosetta 2
33- if platform .system () == "Darwin" and platform .processor () == "i386" :
34- if subprocess .run (
35- "sysctl -n sysctl.proc_translated" .split (), check = True , capture_output = True , text = True ,
36- ).stdout .strip () == "1" :
37- raise RuntimeError (textwrap .dedent ("""\
38- Running GPT4All under Rosetta is not supported due to CPU feature requirements.
39- Please install GPT4All in an environment that uses a native ARM64 Python interpreter.
40- """ ).strip ())
39+ @_operator_call
40+ def check_rosetta () -> None :
41+ if platform .system () == "Darwin" and platform .processor () == "i386" :
42+ p = subprocess .run ("sysctl -n sysctl.proc_translated" .split (), capture_output = True , text = True )
43+ if p .returncode == 0 and p .stdout .strip () == "1" :
44+ raise RuntimeError (textwrap .dedent ("""\
45+ Running GPT4All under Rosetta is not supported due to CPU feature requirements.
46+ Please install GPT4All in an environment that uses a native ARM64 Python interpreter.
47+ """ ).strip ())
48+
4149
4250# Check for C++ runtime libraries
4351if platform .system () == "Windows" :
5361 """ ), file = sys .stderr )
5462
5563
56- def _load_cuda (rtver : str , blasver : str ) -> None :
57- if platform .system () == "Linux" :
58- cudalib = f"lib/libcudart.so.{ rtver } "
59- cublaslib = f"lib/libcublas.so.{ blasver } "
60- else : # Windows
61- cudalib = fr"bin\cudart64_{ rtver .replace ('.' , '' )} .dll"
62- cublaslib = fr"bin\cublas64_{ blasver } .dll"
63-
64- # preload the CUDA libs so the backend can find them
65- ctypes .CDLL (os .path .join (cuda_runtime .__path__ [0 ], cudalib ), mode = ctypes .RTLD_GLOBAL )
66- ctypes .CDLL (os .path .join (cublas .__path__ [0 ], cublaslib ), mode = ctypes .RTLD_GLOBAL )
67-
68-
69- # Find CUDA libraries from the official packages
70- cuda_found = False
71- if platform .system () in ("Linux" , "Windows" ):
72- try :
73- from nvidia import cuda_runtime , cublas
74- except ImportError :
75- pass # CUDA is optional
76- else :
77- for rtver , blasver in [("12" , "12" ), ("11.0" , "11" )]:
78- try :
79- _load_cuda (rtver , blasver )
80- cuda_found = True
81- except OSError : # dlopen() does not give specific error codes
82- pass # try the next one
64+ @_operator_call
65+ def find_cuda () -> None :
66+ global cuda_found
67+
68+ def _load_cuda (rtver : str , blasver : str ) -> None :
69+ if platform .system () == "Linux" :
70+ cudalib = f"lib/libcudart.so.{ rtver } "
71+ cublaslib = f"lib/libcublas.so.{ blasver } "
72+ else : # Windows
73+ cudalib = fr"bin\cudart64_{ rtver .replace ('.' , '' )} .dll"
74+ cublaslib = fr"bin\cublas64_{ blasver } .dll"
75+
76+ # preload the CUDA libs so the backend can find them
77+ ctypes .CDLL (os .path .join (cuda_runtime .__path__ [0 ], cudalib ), mode = ctypes .RTLD_GLOBAL )
78+ ctypes .CDLL (os .path .join (cublas .__path__ [0 ], cublaslib ), mode = ctypes .RTLD_GLOBAL )
79+
80+ # Find CUDA libraries from the official packages
81+ if platform .system () in ("Linux" , "Windows" ):
82+ try :
83+ from nvidia import cuda_runtime , cublas
84+ except ImportError :
85+ pass # CUDA is optional
86+ else :
87+ for rtver , blasver in [("12" , "12" ), ("11.0" , "11" )]:
88+ try :
89+ _load_cuda (rtver , blasver )
90+ cuda_found = True
91+ except OSError : # dlopen() does not give specific error codes
92+ pass # try the next one
8393
8494
8595# TODO: provide a config file to make this more robust
@@ -121,6 +131,7 @@ class LLModelPromptContext(ctypes.Structure):
121131 ("context_erase" , ctypes .c_float ),
122132 ]
123133
134+
124135class LLModelGPUDevice (ctypes .Structure ):
125136 _fields_ = [
126137 ("backend" , ctypes .c_char_p ),
@@ -131,6 +142,7 @@ class LLModelGPUDevice(ctypes.Structure):
131142 ("vendor" , ctypes .c_char_p ),
132143 ]
133144
145+
134146# Define C function signatures using ctypes
135147llmodel .llmodel_model_create .argtypes = [ctypes .c_char_p ]
136148llmodel .llmodel_model_create .restype = ctypes .c_void_p
@@ -540,7 +552,6 @@ def prompt_model(
540552 ctypes .c_char_p (),
541553 )
542554
543-
544555 def prompt_model_streaming (
545556 self , prompt : str , prompt_template : str , callback : ResponseCallbackType = empty_response_callback , ** kwargs
546557 ) -> Iterable [str ]:
@@ -589,16 +600,16 @@ def _raw_callback(token_id: int, response: bytes) -> bool:
589600 decoded = []
590601
591602 for byte in response :
592-
603+
593604 bits = "{:08b}" .format (byte )
594605 (high_ones , _ , _ ) = bits .partition ('0' )
595606
596- if len (high_ones ) == 1 :
607+ if len (high_ones ) == 1 :
597608 # continuation byte
598609 self .buffer .append (byte )
599610 self .buff_expecting_cont_bytes -= 1
600611
601- else :
612+ else :
602613 # beginning of a byte sequence
603614 if len (self .buffer ) > 0 :
604615 decoded .append (self .buffer .decode (errors = 'replace' ))
@@ -608,18 +619,18 @@ def _raw_callback(token_id: int, response: bytes) -> bool:
608619 self .buffer .append (byte )
609620 self .buff_expecting_cont_bytes = max (0 , len (high_ones ) - 1 )
610621
611- if self .buff_expecting_cont_bytes <= 0 :
622+ if self .buff_expecting_cont_bytes <= 0 :
612623 # received the whole sequence or an out of place continuation byte
613624 decoded .append (self .buffer .decode (errors = 'replace' ))
614625
615626 self .buffer .clear ()
616627 self .buff_expecting_cont_bytes = 0
617-
628+
618629 if len (decoded ) == 0 and self .buff_expecting_cont_bytes > 0 :
619630 # wait for more continuation bytes
620631 return True
621-
622- return callback (token_id , '' .join (decoded ))
632+
633+ return callback (token_id , '' .join (decoded ))
623634
624635 return _raw_callback
625636
0 commit comments