Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
f5c4ed4
Add lora in vllm & some tests
Nov 20, 2025
67b9651
add batched method in async + more tests
Nov 21, 2025
04b6b70
decrease difference error for lora because of precision issues (e.g. …
Nov 21, 2025
bec6fa5
set lora_request as class attribute
Nov 24, 2025
b129c83
change hf backend to support lora + add testing
Nov 25, 2025
78b2039
clean hf lora tests
Nov 25, 2025
450cd2a
add testing for swapping lora and no-lora
Nov 25, 2025
e2a6a81
remove unnecessary import
Nov 25, 2025
3416104
remove double batch method
Nov 25, 2025
2a5f93e
add comments in the new methods
Dec 3, 2025
245743a
remove comment
Dec 3, 2025
ab0860f
add more tests
Dec 5, 2025
f357d0f
update dependencies
Dec 8, 2025
31334ff
cleaning
Dec 8, 2025
b345d53
change model for testing
Dec 9, 2025
809bf82
add lora dependencies in pytest
Dec 9, 2025
c00c8d7
fix dependencies lora
Dec 9, 2025
eae7c8a
change test model
Dec 9, 2025
797c8d6
fix lora test on transformer
Dec 10, 2025
998af61
increase gpu memory util
Dec 10, 2025
18e114e
decrease gpu memory util
Dec 10, 2025
daddc42
check gpu github
Dec 11, 2025
a951037
change gpu memory util
Dec 11, 2025
bff9a75
debug github
Dec 11, 2025
42f0402
decrease tests
Dec 11, 2025
51f2e08
downgrade triton
Dec 12, 2025
ca8bd0f
trition 3.2
Dec 12, 2025
ec5ceb1
debug models github
Dec 12, 2025
d766049
change model on tests
Dec 12, 2025
5568324
remove test for cache reasons
Dec 12, 2025
73955d2
free disk space
Dec 12, 2025
bda3699
triton
Dec 12, 2025
709d279
add testing for error path
Dec 12, 2025
cc0ba95
add readme
Dec 12, 2025
c52faf2
cleaning
Dec 12, 2025
565f87f
triton reinstall
Dec 15, 2025
bc151d5
triton 3.2
Dec 15, 2025
24c8dee
remove unnecessary reinstall
Dec 15, 2025
15c0154
adding lora methods in base class
Dec 16, 2025
9654250
no cover
Dec 16, 2025
854c1c1
change triton version
Dec 16, 2025
721b6e2
trition uninstall and install
Dec 17, 2025
814ab86
rm triton caches
Dec 17, 2025
ee54495
dependencies explicitly
Dec 17, 2025
675cef5
dependencies explicitly
Dec 17, 2025
dfe2bb8
triton reinstall
Dec 17, 2025
a870bd2
rename methods, add add_new_lora to vllm
Dec 19, 2025
3fcf25f
deps fix
Jan 12, 2026
a49d726
Merge branch 'main' into vicky/lora
vicky-xef Jan 12, 2026
6c689d0
attributes fix
Jan 12, 2026
2c853ef
fix tests
Jan 12, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install -U pip
pip install -e .
pip install -e .[lora]
pip install --force-reinstall 'triton==3.2.0'
pip install -r requirements-dev.txt

- name: Run tests
Expand Down
15 changes: 12 additions & 3 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,27 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 1
- name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@v1.3.1
with:
tool-cache: false
android: true
dotnet: true
haskell: true
large-packages: true
docker-images: true
swap-storage: false

- uses: actions/setup-python@v4
with:
python-version: 3.11.5
cache: 'pip'

- name: Install dependencies
run: |
python -m pip install -U pip
pip install -e .
pip install -e .[lora]
pip install --force-reinstall 'triton==3.2.0'
pip install -r requirements-dev.txt

- name: Run Tests
run: |
python -m pytest tests --ignore=tests/test_mlx.py
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ Or to install with MLX support, run:
pip install genlm-backend[mlx]
```

Or to install with LoRA support, run:

```bash
pip install genlm-backend[lora]
```

## 🧪 Example: Autobatched Sequential Importance Sampling with LLMs

Expand Down
31 changes: 31 additions & 0 deletions genlm/backend/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,37 @@ def batch_next_token_logprobs_sync(self, token_ids_list):
[self.next_token_logprobs_sync(token_ids) for token_ids in token_ids_list]
)

def add_new_lora(self, lora_path, lora_name):
"""Load a LoRA adapter into the base model.

Args:
lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
lora_name (str): Name to assign to the loaded adapter.

"""
raise NotImplementedError(
"add_new_lora must be implemented by subclasses"
) # pragma: no cover

def set_lora(self, lora_path, lora_name):
"""Activate a previously loaded LoRA adapter.

Args:
lora_name (str): Name of the LoRA adapter to activate.

"""
raise NotImplementedError(
"set_lora must be implemented by subclasses"
) # pragma: no cover

def clear_lora(self):
"""
Deactivate all LoRA adapters.
"""
raise NotImplementedError(
"clear_lora must be implemented by subclasses"
) # pragma: no cover

def clear_cache(self):
"""Clear any caches used by the language model. No-op in base class."""
pass # pragma: no cover
Expand Down
34 changes: 34 additions & 0 deletions genlm/backend/llm/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,40 @@ def cache_kv(self, prompt_tokens):
result = self.model(torch.tensor([prompt_tokens]).to(self.device))
node = self.cache.extend_cache(0, prompt_tokens, result.logits[0], 0)
node.past_key_values = result.past_key_values

def add_new_lora(self, lora_path, lora_name='lora_1'):
"""Load a LoRA adapter into the base model.

Args:
lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
lora_name (str): Name to assign to the loaded adapter.

Notes:
This does not activate the adapter immediately. Call `set_lora()` to enable the adapter.
"""
self.model.load_adapter(lora_path, lora_name)

def set_lora(self, lora_path=None, lora_name='lora_1'):
"""Activate a previously loaded LoRA adapter.

Args:
lora_name (str): Name of the LoRA adapter to activate.

"""
if lora_name not in list(self.model.peft_config.keys()):
raise ValueError(f"A LoRA adapter named '{lora_name}' has not been loaded yet. Please call add_new_lora() first to load and name your LoRA adapters.")

self.clear_kv_cache()
self.clear_cache()
self.model.set_adapter(lora_name)

def clear_lora(self):
"""
Deactivate all LoRA adapters.
"""
self.clear_kv_cache()
self.clear_cache()
self.model.set_adapter([])

@torch.no_grad()
def batch_evaluate_queries(self):
Expand Down
51 changes: 50 additions & 1 deletion genlm/backend/llm/vllm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import torch
import logging
import warnings
import hashlib

from genlm.backend.llm.base import AsyncLM
from genlm.backend.cache import OutputCache

try:
from vllm import AsyncLLMEngine, SamplingParams, AsyncEngineArgs
from vllm.lora.request import LoRARequest
from vllm.utils import Counter
from vllm.inputs import TokensPrompt

Expand Down Expand Up @@ -81,6 +83,8 @@ def __init__(self, async_llm_engine, cache_size=0, cache_opts={}):
if cache_size > 0
else None
)
self.lora_request = None
self.lora_name_to_ids = {}

async_llm_engine.engine.log_stats = False

Expand Down Expand Up @@ -128,6 +132,48 @@ def from_name(cls, model_name, engine_opts=None, **kwargs):
def underlying_model(self):
return self.async_llm_engine.engine.model_executor.driver_worker.model_runner.model

def clear_lora(self):
"""
Disable any active LoRA adapter for the vLLM engine.
"""
self.lora_request = None

def add_new_lora(self, lora_path, lora_name='lora_1'):
"""Load a LoRA adapter into the base model by creating a unique id for it.

Args:
lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
lora_name (str): Name to assign to the loaded adapter.

Notes:
This does not activate the adapter immediately. Call `set_lora()` to enable the adapter.
"""
self.lora_name_to_ids[lora_name] = self.hash_to_int(lora_name)

def hash_to_int(self, value):
"""Generates a deterministic unique id for a LoRA adapter from its name.

Args:
value (str): The name of the LoRA adapter to hash.

Returns:
An integer ID corresponding to the LoRA adapter, in the range 0–255.
"""
hash_bytes = hashlib.shake_128(value.encode("utf-8")).digest(1)
return int.from_bytes(hash_bytes, "big")

def set_lora(self, lora_path, lora_name='lora_1'):
"""Configure a LoRA adapter request for the vLLM engine.

Args:
lora_path (str): Path to the adapter weights directory or identifier in HuggingFace's model hub.
lora_name (str): Identifier name to associate with this LoRA adapter within vLLM.
lora_id (int): Globally unique ID for the adapter.
"""
if lora_name not in self.lora_name_to_ids.keys():
raise ValueError(f"A LoRA adapter named '{lora_name}' has not been loaded yet. Please call add_new_lora() first to load and name your LoRA adapters.")
self.lora_request = LoRARequest(lora_name, self.lora_name_to_ids[lora_name], lora_path)

async def next_token_logprobs(self, token_ids):
"""Request log probabilities of next token asynchronously with output caching.

Expand Down Expand Up @@ -172,6 +218,7 @@ async def _next_token_logprobs(self, token_ids):
sampling_params=SamplingParams(
**self.default_params, logits_processors=[processor]
),
lora_request=self.lora_request,
request_id=req_id,
):
if output.finished:
Expand Down Expand Up @@ -215,11 +262,12 @@ def batch_next_token_logprobs_sync(self, token_ids_list):
params=SamplingParams(
**self.default_params, logits_processors=[processor]
),
lora_request=self.lora_request,
request_id=req_id,
)

while self.async_llm_engine.engine.has_unfinished_requests():
output = self.async_llm_engine.engine.step()
output = self.async_llm_engine.engine.step()
for out in output:
if out.finished:
assert out.request_id in req_id2processors, (
Expand Down Expand Up @@ -275,6 +323,7 @@ async def sample(
seed=seed,
stop=[self.byte_vocab[i].decode() for i in eos_token_ids],
),
lora_request=self.lora_request,
request_id=str(next(self.request_counter)),
):
if output.finished:
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ mlx = [
"mlx",
"mlx-lm"
]
lora = [
'peft'
]
docs = [
"mkdocs",
"mkdocstrings[python]",
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
destroy_model_parallel,
destroy_distributed_environment,
)
from vllm.lora.request import LoRARequest

HAS_VLLM = True
except ImportError:
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(self, llm):
stop=None,
ignore_eos=True,
)
self.lora_request = None

self.llm.llm_engine.log_stats = False

Expand All @@ -158,11 +160,18 @@ def from_name(cls, model_name, llm_opts=None):
llm = LLM(model=model_name, tokenizer=model_name, **llm_opts)
return cls(llm)

def clear_lora(self):
self.lora_request = None

def set_lora(self, lora_path, lora_name="current_lora", lora_id=1):
self.lora_request = LoRARequest(lora_name, lora_id, lora_path)

def next_token_logprobs_sync(self, token_ids):
outputs = self.llm.generate(
prompts=TokensPrompt(prompt_token_ids=token_ids),
sampling_params=self.DEFAULT_SAMPLING_PARAMS,
use_tqdm=False,
lora_request=self.lora_request
)
logprobs = np.array(
[
Expand All @@ -185,6 +194,7 @@ async def batch_next_token_logprobs(self, token_ids_list):
prompts=prompts,
sampling_params=self.DEFAULT_SAMPLING_PARAMS,
use_tqdm=False,
lora_request=self.lora_request
)
logprobs = np.array(
[
Expand Down
Loading