Skip to content

Commit 77d2f08

Browse files
authored
fix model-id error caused by previous refactor (#46)
1 parent dc77222 commit 77d2f08

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

mellea/backends/formatter.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import re
66
import sys
77
from collections.abc import Iterable, Mapping
8+
from dataclasses import fields
89
from typing import Any
910

1011
import jinja2
@@ -396,14 +397,13 @@ def _get_model_id(self) -> str:
396397
"model_id was neither a `str` nor `ModelIdentifier`"
397398
)
398399

399-
# Go through the ModelIdentifier's fields, find one that isn't `"None"` or `""`.
400-
ids = [model_id.hf_model_name, model_id.ollama_name]
401-
model_id = ""
402-
for val in ids:
403-
if val != "None" and val != "":
404-
model_id = val # type: ignore
405-
break
406-
return model_id
400+
# Go through the ModelIdentifier's fields, find one that can be matched against.
401+
for field in fields(model_id):
402+
val = getattr(model_id, field.name)
403+
if val is not None and val != "":
404+
return val
405+
406+
return "" # Cannot match against any model identifiers. Will ultimately use default.
407407

408408

409409
def _simplify_model_string(input: str) -> str:

test/test_formatter_baseclasses.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88

99
from mellea.backends.formatter import TemplateFormatter
10-
from mellea.backends.model_ids import IBM_GRANITE_3_2_8B
10+
from mellea.backends.model_ids import ModelIdentifier, IBM_GRANITE_3_2_8B
1111
from mellea.stdlib.base import (
1212
BasicContext,
1313
CBlock,
@@ -284,6 +284,22 @@ def test_fake_model_id(instr: Instruction):
284284
"default" in tmpl.name
285285
), "there should always be a default instruction template"
286286

287+
def test_custom_model_id():
288+
model_id = ModelIdentifier(mlx_name="new-model-here")
289+
tf = TemplateFormatter(model_id=model_id)
290+
assert tf._get_model_id() == "new-model-here", "getting the model id should always give a string if one exists"
291+
292+
def test_empty_model_id(instr: Instruction):
293+
model_id = ModelIdentifier()
294+
tf = TemplateFormatter(model_id=model_id)
295+
assert tf._get_model_id() == ""
296+
297+
tmpl = tf._load_template(instr.format_for_llm())
298+
assert tmpl.name is not None
299+
assert (
300+
"default" in tmpl.name
301+
), "there should always be a default instruction template"
302+
287303

288304
def test_template_caching(instr: Instruction):
289305
"""Caching shouldn't be interacted with this way by users.

0 commit comments

Comments
 (0)