Skip to content

Commit e5800d6

Browse files
Optimisation/mlx llm (#123)
* exported namespace management, for use in other libraries. * prep_namespace exposed for use in other libraries. * merge conflict in uv.lock * mlx llms working. * Updating github test cases for AS-M1. * Updating github test cases for AS-M1.. * Updating github test cases for AS-M1...
1 parent 6c9ee29 commit e5800d6

File tree

9 files changed

+608
-42
lines changed

9 files changed

+608
-42
lines changed

.github/workflows/docs2pages.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
#----------------------------------------------
5252
- name: Install dependencies
5353
if: steps.cached-uv-dependencies.outputs.cache-hit != 'true'
54-
run: uv sync --dev --group docs
54+
run: uv sync --dev --group docs --no-group apple-silicon
5555
- name: Build documentation
5656
run: |
5757
uv run sphinx-build -b html docs/ ./_site

.github/workflows/test-windows.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ jobs:
4545
#----------------------------------------------
4646
- name: Install dependencies
4747
if: steps.cached-uv-dependencies.outputs.cache-hit != 'true'
48-
run: uv sync --dev
48+
run: uv sync --dev --no-group apple-silicon
4949
#----------------------------------------------
5050
# run black
5151
#----------------------------------------------

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ jobs:
4242
#----------------------------------------------
4343
- name: Install dependencies
4444
if: steps.cached-uv-dependencies.outputs.cache-hit != 'true'
45-
run: uv sync --dev
45+
run: uv sync --dev --no-group apple-silicon
4646
#----------------------------------------------
4747
# run black
4848
#----------------------------------------------

llamea/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
from .llamea import LLaMEA
2-
from .llm import LLM, Dummy_LLM, Gemini_LLM, Multi_LLM, Ollama_LLM, OpenAI_LLM
2+
from .llm import (
3+
LLM,
4+
Dummy_LLM,
5+
Gemini_LLM,
6+
Multi_LLM,
7+
Ollama_LLM,
8+
OpenAI_LLM,
9+
LMStudio_LLM,
10+
MLX_LM_LLM,
11+
)
312
from .loggers import ExperimentLogger
413
from .solution import Solution
514
from .utils import (

llamea/llamea.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def __call__(self, func):
310310

311311
if self.log:
312312
modelname = self.model.replace(":", "_")
313+
modelname = self.model.replace("/", "_")
313314
self.logger = ExperimentLogger(f"LLaMEA-{modelname}-{experiment_name}")
314315
self.llm.set_logger(self.logger)
315316
else:

llamea/llm.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@
2424
import openai
2525
except ModuleNotFoundError: # pragma: no cover - optional dependency
2626
openai = None
27+
try:
28+
import lmstudio as lms
29+
except ModuleNotFoundError:
30+
lms = object
31+
try:
32+
from mlx_lm import load, generate
33+
except ModuleNotFoundError:
34+
load = None
35+
generate = None
2736

2837
try:
2938
from ConfigSpace import ConfigurationSpace
@@ -545,6 +554,173 @@ def __init__(self, api_key, model="deepseek-chat", temperature=0.8, **kwargs):
545554
self.client = openai.OpenAI(**self._client_kwargs)
546555

547556

557+
class LMStudio_LLM(LLM):
558+
"""A manager for running MLX-Optimised LLM locally."""
559+
560+
def __init__(self, model, config=None, **kwargs):
561+
"""
562+
Initialises the LMStudio LLM inteface.
563+
564+
:param model: Name of the model, to be initialised for interaction.
565+
:param config: Configuration to be set for LLM chat.
566+
:param kwargs: Keyed arguements for setting up the LLM chat.
567+
"""
568+
super().__init__(api_key="", model=model, **kwargs)
569+
self.llm = lms.llm(model)
570+
self.config = config
571+
572+
def query(
573+
self, session: list[dict[str, str]], default_delay: int = 5, max_tries: int = 5
574+
) -> str:
575+
"""
576+
Query stub for LMStudio class.
577+
578+
## Parameters
579+
`session: list[dict[str, str]]`: A session message is a list of {'role' : 'user'|'system', 'content': 'content'} data, use to make LLM request.
580+
`default_delay: int`: Amount of time to wait, before retrying a prompt on LLMs when exception occurs.
581+
`max_tries: int`: A max count for the number of tries, to get a response.
582+
"""
583+
request = session[-1]["content"]
584+
for _ in range(max_tries):
585+
try:
586+
if self.config is not None:
587+
response = self.llm.respond(request, config=self.config)
588+
else:
589+
response = self.llm.respond(request)
590+
response = re.sub( # Remove thinking section, if avaiable.
591+
r"<think>.*?</think>", "", str(response), flags=re.DOTALL
592+
)
593+
return response
594+
except:
595+
time.sleep(default_delay)
596+
pass
597+
return ""
598+
599+
def __getstate__(self):
600+
state = self.__dict__.copy()
601+
state.pop("llm", None)
602+
return state
603+
604+
def __setstate__(self, state):
605+
self.__dict__.update(state)
606+
self.llm = lms.llm(self.model)
607+
608+
def __deepcopy__(self, memo):
609+
cls = self.__class__
610+
new = cls.__new__(cls)
611+
memo[id(self)] = new
612+
for k, v in self.__dict__.items():
613+
if k == "llm":
614+
continue
615+
setattr(new, k, copy.deepcopy(v, memo))
616+
new.llm = self.llm
617+
return new
618+
619+
620+
class MLX_LM_LLM(LLM):
621+
"""An mlx_lm implementation for running large LLMs locally."""
622+
623+
def __init__(
624+
self,
625+
model,
626+
config=None,
627+
max_tokens: int = 12000,
628+
chat_template_style=None,
629+
**kwargs,
630+
):
631+
"""
632+
Initialises the LMStudio LLM inteface.
633+
634+
:param model: Name of the model, to be initialised for interaction.
635+
:param config: Configuration to be set for LLM chat.
636+
:param max_tokens: Maximun number of tokens to be generated for a request.
637+
:param chat_template_style: Some models require chat_template_style to be specify, refer to those model's docs in huggingface to set this parameter.
638+
:param kwargs: Keyed arguements for setting up the LLM chat.
639+
"""
640+
super().__init__(api_key="", model=model, **kwargs)
641+
if config is not None:
642+
llm, tokenizer = load(model, model_config=config)
643+
else:
644+
llm, tokenizer = load(model)
645+
self.llm = llm
646+
self.tokenizer = tokenizer
647+
self.chat_template_style = chat_template_style
648+
print(f"Init tokeniser object: {self.tokenizer}.")
649+
650+
self.config = config
651+
self.max_tokens = max_tokens
652+
653+
def __getstate__(self) -> object:
654+
state = self.__dict__.copy()
655+
state.pop("tokenizer", None)
656+
state.pop("llm", None)
657+
return state
658+
659+
def __setstate__(self, state):
660+
self.__dict__.update(state)
661+
if self.config is None:
662+
llm, tokenizer = load(self.model)
663+
else:
664+
llm, tokenizer = load(self.model, model_config=self.config)
665+
self.llm = llm
666+
self.tokenizer = tokenizer
667+
668+
def __deepcopy__(self, memo):
669+
cls = self.__class__
670+
new = cls.__new__(cls)
671+
memo[id(self)] = new
672+
for k, v in self.__dict__.items():
673+
if k in ["llm", "tokenizer"]:
674+
continue
675+
setattr(new, k, copy.deepcopy(v, memo))
676+
new.llm = self.llm # <- reference symantics copy for massive object `llm`.
677+
new.tokenizer = self.tokenizer
678+
return new
679+
680+
def query(
681+
self,
682+
session: list,
683+
max_tries: int = 5,
684+
default_delay: int = 5,
685+
add_generation_prompt: bool = False,
686+
):
687+
"""
688+
Query stub for LMStudio class.
689+
690+
## Parameters
691+
`session: list[dict[str, str]]`: A session message is a list of {'role' : 'user'|'system', 'content': 'content'} data, use to make LLM request.
692+
`max_tries: int`: A max count for the number of tries, to get a response.
693+
`default_delay: int`: Amount of time to wait, before retrying a prompt on LLMs when exception occurs.
694+
`add_generation_prompt: bool`: MLX_LM come with an option to add_generation_prompt to optimise prompts.
695+
"""
696+
if self.chat_template_style is not None:
697+
prompt = self.tokenizer.apply_chat_template(
698+
session,
699+
add_generation_prompt=add_generation_prompt,
700+
chat_template=self.chat_template_style,
701+
)
702+
else:
703+
prompt = self.tokenizer.apply_chat_template(
704+
session, add_generation_prompt=add_generation_prompt
705+
)
706+
for _ in range(max_tries):
707+
try:
708+
response = generate(
709+
self.llm,
710+
self.tokenizer,
711+
prompt,
712+
max_tokens=self.max_tokens, # Disable limit on token count.
713+
)
714+
response = re.sub( # Remove thinking section, if avaiable.
715+
r"<think>.*?</think>", "", str(response), flags=re.DOTALL
716+
)
717+
return response
718+
except:
719+
time.sleep(default_delay)
720+
pass
721+
return ""
722+
723+
548724
class Dummy_LLM(LLM):
549725
def __init__(self, model="DUMMY", **kwargs):
550726
"""

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,11 @@ examples = [
6565
]
6666
llm-extras = [
6767
"torch>=2.6.0,<3",
68-
"transformers>=4.49.0,<5",
68+
"transformers>=4.49.0,<=5",
69+
]
70+
apple-silicon = [
71+
"lmstudio>=1.5.0,<2; platform_system == 'Darwin' and platform_machine == 'arm64'",
72+
"mlx_lm>=0.29.1,<1; platform_system == 'Darwin' and platform_machine == 'arm64'",
6973
]
7074

7175
[tool.uv]

0 commit comments

Comments
 (0)