File tree Expand file tree Collapse file tree 2 files changed +25
-9
lines changed Expand file tree Collapse file tree 2 files changed +25
-9
lines changed Original file line number Diff line number Diff line change 55import re
66import sys
77from collections .abc import Iterable , Mapping
8+ from dataclasses import fields
89from typing import Any
910
1011import jinja2
@@ -396,14 +397,13 @@ def _get_model_id(self) -> str:
396397 "model_id was neither a `str` nor `ModelIdentifier`"
397398 )
398399
399- # Go through the ModelIdentifier's fields, find one that isn't `"None"` or `""`.
400- ids = [model_id .hf_model_name , model_id .ollama_name ]
401- model_id = ""
402- for val in ids :
403- if val != "None" and val != "" :
404- model_id = val # type: ignore
405- break
406- return model_id
400+ # Go through the ModelIdentifier's fields, find one that can be matched against.
401+ for field in fields (model_id ):
402+ val = getattr (model_id , field .name )
403+ if val is not None and val != "" :
404+ return val
405+
406+ return "" # Cannot match against any model identifiers. Will ultimately use default.
407407
408408
409409def _simplify_model_string (input : str ) -> str :
Original file line number Diff line number Diff line change 77import pytest
88
99from mellea .backends .formatter import TemplateFormatter
10- from mellea .backends .model_ids import IBM_GRANITE_3_2_8B
10+ from mellea .backends .model_ids import ModelIdentifier , IBM_GRANITE_3_2_8B
1111from mellea .stdlib .base import (
1212 BasicContext ,
1313 CBlock ,
@@ -284,6 +284,22 @@ def test_fake_model_id(instr: Instruction):
284284 "default" in tmpl .name
285285 ), "there should always be a default instruction template"
286286
287+ def test_custom_model_id ():
288+ model_id = ModelIdentifier (mlx_name = "new-model-here" )
289+ tf = TemplateFormatter (model_id = model_id )
290+ assert tf ._get_model_id () == "new-model-here" , "getting the model id should always give a string if one exists"
291+
292+ def test_empty_model_id (instr : Instruction ):
293+ model_id = ModelIdentifier ()
294+ tf = TemplateFormatter (model_id = model_id )
295+ assert tf ._get_model_id () == ""
296+
297+ tmpl = tf ._load_template (instr .format_for_llm ())
298+ assert tmpl .name is not None
299+ assert (
300+ "default" in tmpl .name
301+ ), "there should always be a default instruction template"
302+
287303
288304def test_template_caching (instr : Instruction ):
289305 """Caching shouldn't be interacted with this way by users.
You can’t perform that action at this time.
0 commit comments