Skip to content

Commit a59ec91

Browse files
authored
python: fix CalledProcessError on Intel Macs since v2.8.0 (#3045)
Signed-off-by: Jared Van Bortel <[email protected]>
1 parent 8e3108f commit a59ec91

File tree

3 files changed

+56
-45
lines changed

3 files changed

+56
-45
lines changed

gpt4all-bindings/python/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
1212
### Changed
1313
- Rebase llama.cpp on latest upstream as of September 26th ([#2998](https://github.com/nomic-ai/gpt4all/pull/2998))
1414
- Change the error message when a message is too long ([#3004](https://github.com/nomic-ai/gpt4all/pull/3004))
15+
- Fix CalledProcessError on Intel Macs since v2.8.0 ([#3045](https://github.com/nomic-ai/gpt4all/pull/3045))
1516

1617
## [2.8.2] - 2024-08-14
1718

gpt4all-bindings/python/gpt4all/_pyllmodel.py

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import ctypes
44
import os
55
import platform
6-
import re
76
import subprocess
87
import sys
98
import textwrap
@@ -28,16 +27,25 @@
2827

2928
EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]')
3029

30+
cuda_found: bool = False
31+
32+
33+
# TODO(jared): use operator.call after we drop python 3.10 support
34+
def _operator_call(obj, /, *args, **kwargs):
35+
return obj(*args, **kwargs)
36+
3137

3238
# Detect Rosetta 2
33-
if platform.system() == "Darwin" and platform.processor() == "i386":
34-
if subprocess.run(
35-
"sysctl -n sysctl.proc_translated".split(), check=True, capture_output=True, text=True,
36-
).stdout.strip() == "1":
37-
raise RuntimeError(textwrap.dedent("""\
38-
Running GPT4All under Rosetta is not supported due to CPU feature requirements.
39-
Please install GPT4All in an environment that uses a native ARM64 Python interpreter.
40-
""").strip())
39+
@_operator_call
40+
def check_rosetta() -> None:
41+
if platform.system() == "Darwin" and platform.processor() == "i386":
42+
p = subprocess.run("sysctl -n sysctl.proc_translated".split(), capture_output=True, text=True)
43+
if p.returncode == 0 and p.stdout.strip() == "1":
44+
raise RuntimeError(textwrap.dedent("""\
45+
Running GPT4All under Rosetta is not supported due to CPU feature requirements.
46+
Please install GPT4All in an environment that uses a native ARM64 Python interpreter.
47+
""").strip())
48+
4149

4250
# Check for C++ runtime libraries
4351
if platform.system() == "Windows":
@@ -53,33 +61,35 @@
5361
"""), file=sys.stderr)
5462

5563

56-
def _load_cuda(rtver: str, blasver: str) -> None:
57-
if platform.system() == "Linux":
58-
cudalib = f"lib/libcudart.so.{rtver}"
59-
cublaslib = f"lib/libcublas.so.{blasver}"
60-
else: # Windows
61-
cudalib = fr"bin\cudart64_{rtver.replace('.', '')}.dll"
62-
cublaslib = fr"bin\cublas64_{blasver}.dll"
63-
64-
# preload the CUDA libs so the backend can find them
65-
ctypes.CDLL(os.path.join(cuda_runtime.__path__[0], cudalib), mode=ctypes.RTLD_GLOBAL)
66-
ctypes.CDLL(os.path.join(cublas.__path__[0], cublaslib), mode=ctypes.RTLD_GLOBAL)
67-
68-
69-
# Find CUDA libraries from the official packages
70-
cuda_found = False
71-
if platform.system() in ("Linux", "Windows"):
72-
try:
73-
from nvidia import cuda_runtime, cublas
74-
except ImportError:
75-
pass # CUDA is optional
76-
else:
77-
for rtver, blasver in [("12", "12"), ("11.0", "11")]:
78-
try:
79-
_load_cuda(rtver, blasver)
80-
cuda_found = True
81-
except OSError: # dlopen() does not give specific error codes
82-
pass # try the next one
64+
@_operator_call
65+
def find_cuda() -> None:
66+
global cuda_found
67+
68+
def _load_cuda(rtver: str, blasver: str) -> None:
69+
if platform.system() == "Linux":
70+
cudalib = f"lib/libcudart.so.{rtver}"
71+
cublaslib = f"lib/libcublas.so.{blasver}"
72+
else: # Windows
73+
cudalib = fr"bin\cudart64_{rtver.replace('.', '')}.dll"
74+
cublaslib = fr"bin\cublas64_{blasver}.dll"
75+
76+
# preload the CUDA libs so the backend can find them
77+
ctypes.CDLL(os.path.join(cuda_runtime.__path__[0], cudalib), mode=ctypes.RTLD_GLOBAL)
78+
ctypes.CDLL(os.path.join(cublas.__path__[0], cublaslib), mode=ctypes.RTLD_GLOBAL)
79+
80+
# Find CUDA libraries from the official packages
81+
if platform.system() in ("Linux", "Windows"):
82+
try:
83+
from nvidia import cuda_runtime, cublas
84+
except ImportError:
85+
pass # CUDA is optional
86+
else:
87+
for rtver, blasver in [("12", "12"), ("11.0", "11")]:
88+
try:
89+
_load_cuda(rtver, blasver)
90+
cuda_found = True
91+
except OSError: # dlopen() does not give specific error codes
92+
pass # try the next one
8393

8494

8595
# TODO: provide a config file to make this more robust
@@ -121,6 +131,7 @@ class LLModelPromptContext(ctypes.Structure):
121131
("context_erase", ctypes.c_float),
122132
]
123133

134+
124135
class LLModelGPUDevice(ctypes.Structure):
125136
_fields_ = [
126137
("backend", ctypes.c_char_p),
@@ -131,6 +142,7 @@ class LLModelGPUDevice(ctypes.Structure):
131142
("vendor", ctypes.c_char_p),
132143
]
133144

145+
134146
# Define C function signatures using ctypes
135147
llmodel.llmodel_model_create.argtypes = [ctypes.c_char_p]
136148
llmodel.llmodel_model_create.restype = ctypes.c_void_p
@@ -540,7 +552,6 @@ def prompt_model(
540552
ctypes.c_char_p(),
541553
)
542554

543-
544555
def prompt_model_streaming(
545556
self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
546557
) -> Iterable[str]:
@@ -589,16 +600,16 @@ def _raw_callback(token_id: int, response: bytes) -> bool:
589600
decoded = []
590601

591602
for byte in response:
592-
603+
593604
bits = "{:08b}".format(byte)
594605
(high_ones, _, _) = bits.partition('0')
595606

596-
if len(high_ones) == 1:
607+
if len(high_ones) == 1:
597608
# continuation byte
598609
self.buffer.append(byte)
599610
self.buff_expecting_cont_bytes -= 1
600611

601-
else:
612+
else:
602613
# beginning of a byte sequence
603614
if len(self.buffer) > 0:
604615
decoded.append(self.buffer.decode(errors='replace'))
@@ -608,18 +619,18 @@ def _raw_callback(token_id: int, response: bytes) -> bool:
608619
self.buffer.append(byte)
609620
self.buff_expecting_cont_bytes = max(0, len(high_ones) - 1)
610621

611-
if self.buff_expecting_cont_bytes <= 0:
622+
if self.buff_expecting_cont_bytes <= 0:
612623
# received the whole sequence or an out of place continuation byte
613624
decoded.append(self.buffer.decode(errors='replace'))
614625

615626
self.buffer.clear()
616627
self.buff_expecting_cont_bytes = 0
617-
628+
618629
if len(decoded) == 0 and self.buff_expecting_cont_bytes > 0:
619630
# wait for more continuation bytes
620631
return True
621-
622-
return callback(token_id, ''.join(decoded))
632+
633+
return callback(token_id, ''.join(decoded))
623634

624635
return _raw_callback
625636

gpt4all-bindings/python/gpt4all/gpt4all.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import platform
99
import re
1010
import sys
11-
import time
1211
import warnings
1312
from contextlib import contextmanager
1413
from pathlib import Path

0 commit comments

Comments
 (0)