2020import concurrent .futures
2121import gc
2222import os
23+ import signal
2324import warnings
2425from functools import partial
26+ from functools import wraps
2527
2628import hydra
2729import numpy as np
4244warnings .filterwarnings ("ignore" )
4345
4446
47+ def timeout (seconds ):
48+ def decorator (func ):
49+ @wraps (func )
50+ def wrapper (* args , ** kwargs ):
51+ def handler (signum , frame ):
52+ raise TimeoutError (f"Timed out after { seconds } s" )
53+
54+ old = signal .signal (signal .SIGALRM , handler )
55+ signal .alarm (seconds )
56+ try :
57+ result = func (* args , ** kwargs )
58+ finally :
59+ signal .alarm (0 )
60+ signal .signal (signal .SIGALRM , old )
61+ return result
62+
63+ return wrapper
64+
65+ return decorator
66+
67+
68+ @timeout (30 )
69+ def safe_factor (expr ):
70+ return sympy .factor (expr )
71+
72+
73+ @timeout (30 )
74+ def safe_simplify (expr ):
75+ return sympy .simplify (expr )
76+
77+
4578def fliter_nodes (expr , num_nodes ):
4679 if num_nodes [0 ] <= len (expr ) <= num_nodes [1 ]:
4780 return expr
@@ -52,14 +85,18 @@ def fliter_nodes(expr, num_nodes):
5285def fliter_nested (expr , num_nested_max ):
5386 try :
5487 expr_sympy = from_seq_to_sympy (expr )
55- expr_sympy = sympy . factor (expr_sympy )
56- expr_sympy = sympy . simplify (expr_sympy )
88+ expr_sympy = safe_factor (expr_sympy )
89+ expr_sympy = safe_simplify (expr_sympy )
5790 assert "zoo" not in str (expr_sympy )
5891 assert expr_tree_depth (expr_sympy ) <= num_nested_max
5992 expr_sympy = reassign_variables (expr_sympy )
60- expr_sympy = sympy .factor (expr_sympy )
61- expr_sympy = sympy .simplify (expr_sympy )
62- return expr_sympy
93+ expr_sympy = safe_factor (expr_sympy )
94+ expr_sympy = safe_simplify (expr_sympy )
95+ expr_seq = from_sympy_to_seq (expr_sympy )
96+ return expr_seq
97+ except TimeoutError :
98+ print ("Task timed out" )
99+ return None
63100 except Exception :
64101 return None
65102
@@ -107,40 +144,64 @@ def generate_data(cfg: DictConfig):
107144 partial_fliter_nested = partial (fliter_nested , num_nested_max = num_nested_max )
108145 exprs_fliter_nested = []
109146
110- chunk_size = min (len (exprs_filter_nodes ), 10000 )
147+ max_workers = min (5 , os .cpu_count () or 1 )
148+ chunk_size = min (len (exprs_filter_nodes ), 100000 )
111149 total_chunks = (len (exprs_filter_nodes ) - 1 ) // chunk_size + 1
112- for i in range ( 0 , len ( exprs_filter_nodes ), chunk_size ):
113- chunk = exprs_filter_nodes [ i : i + chunk_size ]
114- with concurrent . futures . ThreadPoolExecutor ( max_workers = 5 ) as executor :
115- future_to_expr = {
116- executor . submit ( partial_fliter_nested , expr ): expr for expr in chunk
117- }
150+
151+ with concurrent . futures . ProcessPoolExecutor ( max_workers = max_workers ) as executor :
152+ for i in range ( 0 , len ( exprs_filter_nodes ), chunk_size ) :
153+ chunk = exprs_filter_nodes [ i : i + chunk_size ]
154+
155+ futures = []
118156 progress = tqdm (
119- concurrent .futures .as_completed (future_to_expr ),
120157 total = len (chunk ),
121158 desc = (
122- f"Check invalid abd very nested (>{ num_nested_max } ) expressions. "
123- f"Processing chunk { i // chunk_size + 1 } /{ total_chunks + 1 } "
159+ f"Check invalid and very nested (>{ num_nested_max } ) expressions. "
160+ f"Chunk { i // chunk_size + 1 } /{ total_chunks } "
124161 ),
162+ leave = True ,
125163 )
126- for future in progress :
164+
165+ for expr in chunk :
127166 try :
128- if (result := future .result ()) is not None :
167+ future = executor .submit (partial_fliter_nested , expr )
168+ futures .append (future )
169+ except Exception as e :
170+ print (f"Submit failed: { e } " )
171+ continue
172+
173+ completed_count = 0
174+ for future in concurrent .futures .as_completed (futures ):
175+ try :
176+ result = future .result (timeout = 60 )
177+ if result is not None :
129178 exprs_fliter_nested .append (result )
179+ except concurrent .futures .TimeoutError :
180+ print ("Task timeout during result" )
130181 except Exception as e :
131- print (f"Skipped error: { str (e )} " )
132- del chunk
133- gc .collect ()
182+ print (f"Task error: { str (e )} " )
183+ finally :
184+ completed_count += 1
185+ progress .update (1 )
186+
187+ if completed_count < len (futures ):
188+ print (
189+ f"Warning: { len (futures ) - completed_count } tasks not completed in chunk { i // chunk_size + 1 } "
190+ )
191+
192+ progress .close ()
193+ del chunk
194+ gc .collect ()
195+ print (f"Filtered { len (exprs_fliter_nested )} valid expressions." )
134196
135197 # filter consts/vars/seq_length
136198 num_consts = cfg .DATA_GENERATE .num_consts
137199 num_vars = cfg .DATA_GENERATE .num_vars
138200 seq_length_max = cfg .DATA_GENERATE .seq_length_max
139201 exprs_cvl = []
140202 for i in tqdm (range (len (exprs_fliter_nested )), desc = "Check consts and vars." ):
141- expr_seq = from_sympy_to_seq (exprs_fliter_nested [i ])
142203 expr_seq = fliter_consts_vars_len (
143- expr_seq , num_consts , num_vars , seq_length_max
204+ exprs_fliter_nested [ i ] , num_consts , num_vars , seq_length_max
144205 )
145206 if expr_seq is not None :
146207 exprs_cvl .append (expr_seq )
0 commit comments