Skip to content

Commit 02ba0d2

Browse files
authored
[Example & Doc]update data generation oftransformer4sr (#1198)
1 parent 2c5a6ff commit 02ba0d2

File tree

3 files changed

+42
-31
lines changed

3 files changed

+42
-31
lines changed

docs/zh/examples/transformer4sr.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55
``` sh
66
# linux
7-
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/transformer4sr/data_generated.tar
7+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/transformer4sr/data_generated.tar.gz
88
# windows
9-
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/transformer4sr/data_generated.tar -o data_generated.tar
9+
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/transformer4sr/data_generated.tar.gz -o data_generated.tar.gz
1010
# unzip it
11-
tar -xvf data_generated.tar
11+
tar -xzvf data_generated.tar.gz
1212
python transformer4sr.py
1313
```
1414

examples/transformer4sr/generate_datasets.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919

2020
import concurrent.futures
21+
import gc
2122
import os
2223
import warnings
2324
from functools import partial
@@ -105,24 +106,31 @@ def generate_data(cfg: DictConfig):
105106
num_nested_max = cfg.DATA_GENERATE.num_nested_max
106107
partial_fliter_nested = partial(fliter_nested, num_nested_max=num_nested_max)
107108
exprs_fliter_nested = []
108-
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
109-
future_to_expr = {
110-
executor.submit(partial_fliter_nested, expr): expr
111-
for expr in exprs_filter_nodes
112-
}
113-
progress = tqdm(
114-
concurrent.futures.as_completed(future_to_expr),
115-
total=len(exprs_filter_nodes),
116-
desc=f"Check invalid abd very nested (>{num_nested_max}) expressions",
117-
)
118-
for future in progress:
119-
expr = future_to_expr[future]
120-
try:
121-
expr_sympy = future.result()
122-
if expr_sympy is not None:
123-
exprs_fliter_nested.append(expr_sympy)
124-
except Exception:
125-
continue
109+
110+
chunk_size = min(len(exprs_filter_nodes), 10000)
111+
total_chunks = (len(exprs_filter_nodes) - 1) // chunk_size + 1
112+
for i in range(0, len(exprs_filter_nodes), chunk_size):
113+
chunk = exprs_filter_nodes[i : i + chunk_size]
114+
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
115+
future_to_expr = {
116+
executor.submit(partial_fliter_nested, expr): expr for expr in chunk
117+
}
118+
progress = tqdm(
119+
concurrent.futures.as_completed(future_to_expr),
120+
total=len(chunk),
121+
desc=(
122+
f"Check invalid abd very nested (>{num_nested_max}) expressions. "
123+
f"Processing chunk {i//chunk_size + 1}/{total_chunks+1}"
124+
),
125+
)
126+
for future in progress:
127+
try:
128+
if (result := future.result()) is not None:
129+
exprs_fliter_nested.append(result)
130+
except Exception as e:
131+
print(f"Skipped error: {str(e)}")
132+
del chunk
133+
gc.collect()
126134

127135
# filter consts/vars/seq_length
128136
num_consts = cfg.DATA_GENERATE.num_consts
@@ -174,14 +182,16 @@ def generate_data(cfg: DictConfig):
174182
ground_truth.append(token)
175183

176184
cur_sympy_expr = from_seq_to_sympy(seq_deformed)
177-
np_y, np_x = gen_samples(cur_sympy_expr, num_samples=1000)
185+
np_y, np_x = gen_samples(
186+
cur_sympy_expr, num_samples=max(1000, sampling_times * 2)
187+
)
178188
assert np.nanmax(np.abs(np_y)) <= order_of_mag_limit
179189
mask = np.logical_not(np.isnan(np_y))
180190
num_temp_obs = np.sum(mask)
181191
assert num_temp_obs >= sampling_times
182192

183193
idx = np.random.choice(num_temp_obs, size=sampling_times, replace=False)
184-
num_var = count_var_num(sampling_times)
194+
num_var = count_var_num(cur_sympy_expr)
185195
x_values = np_x[mask][idx, :num_var]
186196
y_values = np_y[mask][idx]
187197
if var_type == "both":

examples/transformer4sr/utils.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import sympy
2424
import yaml
2525
import zss
26+
from sympy import S
2627
from typing_extensions import Literal
2728

2829
with open("./conf/transformer4sr.yaml", "r") as file:
@@ -160,23 +161,23 @@ def from_sympy_power_to_seq(exponent):
160161
return ["inv", "cb"]
161162
elif exponent == (-2):
162163
return ["inv", "sq"]
163-
elif exponent == (-3 / 2):
164+
elif exponent == -S(3) / 2:
164165
return ["inv", "cb", "sqrt"]
165166
elif exponent == (-1):
166167
return ["inv"]
167-
elif exponent == (-1 / 2):
168+
elif exponent in (-S.Half, -S(1) / 2):
168169
return ["inv", "sqrt"]
169-
elif exponent == (-1 / 3):
170+
elif exponent == -S(1) / 3:
170171
return ["inv", "cbrt"]
171-
elif exponent == (-1 / 4):
172+
elif exponent == -S(1) / 4:
172173
return ["inv", "sqrt", "sqrt"]
173-
elif exponent == (1 / 4):
174+
elif exponent == S(1) / 4:
174175
return ["sqrt", "sqrt"]
175-
elif exponent == (1 / 3):
176+
elif exponent == S(1) / 3:
176177
return ["cbrt"]
177-
elif exponent == (1 / 2):
178+
elif exponent in (S.Half, S(1) / 2):
178179
return ["sqrt"]
179-
elif exponent == (3 / 2):
180+
elif exponent == S(3) / 2:
180181
return ["cb", "sqrt"]
181182
elif exponent == (2):
182183
return ["sq"]

0 commit comments

Comments
 (0)