Skip to content

Commit 28834fc

Browse files
committed
test: reorgainize test files
1 parent d93dbc6 commit 28834fc

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@
5555
atexit.register(EXECUTOR.shutdown)
5656

5757

58-
def concurrent_run(func):
59-
futures = [EXECUTOR.submit(func) for _ in range(NUM_FUTURES)]
58+
def concurrent_run(func, /, *args, **kwargs):
59+
futures = [EXECUTOR.submit(func, *args, **kwargs) for _ in range(NUM_FUTURES)]
6060
future2index = {future: i for i, future in enumerate(futures)}
6161
completed_futures = sorted(as_completed(futures), key=future2index.get)
6262
first_exception = next(filter(None, (future.exception() for future in completed_futures)), None)
@@ -92,7 +92,7 @@ def test_fn():
9292
for result in concurrent_run(test_fn):
9393
assert result == expected
9494

95-
for result in concurrent_run(lambda: optree.tree_unflatten(treespec, leaves)):
95+
for result in concurrent_run(optree.tree_unflatten, treespec, leaves):
9696
assert result == tree
9797

9898

@@ -353,7 +353,7 @@ def test_tree_iter_thread_safe(
353353
namespace=namespace,
354354
)
355355

356-
results = concurrent_run(lambda: list(it))
356+
results = concurrent_run(list, it)
357357
for seq in results:
358358
assert sorted(seq) == seq
359359
assert sorted(itertools.chain.from_iterable(results)) == list(range(num_leaves))

0 commit comments

Comments
 (0)