Skip to content

Commit 7b48b4d

Browse files
authored
bug: b64 encoded identifiers (#33)
- if a generator identifier string ends in , the split function (on ) splits into too large of an array - instead, just split on the first as the rest of the string is comma delimtited
1 parent f1b2b45 commit 7b48b4d

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

rigging/generator/base.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626

2727
@t.runtime_checkable
2828
class LazyGenerator(t.Protocol):
29-
def __call__(self) -> type[Generator]:
30-
...
29+
def __call__(self) -> type[Generator]: ...
3130

3231

3332
g_providers: dict[str, type[Generator] | LazyGenerator] = {}
@@ -382,16 +381,14 @@ def chat(
382381
self,
383382
messages: t.Sequence[MessageDict],
384383
params: GenerateParams | None = None,
385-
) -> ChatPipeline:
386-
...
384+
) -> ChatPipeline: ...
387385

388386
@t.overload
389387
def chat(
390388
self,
391389
messages: t.Sequence[Message] | MessageDict | Message | str | None = None,
392390
params: GenerateParams | None = None,
393-
) -> ChatPipeline:
394-
...
391+
) -> ChatPipeline: ...
395392

396393
def chat(
397394
self,
@@ -460,17 +457,15 @@ def chat(
460457
generator: Generator,
461458
messages: t.Sequence[MessageDict],
462459
params: GenerateParams | None = None,
463-
) -> ChatPipeline:
464-
...
460+
) -> ChatPipeline: ...
465461

466462

467463
@t.overload
468464
def chat(
469465
generator: Generator,
470466
messages: t.Sequence[Message] | MessageDict | Message | str | None = None,
471467
params: GenerateParams | None = None,
472-
) -> ChatPipeline:
473-
...
468+
) -> ChatPipeline: ...
474469

475470

476471
def chat(
@@ -597,7 +592,7 @@ def get_generator(identifier: str, *, params: GenerateParams | None = None) -> G
597592
if "," in model:
598593
try:
599594
model, kwargs_str = model.split(",", 1)
600-
kwargs = dict(arg.split("=") for arg in kwargs_str.split(","))
595+
kwargs = dict(arg.split("=", 1) for arg in kwargs_str.split(","))
601596
except Exception as e:
602597
raise InvalidModelSpecifiedError(identifier) from e
603598

tests/test_generator_ids.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,9 @@ def test_register_generator() -> None:
7575
register_generator("echo", EchoGenerator)
7676
generator = get_generator("echo!test")
7777
assert isinstance(generator, EchoGenerator)
78+
79+
80+
def test_get_generator_b64() -> None:
81+
generator = get_generator("litellm!test_model,api_key=ZXhhbXBsZXRleHQ=")
82+
assert isinstance(generator, LiteLLMGenerator)
83+
assert generator.model == "test_model"

0 commit comments

Comments
 (0)