Skip to content

Commit 536c1a7

Browse files
committed
refactor: renamed 'format' variable to '_format' in internal methods so that mypy detects it
1 parent 3fa2bbc commit 536c1a7

File tree

5 files changed

+66
-53
lines changed

5 files changed

+66
-53
lines changed

mellea/backends/huggingface.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
"""
6868
TransformersTorchConfig = tuple[PreTrainedTokenizer, PreTrainedModel, torch.device]
6969

70+
format: int = 1 # typing this variable in order to shadow the global format function and ensure mypy checks for errors
71+
7072

7173
@dataclasses.dataclass
7274
class HFAloraCacheInfo:
@@ -209,11 +211,11 @@ def generate_from_context(
209211
reroute_to_alora = True
210212
if reroute_to_alora:
211213
mot = self._generate_from_context_alora(
212-
action, ctx, format=format, model_options=model_opts
214+
action, ctx, _format=format, model_options=model_opts
213215
)
214216
return mot, ctx.add(mot)
215217
mot = self._generate_from_context_standard(
216-
action, ctx, format=format, model_options=model_opts, tool_calls=tool_calls
218+
action, ctx, _format=format, model_options=model_opts, tool_calls=tool_calls
217219
)
218220
return mot, ctx.add(action).add(mot)
219221

@@ -222,7 +224,7 @@ def _generate_from_context_alora(
222224
action: Component | CBlock,
223225
ctx: Context,
224226
*,
225-
format: type[BaseModelSubclass] | None = None,
227+
_format: type[BaseModelSubclass] | None = None,
226228
model_options: dict[str, Any],
227229
) -> ModelOutputThunk:
228230
match action:
@@ -245,7 +247,7 @@ def _generate_from_context_alora(
245247
assert alora_for_this_request is not None
246248
assert type(user_message) is str
247249
assert type(assistant_message) is str
248-
assert format is None, "Structured outputs are not supported by ALoRAs."
250+
assert _format is None, "Structured outputs are not supported by ALoRAs."
249251

250252
alora_output = alora_for_this_request.generate_using_strings(
251253
input=user_message,
@@ -269,7 +271,7 @@ def _generate_from_context_standard(
269271
action: Component | CBlock,
270272
ctx: Context,
271273
*,
272-
format: type[BaseModelSubclass] | None = None,
274+
_format: type[BaseModelSubclass] | None = None,
273275
model_options: dict[str, Any],
274276
tool_calls: bool = False,
275277
) -> ModelOutputThunk:
@@ -310,7 +312,7 @@ def _generate_from_context_standard(
310312
# Append tool call information if applicable.
311313
tools: dict[str, Callable] = dict()
312314
if tool_calls:
313-
if format:
315+
if _format:
314316
FancyLogger.get_logger().warning(
315317
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
316318
)
@@ -338,10 +340,10 @@ def _generate_from_context_standard(
338340
).to(self._device) # type: ignore
339341

340342
format_kwargs = {}
341-
if format:
343+
if _format:
342344
# outlines.generate.json always parses the resulting json into a python dict.
343345
# We however want to keep it as a json string for later storing it in ModelOutputThunk
344-
schema: dict[str, Any] = format.model_json_schema()
346+
schema: dict[str, Any] = _format.model_json_schema()
345347
schema_json: str = json.dumps(schema)
346348
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
347349
schema_json
@@ -402,7 +404,7 @@ def _generate_from_context_standard(
402404
self.post_processing,
403405
conversation=ctx_as_conversation,
404406
input_ids=input_ids,
405-
format=format,
407+
_format=_format,
406408
tool_calls=tool_calls,
407409
tools=tools,
408410
seed=seed,
@@ -459,7 +461,7 @@ async def post_processing(
459461
self,
460462
mot: ModelOutputThunk,
461463
conversation: list[dict],
462-
format: type[BaseModelSubclass] | None,
464+
_format: type[BaseModelSubclass] | None,
463465
tool_calls: bool,
464466
tools: dict[str, Callable],
465467
seed,
@@ -490,7 +492,7 @@ async def post_processing(
490492
self.cache_put(mot.value, cache_info)
491493

492494
# Only scan for tools if we are not doing structured output and tool calls were provided to the model.
493-
if format is None and tool_calls:
495+
if _format is None and tool_calls:
494496
mot.tool_calls = self._extract_model_tool_requests(tools, mot.value)
495497

496498
assert mot._action is not None, (
@@ -510,7 +512,7 @@ async def post_processing(
510512
generate_log.date = datetime.datetime.now()
511513
generate_log.model_output = mot.value
512514
generate_log.extra = {
513-
"format": format,
515+
"format": _format,
514516
"tools_available": tools,
515517
"tools_called": mot.tool_calls,
516518
"seed": seed,

mellea/backends/litellm.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
from mellea.stdlib.chat import Message
4141
from mellea.stdlib.requirement import ALoraRequirement
4242

43+
format: int = 1 # typing this variable in order to shadow the global format function and ensure mypy checks for errors
44+
4345

4446
class LiteLLMBackend(FormatterBackend):
4547
"""A generic LiteLLM compatible backend."""
@@ -121,7 +123,7 @@ def generate_from_context(
121123
mot = self._generate_from_chat_context_standard(
122124
action,
123125
ctx,
124-
format=format,
126+
_format=format,
125127
model_options=model_options,
126128
tool_calls=tool_calls,
127129
)
@@ -213,7 +215,7 @@ def _generate_from_chat_context_standard(
213215
action: Component | CBlock,
214216
ctx: Context,
215217
*,
216-
format: type[BaseModelSubclass]
218+
_format: type[BaseModelSubclass]
217219
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
218220
model_options: dict | None = None,
219221
tool_calls: bool = False,
@@ -247,12 +249,12 @@ def _generate_from_chat_context_standard(
247249
[OpenAIBackend.message_to_openai_message(m) for m in messages]
248250
)
249251

250-
if format is not None:
252+
if _format is not None:
251253
response_format = {
252254
"type": "json_schema",
253255
"json_schema": {
254-
"name": format.__name__,
255-
"schema": format.model_json_schema(),
256+
"name": _format.__name__,
257+
"schema": _format.model_json_schema(),
256258
"strict": True,
257259
},
258260
}
@@ -265,7 +267,7 @@ def _generate_from_chat_context_standard(
265267
thinking = "medium"
266268

267269
# Append tool call information if applicable.
268-
tools = self._extract_tools(action, format, model_opts, tool_calls, ctx)
270+
tools = self._extract_tools(action, _format, model_opts, tool_calls, ctx)
269271
formatted_tools = convert_tools_to_json(tools) if len(tools) > 0 else None
270272

271273
model_specific_options = self._make_backend_specific_and_remove(model_opts)
@@ -295,7 +297,7 @@ def _generate_from_chat_context_standard(
295297
conversation=conversation,
296298
tools=tools,
297299
thinking=thinking,
298-
format=format,
300+
_format=_format,
299301
)
300302

301303
try:
@@ -372,7 +374,7 @@ async def post_processing(
372374
conversation: list[dict],
373375
tools: dict[str, Callable],
374376
thinking,
375-
format,
377+
_format,
376378
):
377379
"""Called when generation is done."""
378380
# Reconstruct the chat_response from chunks if streamed.
@@ -417,7 +419,7 @@ async def post_processing(
417419
generate_log.date = datetime.datetime.now()
418420
generate_log.model_output = mot._meta["litellm_chat_response"]
419421
generate_log.extra = {
420-
"format": format,
422+
"format": _format,
421423
"tools_available": tools,
422424
"tools_called": mot.tool_calls,
423425
"seed": thinking,
@@ -428,11 +430,11 @@ async def post_processing(
428430

429431
@staticmethod
430432
def _extract_tools(
431-
action, format, model_opts, tool_calls, ctx
433+
action, _format, model_opts, tool_calls, ctx
432434
) -> dict[str, Callable]:
433435
tools: dict[str, Callable] = dict()
434436
if tool_calls:
435-
if format:
437+
if _format:
436438
FancyLogger.get_logger().warning(
437439
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
438440
)

mellea/backends/ollama.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from mellea.stdlib.chat import Message
3333
from mellea.stdlib.requirement import ALoraRequirement
3434

35+
format: int = 1 # typing this variable in order to shadow the global format function and ensure mypy checks for errors
36+
3537

3638
class OllamaModelBackend(FormatterBackend):
3739
"""A model that uses the Ollama Python SDK for local inference."""
@@ -245,7 +247,7 @@ def generate_from_context(
245247
mot = self.generate_from_chat_context(
246248
action,
247249
ctx,
248-
format=format,
250+
_format=format,
249251
model_options=model_options,
250252
tool_calls=tool_calls,
251253
)
@@ -257,7 +259,7 @@ def generate_from_chat_context(
257259
action: Component | CBlock,
258260
ctx: Context,
259261
*,
260-
format: type[BaseModelSubclass] | None = None,
262+
_format: type[BaseModelSubclass] | None = None,
261263
model_options: dict | None = None,
262264
tool_calls: bool = False,
263265
) -> ModelOutputThunk:
@@ -305,7 +307,7 @@ def generate_from_chat_context(
305307
# Append tool call information if applicable.
306308
tools: dict[str, Callable] = dict()
307309
if tool_calls:
308-
if format:
310+
if _format:
309311
FancyLogger.get_logger().warning(
310312
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
311313
)
@@ -331,7 +333,7 @@ def generate_from_chat_context(
331333
think=model_opts.get(ModelOption.THINKING, None),
332334
stream=model_opts.get(ModelOption.STREAM, False),
333335
options=self._make_backend_specific_and_remove(model_opts),
334-
format=format.model_json_schema() if format is not None else None,
336+
format=_format.model_json_schema() if _format is not None else None,
335337
) # type: ignore
336338

337339
output = ModelOutputThunk(None)
@@ -343,7 +345,10 @@ def generate_from_chat_context(
343345
# each processing step.
344346
output._process = functools.partial(self.processing, tools=tools)
345347
output._post_process = functools.partial(
346-
self.post_processing, conversation=conversation, tools=tools, format=format
348+
self.post_processing,
349+
conversation=conversation,
350+
tools=tools,
351+
_format=_format,
347352
)
348353

349354
try:
@@ -506,7 +511,7 @@ async def post_processing(
506511
mot: ModelOutputThunk,
507512
conversation: list[dict],
508513
tools: dict[str, Callable],
509-
format,
514+
_format,
510515
):
511516
"""Called when generation is done."""
512517
assert mot._action is not None, (
@@ -525,7 +530,7 @@ async def post_processing(
525530
generate_log.date = datetime.datetime.now()
526531
generate_log.model_output = mot._meta["chat_response"]
527532
generate_log.extra = {
528-
"format": format,
533+
"format": _format,
529534
"thinking": mot._model_options.get(ModelOption.THINKING, None),
530535
"tools_available": tools,
531536
"tools_called": mot.tool_calls,

mellea/backends/openai.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151

5252
openai_ollama_batching_error = "json: cannot unmarshal array into Go struct field CompletionRequest.prompt of type string"
5353

54+
format: int = 1 # typing this variable in order to shadow the global format function and ensure mypy checks for errors
55+
5456

5557
class _ServerType(Enum):
5658
LOCALHOST = 1
@@ -279,7 +281,7 @@ def generate_from_context(
279281
mot = self.generate_from_chat_context(
280282
action,
281283
ctx,
282-
format=format,
284+
_format=format,
283285
model_options=model_options,
284286
tool_calls=tool_calls,
285287
)
@@ -290,7 +292,7 @@ def generate_from_chat_context(
290292
action: Component | CBlock,
291293
ctx: Context,
292294
*,
293-
format: type[BaseModelSubclass]
295+
_format: type[BaseModelSubclass]
294296
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
295297
model_options: dict | None = None,
296298
tool_calls: bool = False,
@@ -308,13 +310,13 @@ def generate_from_chat_context(
308310
reroute_to_alora = True
309311
if reroute_to_alora:
310312
return self._generate_from_chat_context_alora(
311-
action, ctx, format=format, model_options=model_options
313+
action, ctx, _format=_format, model_options=model_options
312314
)
313315

314316
return self._generate_from_chat_context_standard(
315317
action,
316318
ctx,
317-
format=format,
319+
_format=_format,
318320
model_options=model_options,
319321
tool_calls=tool_calls,
320322
)
@@ -324,7 +326,7 @@ def _generate_from_chat_context_alora(
324326
action: Component | CBlock,
325327
ctx: Context,
326328
*,
327-
format: type[BaseModelSubclass]
329+
_format: type[BaseModelSubclass]
328330
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
329331
model_options: dict | None = None,
330332
) -> ModelOutputThunk:
@@ -349,7 +351,7 @@ def _generate_from_chat_context_alora(
349351
assert alora_for_this_request is not None
350352
assert type(user_message) is str
351353
assert type(assistant_message) is str
352-
assert format is None, "Structured outputs are not supported by ALoRAs."
354+
assert _format is None, "Structured outputs are not supported by ALoRAs."
353355

354356
model_opts = self._simplify_and_merge(model_options, is_chat_context=True)
355357

@@ -409,7 +411,7 @@ def _generate_from_chat_context_standard(
409411
action: Component | CBlock,
410412
ctx: Context,
411413
*,
412-
format: type[BaseModelSubclass]
414+
_format: type[BaseModelSubclass]
413415
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
414416
model_options: dict | None = None,
415417
tool_calls: bool = False,
@@ -438,12 +440,12 @@ def _generate_from_chat_context_standard(
438440
conversation.append({"role": "system", "content": system_prompt})
439441
conversation.extend([self.message_to_openai_message(m) for m in messages])
440442

441-
if format is not None:
443+
if _format is not None:
442444
response_format = {
443445
"type": "json_schema",
444446
"json_schema": {
445-
"name": format.__name__,
446-
"schema": format.model_json_schema(),
447+
"name": _format.__name__,
448+
"schema": _format.model_json_schema(),
447449
"strict": True,
448450
},
449451
}
@@ -453,7 +455,7 @@ def _generate_from_chat_context_standard(
453455
# Append tool call information if applicable.
454456
tools: dict[str, Callable] = dict()
455457
if tool_calls:
456-
if format:
458+
if _format:
457459
FancyLogger.get_logger().warning(
458460
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
459461
)
@@ -502,7 +504,7 @@ def _generate_from_chat_context_standard(
502504
conversation=conversation,
503505
thinking=thinking,
504506
seed=model_opts.get(ModelOption.SEED, None),
505-
format=format,
507+
_format=_format,
506508
)
507509

508510
try:
@@ -570,7 +572,7 @@ async def post_processing(
570572
conversation: list[dict],
571573
thinking,
572574
seed,
573-
format,
575+
_format,
574576
):
575577
"""Called when generation is done."""
576578
# Reconstruct the chat_response from chunks if streamed.
@@ -608,7 +610,7 @@ async def post_processing(
608610
generate_log.date = datetime.datetime.now()
609611
generate_log.model_output = mot._meta["oai_chat_response"]
610612
generate_log.extra = {
611-
"format": format,
613+
"format": _format,
612614
"thinking": thinking,
613615
"tools_available": tools,
614616
"tools_called": mot.tool_calls,

0 commit comments

Comments
 (0)