Skip to content

Commit 4cd82d6

Browse files
author
ochafik
committed
tool-call: fix pyright type errors
1 parent 059babd commit 4cd82d6

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

examples/server/tests/features/steps/steps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,8 +1146,8 @@ async def oai_chat_completions(user_prompt,
11461146
max_tokens=n_predict,
11471147
stream=enable_streaming,
11481148
response_format=payload.get('response_format') or openai.NOT_GIVEN,
1149-
tools=payload.get('tools'),
1150-
tool_choice=payload.get('tool_choice'),
1149+
tools=payload.get('tools') or openai.NOT_GIVEN,
1150+
tool_choice=payload.get('tool_choice') or openai.NOT_GIVEN,
11511151
seed=seed,
11521152
temperature=payload['temperature']
11531153
)

tests/update_jinja_goldens.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py
1616
'''
1717

18+
import logging
1819
import datetime
1920
import glob
2021
import os
@@ -25,6 +26,8 @@
2526
import re
2627
# import requests
2728

29+
logger = logging.getLogger(__name__)
30+
2831
model_ids = [
2932
"NousResearch/Hermes-3-Llama-3.1-70B",
3033
"NousResearch/Hermes-2-Pro-Llama-3-8B",
@@ -76,19 +79,19 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False)
7679

7780

7881
def strftime_now(format):
79-
return datetime.now().strftime(format)
82+
return datetime.datetime.now().strftime(format)
8083

8184

8285
def handle_chat_template(model_id, variant, template_src):
83-
print(f"# {model_id} @ {variant}", flush=True)
86+
logger.info(f"# {model_id} @ {variant}")
8487
model_name = model_id.replace("/", "-")
8588
base_name = f'{model_name}-{variant}' if variant else model_name
8689
template_file = f'tests/chat/templates/{base_name}.jinja'
87-
print(f'template_file: {template_file}')
90+
logger.info(f'template_file: {template_file}')
8891
with open(template_file, 'w') as f:
8992
f.write(template_src)
9093

91-
print(f"- {template_file}", flush=True)
94+
logger.info(f"- {template_file}")
9295

9396
env = jinja2.Environment(
9497
trim_blocks=True,
@@ -119,7 +122,7 @@ def handle_chat_template(model_id, variant, template_src):
119122
continue
120123

121124
output_file = f'tests/chat/goldens/{base_name}-{context_name}.txt'
122-
print(f"- {output_file}", flush=True)
125+
logger.info(f"- {output_file}")
123126
try:
124127
output = template.render(**context)
125128
except Exception as e1:
@@ -131,14 +134,12 @@ def handle_chat_template(model_id, variant, template_src):
131134
try:
132135
output = template.render(**context)
133136
except Exception as e2:
134-
print(f" ERROR: {e2} (after first error: {e1})", flush=True)
137+
logger.info(f" ERROR: {e2} (after first error: {e1})")
135138
output = f"ERROR: {e2}"
136139

137140
with open(output_file, 'w') as f:
138141
f.write(output)
139142

140-
print()
141-
142143

143144
def main():
144145
for dir in ['tests/chat/templates', 'tests/chat/goldens']:

0 commit comments

Comments
 (0)