Skip to content

Commit 8af76a9

Browse files
committed
refactor: renamed 'format' variable to '_format' in internal methods so that mypy detects it
1 parent 689e1a9 commit 8af76a9

File tree

5 files changed

+66
-53
lines changed

5 files changed

+66
-53
lines changed

mellea/backends/huggingface.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
"""
6868
TransformersTorchConfig = tuple[PreTrainedTokenizer, PreTrainedModel, torch.device]
6969

70+
format: int = 1 # typing this variable in order to shadow the global format function and ensure mypy checks for errors
71+
7072

7173
@dataclasses.dataclass
7274
class HFAloraCacheInfo:
@@ -209,11 +211,11 @@ def generate_from_context(
209211
reroute_to_alora = True
210212
if reroute_to_alora:
211213
mot = self._generate_from_context_alora(
212-
action, ctx, format=format, model_options=model_opts
214+
action, ctx, _format=format, model_options=model_opts
213215
)
214216
return mot, ctx.add(mot)
215217
mot = self._generate_from_context_standard(
216-
action, ctx, format=format, model_options=model_opts, tool_calls=tool_calls
218+
action, ctx, _format=format, model_options=model_opts, tool_calls=tool_calls
217219
)
218220
return mot, ctx.add(action).add(mot)
219221

@@ -222,7 +224,7 @@ def _generate_from_context_alora(
222224
action: Component | CBlock,
223225
ctx: Context,
224226
*,
225-
format: type[BaseModelSubclass] | None = None,
227+
_format: type[BaseModelSubclass] | None = None,
226228
model_options: dict[str, Any],
227229
) -> ModelOutputThunk:
228230
match action:
@@ -245,7 +247,7 @@ def _generate_from_context_alora(
245247
assert alora_for_this_request is not None
246248
assert type(user_message) is str
247249
assert type(assistant_message) is str
248-
assert format is None, "Structured outputs are not supported by ALoRAs."
250+
assert _format is None, "Structured outputs are not supported by ALoRAs."
249251

250252
alora_output = alora_for_this_request.generate_using_strings(
251253
input=user_message,
@@ -269,7 +271,7 @@ def _generate_from_context_standard(
269271
action: Component | CBlock,
270272
ctx: Context,
271273
*,
272-
format: type[BaseModelSubclass] | None = None,
274+
_format: type[BaseModelSubclass] | None = None,
273275
model_options: dict[str, Any],
274276
tool_calls: bool = False,
275277
) -> ModelOutputThunk:
@@ -310,7 +312,7 @@ def _generate_from_context_standard(
310312
# Append tool call information if applicable.
311313
tools: dict[str, Callable] = dict()
312314
if tool_calls:
313-
if format:
315+
if _format:
314316
FancyLogger.get_logger().warning(
315317
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
316318
)
@@ -338,10 +340,10 @@ def _generate_from_context_standard(
338340
).to(self._device) # type: ignore
339341

340342
format_kwargs = {}
341-
if format:
343+
if _format:
342344
# outlines.generate.json always parses the resulting json into a python dict.
343345
# We however want to keep it as a json string for later storing it in ModelOutputThunk
344-
schema: dict[str, Any] = format.model_json_schema()
346+
schema: dict[str, Any] = _format.model_json_schema()
345347
schema_json: str = json.dumps(schema)
346348
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
347349
schema_json
@@ -406,7 +408,7 @@ def _generate_from_context_standard(
406408
self.post_processing,
407409
conversation=ctx_as_conversation,
408410
input_ids=input_ids,
409-
format=format,
411+
_format=_format,
410412
tool_calls=tool_calls,
411413
tools=tools,
412414
seed=seed,
@@ -463,7 +465,7 @@ async def post_processing(
463465
self,
464466
mot: ModelOutputThunk,
465467
conversation: list[dict],
466-
format: type[BaseModelSubclass] | None,
468+
_format: type[BaseModelSubclass] | None,
467469
tool_calls: bool,
468470
tools: dict[str, Callable],
469471
seed,
@@ -494,7 +496,7 @@ async def post_processing(
494496
self.cache_put(mot.value, cache_info)
495497

496498
# Only scan for tools if we are not doing structured output and tool calls were provided to the model.
497-
if format is None and tool_calls:
499+
if _format is None and tool_calls:
498500
mot.tool_calls = self._extract_model_tool_requests(tools, mot.value)
499501

500502
assert mot._action is not None, (
@@ -514,7 +516,7 @@ async def post_processing(
514516
generate_log.date = datetime.datetime.now()
515517
generate_log.model_output = mot.value
516518
generate_log.extra = {
517-
"format": format,
519+
"format": _format,
518520
"tools_available": tools,
519521
"tools_called": mot.tool_calls,
520522
"seed": seed,

mellea/backends/litellm.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
from mellea.stdlib.chat import Message
4141
from mellea.stdlib.requirement import ALoraRequirement
4242

43+
format: int = 1 # typing this variable in order to shadow the global format function and ensure mypy checks for errors
44+
4345

4446
class LiteLLMBackend(FormatterBackend):
4547
"""A generic LiteLLM compatible backend."""
@@ -121,7 +123,7 @@ def generate_from_context(
121123
mot = self._generate_from_chat_context_standard(
122124
action,
123125
ctx,
124-
format=format,
126+
_format=format,
125127
model_options=model_options,
126128
tool_calls=tool_calls,
127129
)
@@ -213,7 +215,7 @@ def _generate_from_chat_context_standard(
213215
action: Component | CBlock,
214216
ctx: Context,
215217
*,
216-
format: type[BaseModelSubclass]
218+
_format: type[BaseModelSubclass]
217219
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
218220
model_options: dict | None = None,
219221
tool_calls: bool = False,
@@ -247,12 +249,12 @@ def _generate_from_chat_context_standard(
247249
[OpenAIBackend.message_to_openai_message(m) for m in messages]
248250
)
249251

250-
if format is not None:
252+
if _format is not None:
251253
response_format = {
252254
"type": "json_schema",
253255
"json_schema": {
254-
"name": format.__name__,
255-
"schema": format.model_json_schema(),
256+
"name": _format.__name__,
257+
"schema": _format.model_json_schema(),
256258
"strict": True,
257259
},
258260
}
@@ -265,7 +267,7 @@ def _generate_from_chat_context_standard(
265267
thinking = "medium"
266268

267269
# Append tool call information if applicable.
268-
tools = self._extract_tools(action, format, model_opts, tool_calls, ctx)
270+
tools = self._extract_tools(action, _format, model_opts, tool_calls, ctx)
269271
formatted_tools = convert_tools_to_json(tools) if len(tools) > 0 else None
270272

271273
model_specific_options = self._make_backend_specific_and_remove(model_opts)
@@ -295,7 +297,7 @@ def _generate_from_chat_context_standard(
295297
conversation=conversation,
296298
tools=tools,
297299
thinking=thinking,
298-
format=format,
300+
_format=_format,
299301
)
300302

301303
try:
@@ -373,7 +375,7 @@ async def post_processing(
373375
conversation: list[dict],
374376
tools: dict[str, Callable],
375377
thinking,
376-
format,
378+
_format,
377379
):
378380
"""Called when generation is done."""
379381
# Reconstruct the chat_response from chunks if streamed.
@@ -418,7 +420,7 @@ async def post_processing(
418420
generate_log.date = datetime.datetime.now()
419421
generate_log.model_output = mot._meta["litellm_chat_response"]
420422
generate_log.extra = {
421-
"format": format,
423+
"format": _format,
422424
"tools_available": tools,
423425
"tools_called": mot.tool_calls,
424426
"seed": thinking,
@@ -429,11 +431,11 @@ async def post_processing(
429431

430432
@staticmethod
431433
def _extract_tools(
432-
action, format, model_opts, tool_calls, ctx
434+
action, _format, model_opts, tool_calls, ctx
433435
) -> dict[str, Callable]:
434436
tools: dict[str, Callable] = dict()
435437
if tool_calls:
436-
if format:
438+
if _format:
437439
FancyLogger.get_logger().warning(
438440
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
439441
)

mellea/backends/ollama.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from mellea.stdlib.chat import Message
3333
from mellea.stdlib.requirement import ALoraRequirement
3434

35+
format: int = 1 # typing this variable in order to shadow the global format function and ensure mypy checks for errors
36+
3537

3638
class OllamaModelBackend(FormatterBackend):
3739
"""A model that uses the Ollama Python SDK for local inference."""
@@ -245,7 +247,7 @@ def generate_from_context(
245247
mot = self.generate_from_chat_context(
246248
action,
247249
ctx,
248-
format=format,
250+
_format=format,
249251
model_options=model_options,
250252
tool_calls=tool_calls,
251253
)
@@ -257,7 +259,7 @@ def generate_from_chat_context(
257259
action: Component | CBlock,
258260
ctx: Context,
259261
*,
260-
format: type[BaseModelSubclass] | None = None,
262+
_format: type[BaseModelSubclass] | None = None,
261263
model_options: dict | None = None,
262264
tool_calls: bool = False,
263265
) -> ModelOutputThunk:
@@ -305,7 +307,7 @@ def generate_from_chat_context(
305307
# Append tool call information if applicable.
306308
tools: dict[str, Callable] = dict()
307309
if tool_calls:
308-
if format:
310+
if _format:
309311
FancyLogger.get_logger().warning(
310312
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
311313
)
@@ -331,7 +333,7 @@ def generate_from_chat_context(
331333
think=model_opts.get(ModelOption.THINKING, None),
332334
stream=model_opts.get(ModelOption.STREAM, False),
333335
options=self._make_backend_specific_and_remove(model_opts),
334-
format=format.model_json_schema() if format is not None else None,
336+
format=_format.model_json_schema() if _format is not None else None,
335337
) # type: ignore
336338

337339
output = ModelOutputThunk(None)
@@ -343,7 +345,10 @@ def generate_from_chat_context(
343345
# each processing step.
344346
output._process = functools.partial(self.processing, tools=tools)
345347
output._post_process = functools.partial(
346-
self.post_processing, conversation=conversation, tools=tools, format=format
348+
self.post_processing,
349+
conversation=conversation,
350+
tools=tools,
351+
_format=_format,
347352
)
348353

349354
try:
@@ -506,7 +511,7 @@ async def post_processing(
506511
mot: ModelOutputThunk,
507512
conversation: list[dict],
508513
tools: dict[str, Callable],
509-
format,
514+
_format,
510515
):
511516
"""Called when generation is done."""
512517
assert mot._action is not None, (
@@ -525,7 +530,7 @@ async def post_processing(
525530
generate_log.date = datetime.datetime.now()
526531
generate_log.model_output = mot._meta["chat_response"]
527532
generate_log.extra = {
528-
"format": format,
533+
"format": _format,
529534
"thinking": mot._model_options.get(ModelOption.THINKING, None),
530535
"tools_available": tools,
531536
"tools_called": mot.tool_calls,

mellea/backends/openai.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151

5252
openai_ollama_batching_error = "json: cannot unmarshal array into Go struct field CompletionRequest.prompt of type string"
5353

54+
format: int = 1 # typing this variable in order to shadow the global format function and ensure mypy checks for errors
55+
5456

5557
class _ServerType(Enum):
5658
LOCALHOST = 1
@@ -282,7 +284,7 @@ def generate_from_context(
282284
mot = self.generate_from_chat_context(
283285
action,
284286
ctx,
285-
format=format,
287+
_format=format,
286288
model_options=model_options,
287289
tool_calls=tool_calls,
288290
)
@@ -293,7 +295,7 @@ def generate_from_chat_context(
293295
action: Component | CBlock,
294296
ctx: Context,
295297
*,
296-
format: type[BaseModelSubclass]
298+
_format: type[BaseModelSubclass]
297299
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
298300
model_options: dict | None = None,
299301
tool_calls: bool = False,
@@ -311,13 +313,13 @@ def generate_from_chat_context(
311313
reroute_to_alora = True
312314
if reroute_to_alora:
313315
return self._generate_from_chat_context_alora(
314-
action, ctx, format=format, model_options=model_options
316+
action, ctx, _format=_format, model_options=model_options
315317
)
316318

317319
return self._generate_from_chat_context_standard(
318320
action,
319321
ctx,
320-
format=format,
322+
_format=_format,
321323
model_options=model_options,
322324
tool_calls=tool_calls,
323325
)
@@ -327,7 +329,7 @@ def _generate_from_chat_context_alora(
327329
action: Component | CBlock,
328330
ctx: Context,
329331
*,
330-
format: type[BaseModelSubclass]
332+
_format: type[BaseModelSubclass]
331333
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
332334
model_options: dict | None = None,
333335
) -> ModelOutputThunk:
@@ -352,7 +354,7 @@ def _generate_from_chat_context_alora(
352354
assert alora_for_this_request is not None
353355
assert type(user_message) is str
354356
assert type(assistant_message) is str
355-
assert format is None, "Structured outputs are not supported by ALoRAs."
357+
assert _format is None, "Structured outputs are not supported by ALoRAs."
356358

357359
model_opts = self._simplify_and_merge(model_options, is_chat_context=True)
358360

@@ -413,7 +415,7 @@ def _generate_from_chat_context_standard(
413415
action: Component | CBlock,
414416
ctx: Context,
415417
*,
416-
format: type[BaseModelSubclass]
418+
_format: type[BaseModelSubclass]
417419
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
418420
model_options: dict | None = None,
419421
tool_calls: bool = False,
@@ -442,12 +444,12 @@ def _generate_from_chat_context_standard(
442444
conversation.append({"role": "system", "content": system_prompt})
443445
conversation.extend([self.message_to_openai_message(m) for m in messages])
444446

445-
if format is not None:
447+
if _format is not None:
446448
response_format = {
447449
"type": "json_schema",
448450
"json_schema": {
449-
"name": format.__name__,
450-
"schema": format.model_json_schema(),
451+
"name": _format.__name__,
452+
"schema": _format.model_json_schema(),
451453
"strict": True,
452454
},
453455
}
@@ -457,7 +459,7 @@ def _generate_from_chat_context_standard(
457459
# Append tool call information if applicable.
458460
tools: dict[str, Callable] = dict()
459461
if tool_calls:
460-
if format:
462+
if _format:
461463
FancyLogger.get_logger().warning(
462464
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
463465
)
@@ -506,7 +508,7 @@ def _generate_from_chat_context_standard(
506508
conversation=conversation,
507509
thinking=thinking,
508510
seed=model_opts.get(ModelOption.SEED, None),
509-
format=format,
511+
_format=_format,
510512
)
511513

512514
try:
@@ -575,7 +577,7 @@ async def post_processing(
575577
conversation: list[dict],
576578
thinking,
577579
seed,
578-
format,
580+
_format,
579581
):
580582
"""Called when generation is done."""
581583
# Reconstruct the chat_response from chunks if streamed.
@@ -613,7 +615,7 @@ async def post_processing(
613615
generate_log.date = datetime.datetime.now()
614616
generate_log.model_output = mot._meta["oai_chat_response"]
615617
generate_log.extra = {
616-
"format": format,
618+
"format": _format,
617619
"thinking": thinking,
618620
"tools_available": tools,
619621
"tools_called": mot.tool_calls,

0 commit comments

Comments
 (0)