|
55 | 55 | atexit.register(EXECUTOR.shutdown) |
56 | 56 |
|
57 | 57 |
|
58 | | -def concurrent_run(func): |
59 | | - futures = [EXECUTOR.submit(func) for _ in range(NUM_FUTURES)] |
| 58 | +def concurrent_run(func, /, *args, **kwargs): |
| 59 | + futures = [EXECUTOR.submit(func, *args, **kwargs) for _ in range(NUM_FUTURES)] |
60 | 60 | future2index = {future: i for i, future in enumerate(futures)} |
61 | 61 | completed_futures = sorted(as_completed(futures), key=future2index.get) |
62 | 62 | first_exception = next(filter(None, (future.exception() for future in completed_futures)), None) |
@@ -92,7 +92,7 @@ def test_fn(): |
92 | 92 | for result in concurrent_run(test_fn): |
93 | 93 | assert result == expected |
94 | 94 |
|
95 | | - for result in concurrent_run(lambda: optree.tree_unflatten(treespec, leaves)): |
| 95 | + for result in concurrent_run(optree.tree_unflatten, treespec, leaves): |
96 | 96 | assert result == tree |
97 | 97 |
|
98 | 98 |
|
@@ -353,7 +353,7 @@ def test_tree_iter_thread_safe( |
353 | 353 | namespace=namespace, |
354 | 354 | ) |
355 | 355 |
|
356 | | - results = concurrent_run(lambda: list(it)) |
| 356 | + results = concurrent_run(list, it) |
357 | 357 | for seq in results: |
358 | 358 | assert sorted(seq) == seq |
359 | 359 | assert sorted(itertools.chain.from_iterable(results)) == list(range(num_leaves)) |
0 commit comments