Skip to content

Commit b31ccca

Browse files
committed
[draft_model] Fix OCI mount options for draft_model
When supplying draft model the old code pretty much tried to bindmount model with both src and destination in the end translating to same location. Also there was no support for mulitple file models so when one tried to use 440B as main model and 30B as draft model it would not be possible. To prevent file collisions for example for mmproj and chat_template we use namespaced directory. As of now draft_models won't work with multimodal models but we can still do speculative decoding when disabling vision by passing --no-mmproj. Signed-off-by: Lukas Bezdicka <[email protected]>
1 parent 9201b29 commit b31ccca

File tree

5 files changed

+34
-17
lines changed

5 files changed

+34
-17
lines changed

ramalama/command/context.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from typing import Optional
44

5-
from ramalama.common import check_metal, check_nvidia
5+
from ramalama.common import MNT_DIR, check_metal, check_nvidia
66
from ramalama.console import should_colorize
77
from ramalama.transports.transport_factory import CLASS_MODEL_TYPES, New
88

@@ -123,7 +123,15 @@ def chat_template_path(self) -> Optional[str]:
123123
def draft_model_path(self) -> str:
124124
if getattr(self.model, "draft_model", None):
125125
assert self.model.draft_model
126-
return self.model.draft_model._get_entry_model_path(self.is_container, self.should_generate, self.dry_run)
126+
path = self.model.draft_model._get_entry_model_path(self.is_container, self.should_generate, self.dry_run)
127+
if self.is_container and not path.startswith("oci://"):
128+
# Handle container paths by inserting 'drafts' into the MNT_DIR path
129+
if path.startswith(MNT_DIR):
130+
rest = path[len(MNT_DIR) :].lstrip('/')
131+
return f"{MNT_DIR}/drafts/{rest}"
132+
else:
133+
return path
134+
return path
127135
return ""
128136

129137

ramalama/transports/base.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import ramalama.chat as chat
1313
from ramalama.common import (
1414
MNT_DIR,
15-
MNT_FILE_DRAFT,
1615
accel_image,
1716
exec_cmd,
1817
genname,
@@ -158,8 +157,10 @@ def __init__(self, model: str, model_store_path: str):
158157
def artifact(self) -> bool:
159158
return self.is_artifact()
160159

161-
def extract_model_identifiers(self):
162-
model_name = self.model
160+
def extract_model_identifiers(self, model=None):
161+
if model is None:
162+
model = self.model
163+
model_name = model
163164
model_tag = "latest"
164165
model_organization = ""
165166

@@ -391,12 +392,20 @@ def setup_mounts(self, args):
391392
)
392393

393394
if self.draft_model:
394-
draft_model = self.draft_model._get_entry_model_path(args.container, args.generate, args.dryrun)
395-
# Convert path to container-friendly format (handles Windows path conversion)
396-
container_draft_model = get_container_mount_path(draft_model)
397-
mount_opts = f"--mount=type=bind,src={container_draft_model},destination={MNT_FILE_DRAFT}"
398-
mount_opts += f",ro{self.engine.relabel()}"
399-
self.engine.add([mount_opts])
395+
# Get the model tag for draft model
396+
_, draft_tag, _ = self.extract_model_identifiers(model=self.draft_model.model)
397+
ref_file_draft = self.draft_model.model_store.get_ref_file(draft_tag)
398+
if ref_file_draft is None:
399+
raise NoRefFileFound(self.draft_model.model)
400+
401+
# Mount the draft model files
402+
for file in ref_file_draft.files:
403+
blob_path = self.draft_model.model_store.get_blob_file_path(file.hash)
404+
container_blob_path = get_container_mount_path(blob_path)
405+
mount_path = f"{MNT_DIR}/drafts/{file.name}"
406+
self.engine.add(
407+
[f"--mount=type=bind,src={container_blob_path},destination={mount_path},ro{self.engine.relabel()}"]
408+
)
400409

401410
def bench(self, args, cmd: list[str]):
402411
set_accel_env_vars()

ramalama/transports/huggingface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ def create_repository(self, name, organization, tag):
163163
def get_cli_download_args(self, directory_path, model):
164164
return ["hf", "download", "--local-dir", directory_path, model]
165165

166-
def extract_model_identifiers(self):
167-
model_name, model_tag, model_organization = super().extract_model_identifiers()
166+
def extract_model_identifiers(self, model=None):
167+
model_name, model_tag, model_organization = super().extract_model_identifiers(model)
168168
if '/' not in model_organization:
169169
# if it is a repo then normalize the case insensitive quantization tag
170170
if model_tag != "latest":

ramalama/transports/ollama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ def __init__(self, model, model_store_path) -> None:
146146

147147
self.type = "Ollama"
148148

149-
def extract_model_identifiers(self) -> tuple[str, str, str]:
150-
model_name, model_tag, model_organization = super().extract_model_identifiers()
149+
def extract_model_identifiers(self, model=None) -> tuple[str, str, str]:
150+
model_name, model_tag, model_organization = super().extract_model_identifiers(model)
151151

152152
# use the ollama default namespace if no model organization has been identified
153153
if not model_organization:

ramalama/transports/url.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def __init__(self, model, model_store_path, scheme):
5151
split = self.model.rsplit("/", 1)
5252
self.directory = split[0].removeprefix("/") if len(split) > 1 else ""
5353

54-
def extract_model_identifiers(self):
55-
model_name, model_tag, model_organization = super().extract_model_identifiers()
54+
def extract_model_identifiers(self, model=None):
55+
model_name, model_tag, model_organization = super().extract_model_identifiers(model)
5656

5757
parts = model_organization.split("/")
5858
if len(parts) > 2 and parts[-2] == "blob":

0 commit comments

Comments
 (0)