Skip to content

Commit 35a2673

Browse files
authored
[Example]update data generation of transformer4sr to avoid processes getting stuck (#1199)
1 parent d186246 commit 35a2673

File tree

1 file changed

+83
-22
lines changed

1 file changed

+83
-22
lines changed

examples/transformer4sr/generate_datasets.py

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
import concurrent.futures
2121
import gc
2222
import os
23+
import signal
2324
import warnings
2425
from functools import partial
26+
from functools import wraps
2527

2628
import hydra
2729
import numpy as np
@@ -42,6 +44,37 @@
4244
warnings.filterwarnings("ignore")
4345

4446

47+
def timeout(seconds):
48+
def decorator(func):
49+
@wraps(func)
50+
def wrapper(*args, **kwargs):
51+
def handler(signum, frame):
52+
raise TimeoutError(f"Timed out after {seconds}s")
53+
54+
old = signal.signal(signal.SIGALRM, handler)
55+
signal.alarm(seconds)
56+
try:
57+
result = func(*args, **kwargs)
58+
finally:
59+
signal.alarm(0)
60+
signal.signal(signal.SIGALRM, old)
61+
return result
62+
63+
return wrapper
64+
65+
return decorator
66+
67+
68+
@timeout(30)
69+
def safe_factor(expr):
70+
return sympy.factor(expr)
71+
72+
73+
@timeout(30)
74+
def safe_simplify(expr):
75+
return sympy.simplify(expr)
76+
77+
4578
def fliter_nodes(expr, num_nodes):
4679
if num_nodes[0] <= len(expr) <= num_nodes[1]:
4780
return expr
@@ -52,14 +85,18 @@ def fliter_nodes(expr, num_nodes):
5285
def fliter_nested(expr, num_nested_max):
5386
try:
5487
expr_sympy = from_seq_to_sympy(expr)
55-
expr_sympy = sympy.factor(expr_sympy)
56-
expr_sympy = sympy.simplify(expr_sympy)
88+
expr_sympy = safe_factor(expr_sympy)
89+
expr_sympy = safe_simplify(expr_sympy)
5790
assert "zoo" not in str(expr_sympy)
5891
assert expr_tree_depth(expr_sympy) <= num_nested_max
5992
expr_sympy = reassign_variables(expr_sympy)
60-
expr_sympy = sympy.factor(expr_sympy)
61-
expr_sympy = sympy.simplify(expr_sympy)
62-
return expr_sympy
93+
expr_sympy = safe_factor(expr_sympy)
94+
expr_sympy = safe_simplify(expr_sympy)
95+
expr_seq = from_sympy_to_seq(expr_sympy)
96+
return expr_seq
97+
except TimeoutError:
98+
print("Task timed out")
99+
return None
63100
except Exception:
64101
return None
65102

@@ -107,40 +144,64 @@ def generate_data(cfg: DictConfig):
107144
partial_fliter_nested = partial(fliter_nested, num_nested_max=num_nested_max)
108145
exprs_fliter_nested = []
109146

110-
chunk_size = min(len(exprs_filter_nodes), 10000)
147+
max_workers = min(5, os.cpu_count() or 1)
148+
chunk_size = min(len(exprs_filter_nodes), 100000)
111149
total_chunks = (len(exprs_filter_nodes) - 1) // chunk_size + 1
112-
for i in range(0, len(exprs_filter_nodes), chunk_size):
113-
chunk = exprs_filter_nodes[i : i + chunk_size]
114-
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
115-
future_to_expr = {
116-
executor.submit(partial_fliter_nested, expr): expr for expr in chunk
117-
}
150+
151+
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
152+
for i in range(0, len(exprs_filter_nodes), chunk_size):
153+
chunk = exprs_filter_nodes[i : i + chunk_size]
154+
155+
futures = []
118156
progress = tqdm(
119-
concurrent.futures.as_completed(future_to_expr),
120157
total=len(chunk),
121158
desc=(
122-
f"Check invalid abd very nested (>{num_nested_max}) expressions. "
123-
f"Processing chunk {i//chunk_size + 1}/{total_chunks+1}"
159+
f"Check invalid and very nested (>{num_nested_max}) expressions. "
160+
f"Chunk {i//chunk_size + 1}/{total_chunks}"
124161
),
162+
leave=True,
125163
)
126-
for future in progress:
164+
165+
for expr in chunk:
127166
try:
128-
if (result := future.result()) is not None:
167+
future = executor.submit(partial_fliter_nested, expr)
168+
futures.append(future)
169+
except Exception as e:
170+
print(f"Submit failed: {e}")
171+
continue
172+
173+
completed_count = 0
174+
for future in concurrent.futures.as_completed(futures):
175+
try:
176+
result = future.result(timeout=60)
177+
if result is not None:
129178
exprs_fliter_nested.append(result)
179+
except concurrent.futures.TimeoutError:
180+
print("Task timeout during result")
130181
except Exception as e:
131-
print(f"Skipped error: {str(e)}")
132-
del chunk
133-
gc.collect()
182+
print(f"Task error: {str(e)}")
183+
finally:
184+
completed_count += 1
185+
progress.update(1)
186+
187+
if completed_count < len(futures):
188+
print(
189+
f"Warning: {len(futures) - completed_count} tasks not completed in chunk {i//chunk_size + 1}"
190+
)
191+
192+
progress.close()
193+
del chunk
194+
gc.collect()
195+
print(f"Filtered {len(exprs_fliter_nested)} valid expressions.")
134196

135197
# filter consts/vars/seq_length
136198
num_consts = cfg.DATA_GENERATE.num_consts
137199
num_vars = cfg.DATA_GENERATE.num_vars
138200
seq_length_max = cfg.DATA_GENERATE.seq_length_max
139201
exprs_cvl = []
140202
for i in tqdm(range(len(exprs_fliter_nested)), desc="Check consts and vars."):
141-
expr_seq = from_sympy_to_seq(exprs_fliter_nested[i])
142203
expr_seq = fliter_consts_vars_len(
143-
expr_seq, num_consts, num_vars, seq_length_max
204+
exprs_fliter_nested[i], num_consts, num_vars, seq_length_max
144205
)
145206
if expr_seq is not None:
146207
exprs_cvl.append(expr_seq)

0 commit comments

Comments
 (0)