|
18 | 18 |
|
19 | 19 |
|
20 | 20 | import concurrent.futures |
| 21 | +import gc |
21 | 22 | import os |
22 | 23 | import warnings |
23 | 24 | from functools import partial |
@@ -105,24 +106,31 @@ def generate_data(cfg: DictConfig): |
105 | 106 | num_nested_max = cfg.DATA_GENERATE.num_nested_max |
106 | 107 | partial_fliter_nested = partial(fliter_nested, num_nested_max=num_nested_max) |
107 | 108 | exprs_fliter_nested = [] |
108 | | - with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: |
109 | | - future_to_expr = { |
110 | | - executor.submit(partial_fliter_nested, expr): expr |
111 | | - for expr in exprs_filter_nodes |
112 | | - } |
113 | | - progress = tqdm( |
114 | | - concurrent.futures.as_completed(future_to_expr), |
115 | | - total=len(exprs_filter_nodes), |
116 | | - desc=f"Check invalid abd very nested (>{num_nested_max}) expressions", |
117 | | - ) |
118 | | - for future in progress: |
119 | | - expr = future_to_expr[future] |
120 | | - try: |
121 | | - expr_sympy = future.result() |
122 | | - if expr_sympy is not None: |
123 | | - exprs_fliter_nested.append(expr_sympy) |
124 | | - except Exception: |
125 | | - continue |
| 109 | + |
| 110 | + chunk_size = min(len(exprs_filter_nodes), 10000) |
| 111 | + total_chunks = (len(exprs_filter_nodes) - 1) // chunk_size + 1 |
| 112 | + for i in range(0, len(exprs_filter_nodes), chunk_size): |
| 113 | + chunk = exprs_filter_nodes[i : i + chunk_size] |
| 114 | + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: |
| 115 | + future_to_expr = { |
| 116 | + executor.submit(partial_fliter_nested, expr): expr for expr in chunk |
| 117 | + } |
| 118 | + progress = tqdm( |
| 119 | + concurrent.futures.as_completed(future_to_expr), |
| 120 | + total=len(chunk), |
| 121 | + desc=( |
| 122 | + f"Check invalid abd very nested (>{num_nested_max}) expressions. " |
| 123 | + f"Processing chunk {i//chunk_size + 1}/{total_chunks+1}" |
| 124 | + ), |
| 125 | + ) |
| 126 | + for future in progress: |
| 127 | + try: |
| 128 | + if (result := future.result()) is not None: |
| 129 | + exprs_fliter_nested.append(result) |
| 130 | + except Exception as e: |
| 131 | + print(f"Skipped error: {str(e)}") |
| 132 | + del chunk |
| 133 | + gc.collect() |
126 | 134 |
|
127 | 135 | # filter consts/vars/seq_length |
128 | 136 | num_consts = cfg.DATA_GENERATE.num_consts |
@@ -174,14 +182,16 @@ def generate_data(cfg: DictConfig): |
174 | 182 | ground_truth.append(token) |
175 | 183 |
|
176 | 184 | cur_sympy_expr = from_seq_to_sympy(seq_deformed) |
177 | | - np_y, np_x = gen_samples(cur_sympy_expr, num_samples=1000) |
| 185 | + np_y, np_x = gen_samples( |
| 186 | + cur_sympy_expr, num_samples=max(1000, sampling_times * 2) |
| 187 | + ) |
178 | 188 | assert np.nanmax(np.abs(np_y)) <= order_of_mag_limit |
179 | 189 | mask = np.logical_not(np.isnan(np_y)) |
180 | 190 | num_temp_obs = np.sum(mask) |
181 | 191 | assert num_temp_obs >= sampling_times |
182 | 192 |
|
183 | 193 | idx = np.random.choice(num_temp_obs, size=sampling_times, replace=False) |
184 | | - num_var = count_var_num(sampling_times) |
| 194 | + num_var = count_var_num(cur_sympy_expr) |
185 | 195 | x_values = np_x[mask][idx, :num_var] |
186 | 196 | y_values = np_y[mask][idx] |
187 | 197 | if var_type == "both": |
|
0 commit comments