1010R = TypeVar ("R" )
1111
1212
13+ # async def run_concurrent(
14+ # coro_fn: Callable[[T], Awaitable[R]],
15+ # items: List[T],
16+ # *,
17+ # desc: str = "processing",
18+ # unit: str = "item",
19+ # progress_bar: Optional[gr.Progress] = None,
20+ # ) -> List[R]:
21+ # tasks = [asyncio.create_task(coro_fn(it)) for it in items]
22+ #
23+ # results = []
24+ # async for future in tqdm_async(
25+ # tasks, desc=desc, unit=unit
26+ # ):
27+ # try:
28+ # result = await future
29+ # results.append(result)
30+ # except Exception as e: # pylint: disable=broad-except
31+ # logger.exception("Task failed: %s", e)
32+ #
33+ # if progress_bar is not None:
34+ # progress_bar((len(results)) / len(items), desc=desc)
35+ #
36+ # if progress_bar is not None:
37+ # progress_bar(1.0, desc=desc)
38+ # return results
39+
40+ # results = await tqdm_async.gather(*tasks, desc=desc, unit=unit)
41+ #
42+ # ok_results = []
43+ # for idx, res in enumerate(results):
44+ # if isinstance(res, Exception):
45+ # logger.exception("Task failed: %s", res)
46+ # if progress_bar:
47+ # progress_bar((idx + 1) / len(items), desc=desc)
48+ # continue
49+ # ok_results.append(res)
50+ # if progress_bar:
51+ # progress_bar((idx + 1) / len(items), desc=desc)
52+ #
53+ # if progress_bar:
54+ # progress_bar(1.0, desc=desc)
55+ # return ok_results
56+
57+ # async def run_concurrent(
58+ # coro_fn: Callable[[T], Awaitable[R]],
59+ # items: List[T],
60+ # *,
61+ # desc: str = "processing",
62+ # unit: str = "item",
63+ # progress_bar: Optional[gr.Progress] = None,
64+ # ) -> List[R]:
65+ # tasks = [asyncio.create_task(coro_fn(it)) for it in items]
66+ #
67+ # results = []
68+ # # 使用同步方式更新进度条,避免异步冲突
69+ # for i, task in enumerate(asyncio.as_completed(tasks)):
70+ # try:
71+ # result = await task
72+ # results.append(result)
73+ # # 同步更新进度条
74+ # if progress_bar is not None:
75+ # # 在同步上下文中更新进度
76+ # progress_bar((i + 1) / len(items), desc=desc)
77+ # except Exception as e:
78+ # logger.exception("Task failed: %s", e)
79+ # results.append(e)
80+ #
81+ # return results
82+
83+
1384async def run_concurrent (
1485 coro_fn : Callable [[T ], Awaitable [R ]],
1586 items : List [T ],
@@ -20,19 +91,36 @@ async def run_concurrent(
2091) -> List [R ]:
2192 tasks = [asyncio .create_task (coro_fn (it )) for it in items ]
2293
23- results = await tqdm_async .gather (* tasks , desc = desc , unit = unit )
24-
25- ok_results = []
26- for idx , res in enumerate (results ):
27- if isinstance (res , Exception ):
28- logger .exception ("Task failed: %s" , res )
29- if progress_bar :
30- progress_bar ((idx + 1 ) / len (items ), desc = desc )
31- continue
32- ok_results .append (res )
33- if progress_bar :
34- progress_bar ((idx + 1 ) / len (items ), desc = desc )
35-
36- if progress_bar :
37- progress_bar (1.0 , desc = desc )
38- return ok_results
94+ completed_count = 0
95+ results = []
96+
97+ pbar = tqdm_async (total = len (items ), desc = desc , unit = unit )
98+
99+ if progress_bar is not None :
100+ progress_bar (0.0 , desc = f"{ desc } (0/{ len (items )} )" )
101+
102+ for future in asyncio .as_completed (tasks ):
103+ try :
104+ result = await future
105+ results .append (result )
106+ except Exception as e : # pylint: disable=broad-except
107+ logger .exception ("Task failed: %s" , e )
108+ # even if failed, record it to keep results consistent with tasks
109+ results .append (e )
110+
111+ completed_count += 1
112+ pbar .update (1 )
113+
114+ if progress_bar is not None :
115+ progress = completed_count / len (items )
116+ progress_bar (progress , desc = f"{ desc } ({ completed_count } /{ len (items )} )" )
117+
118+ pbar .close ()
119+
120+ if progress_bar is not None :
121+ progress_bar (1.0 , desc = f"{ desc } (completed)" )
122+
123+ # filter out exceptions
124+ results = [res for res in results if not isinstance (res , Exception )]
125+
126+ return results
0 commit comments