Skip to content

Commit e434995

Browse files
format: ruff format examples directory (#559)
* fix format in examples * Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent e0d5254 commit e434995

File tree

11 files changed

+77
-87
lines changed

11 files changed

+77
-87
lines changed

examples/config_converter.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import argparse
33
import os
44
import shlex
5+
from collections.abc import Callable
56
from dataclasses import dataclass, field
6-
from typing import Any, Callable, Optional, Type
7+
from typing import Any
78

89
import yaml
910
from deepmerge import always_merger
@@ -17,11 +18,10 @@
1718

1819
@dataclass
1920
class ArgSpec:
20-
2121
arg: str = ""
22-
arg_type: Type[Any] = field(default=str)
22+
arg_type: type[Any] = field(default=str)
2323
description: str = ""
24-
map_fn: Optional[Callable] = field(default=None)
24+
map_fn: Callable | None = field(default=None)
2525

2626

2727
class Converter(abc.ABC):
@@ -43,7 +43,7 @@ def get_lite_template(self, template_path: str):
4343
"""
4444
Load the areal template from the specified file.
4545
"""
46-
with open(template_path, "r", encoding="utf-8") as f:
46+
with open(template_path, encoding="utf-8") as f:
4747
return yaml.safe_load(f)
4848

4949
def flatten_dict(self, d, parent_key="", sep="."):
@@ -74,7 +74,7 @@ def convert_to_nested_args(self, args, ARG_MAP: dict) -> dict:
7474
argspec = ARG_MAP[k]
7575
if not argspec.arg:
7676
print(
77-
colored(f"## Warning: For ", "yellow")
77+
colored("## Warning: For ", "yellow")
7878
+ colored(f"{k:>40}", "yellow", attrs=["bold"])
7979
+ colored(f", # {argspec.description}!", "yellow")
8080
)
@@ -87,23 +87,23 @@ def convert_to_nested_args(self, args, ARG_MAP: dict) -> dict:
8787
if v is not None:
8888
try:
8989
# type conversion
90-
if arg_type == bool:
90+
if arg_type is bool:
9191
v = (
9292
bool(v)
9393
if isinstance(v, bool)
9494
else v.lower() in ("1", "true", "yes", "on")
9595
)
96-
elif arg_type == int:
96+
elif arg_type is int:
9797
v = int(v)
98-
elif arg_type == float:
98+
elif arg_type is float:
9999
v = float(v)
100-
elif arg_type == str:
100+
elif arg_type is str:
101101
v = str(v)
102102
else:
103103
raise ValueError(f"Unsupported type: {arg_type}")
104104
except Exception as e:
105105
print(
106-
colored(f"## Error: For ", "red")
106+
colored("## Error: For ", "red")
107107
+ colored(f"{k:>40} {v}", "red", attrs=["bold"])
108108
+ colored(f", # {e}!", "red")
109109
)
@@ -117,7 +117,7 @@ def convert_to_nested_args(self, args, ARG_MAP: dict) -> dict:
117117
else:
118118
unmapped[k] = v
119119
print(
120-
colored(f"## Warning: For ", "yellow")
120+
colored("## Warning: For ", "yellow")
121121
+ colored(f"{k:>50}", "yellow", attrs=["bold"])
122122
+ colored(f", # {CVRT_WARNING}!", "yellow")
123123
)
@@ -135,7 +135,6 @@ def set_nested(self, d: dict, keys, value):
135135

136136

137137
class OpenRLHFConverter(Converter):
138-
139138
ARG_MAP = {
140139
# Ray and vLLM
141140
"ref_num_nodes": ArgSpec("", int, CVRT_WARNING),
@@ -361,7 +360,7 @@ def _parse_args_from_script(
361360
in_command_block = False
362361

363362
try:
364-
with open(script_path, "r", encoding="utf-8") as f:
363+
with open(script_path, encoding="utf-8") as f:
365364
for line in f:
366365
stripped_line = line.strip()
367366

@@ -425,7 +424,6 @@ def _parse_args_from_script(
425424

426425

427426
def post_process_args(args: dict):
428-
429427
if "allocation_mode" in args:
430428
# convert allocation_mode to sglang.dX.tY.pZ
431429
dp = args["cluster"]["n_nodes"] * args["cluster"]["n_gpus_per_node"]
@@ -442,7 +440,7 @@ def post_process_args(args: dict):
442440
allocation_mode += "t1"
443441
else:
444442
allocation_mode += f"t{args['allocation_mode']['sglang']['t']}"
445-
allocation_mode += f"p1"
443+
allocation_mode += "p1"
446444
allocation_mode += "+"
447445
if "engine" not in args["allocation_mode"]:
448446
allocation_mode += f"d{dp}t1p1"
@@ -452,7 +450,7 @@ def post_process_args(args: dict):
452450
allocation_mode += "t1"
453451
else:
454452
allocation_mode += f".t{args['allocation_mode']['engine']['t']}"
455-
allocation_mode += f"p1"
453+
allocation_mode += "p1"
456454

457455
args["allocation_mode"] = allocation_mode
458456
args["cluster"]["n_nodes"] = args["cluster"]["n_nodes"] * 2
@@ -830,7 +828,7 @@ def __init__(self, src_config_path: str, template_path: str):
830828
self.template_path = template_path
831829

832830
def parse(self) -> dict:
833-
with open(self.src_config_path, "r", encoding="utf-8") as f:
831+
with open(self.src_config_path, encoding="utf-8") as f:
834832
cfg = yaml.safe_load(f)
835833
return cfg
836834

@@ -895,7 +893,7 @@ def main():
895893
**converter_args[args.convert_src]
896894
)
897895
lite_args = converter.convert()
898-
yaml_str = yaml.dump(lite_args, sort_keys=False, allow_unicode=True)
896+
# yaml_str = yaml.dump(lite_args, sort_keys=False, allow_unicode=True)
899897
with open(args.output_path, "w", encoding="utf-8") as f:
900898
yaml.dump(lite_args, f, sort_keys=False, allow_unicode=True)
901899
print(f"Converted areal config saved to {args.output_path}")

examples/countdown/countdown.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import random
55
import re
66
import sys
7-
from typing import Dict
87

98
from tqdm import tqdm
109
from transformers import AutoTokenizer
@@ -279,19 +278,19 @@ def combine_nums(a, b):
279278
# Implicitly makes assumptions about the order of operations and valid operations
280279
a = int(a)
281280
b = int(b)
282-
possible = [[a + b, f"{a}+{b}={a+b}"], [a * b, f"{a}*{b}={a*b}"]]
281+
possible = [[a + b, f"{a}+{b}={a + b}"], [a * b, f"{a}*{b}={a * b}"]]
283282
if a <= b:
284-
possible.append([b - a, f"{b}-{a}={b-a}"])
283+
possible.append([b - a, f"{b}-{a}={b - a}"])
285284
if a != 0 and b % a == 0:
286-
possible.append([b // a, f"{b}/{a}={round(b//a,0)}"])
285+
possible.append([b // a, f"{b}/{a}={round(b // a, 0)}"])
287286
else:
288-
possible.append([a - b, f"{a}-{b}={a-b}"])
287+
possible.append([a - b, f"{a}-{b}={a - b}"])
289288
if b != 0 and a % b == 0:
290-
possible.append([a // b, f"{a}/{b}={round(a//b,0)}"])
289+
possible.append([a // b, f"{a}/{b}={round(a // b, 0)}"])
291290
return possible
292291

293292

294-
class CountDown(object):
293+
class CountDown:
295294
def __init__(
296295
self,
297296
max_target=25,
@@ -330,7 +329,7 @@ def generate(self, target):
330329
found = True
331330
return nums, solution
332331

333-
def get_task(self, apply_chat_template=False, return_raw=False) -> Dict[str, str]:
332+
def get_task(self, apply_chat_template=False, return_raw=False) -> dict[str, str]:
334333
target = random.randint(self.min_target, self.max_target)
335334
nums, solution = self.generate(target)
336335

examples/countdown/reward_score.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def validate_equation(equation_str, available_numbers):
3434

3535
# Each number should be used exactly once
3636
return numbers_in_eq == available_numbers
37-
except:
37+
except Exception:
3838
return False
3939

4040

@@ -72,28 +72,28 @@ def compute_score(
7272
do_print = random.randint(1, 64) == 1
7373

7474
if do_print:
75-
print(f"--------------------------------")
75+
print("--------------------------------")
7676
print(f"Target: {target} | Numbers: {numbers}")
7777
print(f"Extracted equation: {equation}")
7878
print(f"Solution string: {solution_str}")
7979

8080
if equation is None:
8181
if do_print:
82-
print(f"No equation found")
82+
print("No equation found")
8383
return 0
8484

8585
# Validate equation uses correct numbers
8686
if not validate_equation(equation, numbers):
8787
if do_print:
88-
print(f"Invalid equation")
88+
print("Invalid equation")
8989
return format_score
9090

9191
# Evaluate equation
9292
try:
9393
result = evaluate_equation(equation)
9494
if result is None:
9595
if do_print:
96-
print(f"Could not evaluate equation")
96+
print("Could not evaluate equation")
9797
return format_score
9898

9999
if abs(result - target) < 1e-5: # Account for floating point precision
@@ -104,7 +104,7 @@ def compute_score(
104104
if do_print:
105105
print(f"Wrong result: equation = {result}, target = {target}")
106106
return format_score
107-
except:
107+
except Exception:
108108
if do_print:
109-
print(f"Error evaluating equation")
109+
print("Error evaluating equation")
110110
return format_score

examples/docs/debug/cmp_rollout.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import base64
22
from io import BytesIO
3-
from typing import List
43

54
import datasets
65
import requests
@@ -9,7 +8,7 @@
98
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
109

1110

12-
def image2base64(images: List[ImageObject] | ImageObject) -> List[str] | str:
11+
def image2base64(images: list[ImageObject] | ImageObject) -> list[str] | str:
1312
if isinstance(images, ImageObject):
1413
images = [images]
1514

examples/search-agent/tongyi_deepresearch/react_agent.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,29 @@
1+
import datetime
12
import json
23
import os
34
import sys
45
import time
5-
from datetime import datetime
66
from pathlib import Path
7-
from typing import Dict, List
87

98
import json5
109
from qwen_agent.agents.fncall_agent import FnCallAgent
1110
from qwen_agent.llm.schema import Message
12-
from qwen_agent.settings import MAX_LLM_CALL_PER_RUN
1311
from transformers import PreTrainedTokenizer
1412

1513
from areal.experimental.openai import ArealOpenAI
1614
from areal.utils import logging
1715

1816
try:
19-
from .prompt import *
20-
from .tool_search import *
21-
from .tool_visit import *
17+
from .prompt import SYSTEM_PROMPT
18+
from .tool_search import Search
19+
from .tool_visit import Visit
2220
except ImportError: # Fallback when executed directly (no package parent known)
2321
module_dir = Path(__file__).parent
2422
if str(module_dir) not in sys.path:
2523
sys.path.insert(0, str(module_dir))
26-
from prompt import *
27-
from tool_search import *
28-
from tool_visit import *
24+
from prompt import SYSTEM_PROMPT
25+
from tool_search import Search
26+
from tool_visit import Visit
2927

3028

3129
logger = logging.getLogger("Tongyi-DeepResearch react agent")
@@ -35,8 +33,6 @@
3533

3634
MAX_LLM_CALL_PER_RUN = int(os.getenv("MAX_LLM_CALL_PER_RUN", 100))
3735

38-
import datetime
39-
4036

4137
def today_date():
4238
return datetime.date.today().strftime("%Y-%m-%d")
@@ -52,7 +48,7 @@ def parse_judge_result(raw_response):
5248
try:
5349
mbe = parse_fn(raw_response.split("```json")[-1].split("```")[0].strip())
5450
break
55-
except:
51+
except Exception:
5652
logger.warning(f"Error parsing judge result with {parse_fn}.")
5753
if mbe is None and '"judgement": "incorrect"' in raw_response:
5854
mbe = dict(judgement="incorrect")
@@ -94,7 +90,7 @@ def count_tokens(self, messages):
9490
return len(prompt_token_ids)
9591

9692
async def call_server(
97-
self, client: ArealOpenAI, messages: List[Dict], max_attempts: int = 100
93+
self, client: ArealOpenAI, messages: list[dict], max_attempts: int = 100
9894
) -> str:
9995
attempts = 0
10096
while attempts < max_attempts:
@@ -119,7 +115,7 @@ async def call_server(
119115

120116
async def run_agent(
121117
self, data, client: ArealOpenAI, save_path: str | None = None
122-
) -> List[List[Message]]:
118+
) -> list[list[Message]]:
123119
start_time = time.time()
124120
data["qid"]
125121
question = data["question"]
@@ -275,7 +271,7 @@ async def custom_call_tool(self, tool_name: str, tool_args: dict, **kwargs):
275271

276272
async def calc_reward_with_llm_judge(
277273
self,
278-
result: Dict[str, str],
274+
result: dict[str, str],
279275
):
280276
# Compute reward with LLM-as-Judge
281277
# judge_client = ArealOpenAI(engine=rollout_engine, tokenizer=tokenizer)
@@ -320,10 +316,10 @@ async def calc_reward_with_llm_judge(
320316

321317
async def make_trajectory(
322318
self,
323-
data: Dict[str, str],
319+
data: dict[str, str],
324320
client: ArealOpenAI,
325321
save_path: str | None = None,
326-
) -> Dict:
322+
) -> dict:
327323
result = await self.run_agent(data, client, save_path=save_path)
328324
reward = await self.calc_reward_with_llm_judge(result)
329325
completions = result["completions"]

examples/search-agent/tongyi_deepresearch/tool_search.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import json
33
import os
4-
from typing import List, Optional, Union
54

65
import aiohttp
76
from qwen_agent.tools.base import BaseTool, register_tool
@@ -25,7 +24,7 @@ class Search(BaseTool):
2524
"required": ["query"],
2625
}
2726

28-
def __init__(self, cfg: Optional[dict] = None):
27+
def __init__(self, cfg: dict | None = None):
2928
super().__init__(cfg)
3029

3130
async def google_search_with_serp(self, query: str):
@@ -72,7 +71,7 @@ def contains_chinese_basic(text: str) -> bool:
7271
snippet = (
7372
f"\n{page['snippet']}" if page.get("snippet") else ""
7473
)
75-
redacted_version = f"{idx}. [{page.get('title','')}]({page.get('link','')}){date_published}{source}\n{snippet}"
74+
redacted_version = f"{idx}. [{page.get('title', '')}]({page.get('link', '')}){date_published}{source}\n{snippet}"
7675
redacted_version = redacted_version.replace(
7776
"Your browser can't play this video.", ""
7877
)
@@ -93,7 +92,7 @@ def contains_chinese_basic(text: str) -> bool:
9392
async def search_with_serp(self, query: str):
9493
return await self.google_search_with_serp(query)
9594

96-
async def call(self, params: Union[str, dict], **kwargs) -> str: # type: ignore[override]
95+
async def call(self, params: str | dict, **kwargs) -> str: # type: ignore[override]
9796
try:
9897
query = params["query"]
9998
except Exception:
@@ -102,7 +101,7 @@ async def call(self, params: Union[str, dict], **kwargs) -> str: # type: ignore
102101
if isinstance(query, str):
103102
return await self.search_with_serp(query)
104103

105-
assert isinstance(query, List)
104+
assert isinstance(query, list)
106105
tasks = [self.search_with_serp(q) for q in query]
107106
responses = await asyncio.gather(*tasks)
108107
return "\n=======\n".join(responses)

0 commit comments

Comments
 (0)