Skip to content

Commit 3d15fd2

Browse files
authored
Merge branch 'develop' into unstable_to_stable
2 parents 5c8140d + 7e2e6cb commit 3d15fd2

File tree

15 files changed

+808
-779
lines changed

15 files changed

+808
-779
lines changed

graph_net/paddle/test_compiler.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,11 @@ def measure_performance(model_call, args, synchronizer_func, profile=False):
180180
duration_box = test_compiler_util.DurationBox(-1)
181181
with test_compiler_util.naive_timer(duration_box, synchronizer_func):
182182
model_call()
183-
print(f"Trial {i + 1}: e2e={duration_box.value:.4f} ms")
183+
print(
184+
f"Trial {i + 1}: e2e={duration_box.value:.4f} ms",
185+
file=sys.stderr,
186+
flush=True,
187+
)
184188
e2e_times.append(duration_box.value)
185189
stats["e2e"] = test_compiler_util.get_timing_stats(e2e_times)
186190

@@ -256,26 +260,34 @@ def test_single_model(args):
256260
# Run on eager mode
257261
eager_success = False
258262
try:
259-
print("Run model in eager mode.")
263+
print("Run model in eager mode.", file=sys.stderr, flush=True)
260264
static_model = get_static_model(args, model)
261265
expected_out, eager_time_stats = measure_performance(
262266
lambda: static_model(**input_dict), args, synchronizer_func, profile=False
263267
)
264268
eager_success = True
265269
except Exception as e:
266-
print(f"Run model in eager mode failed: {str(e)}\n{traceback.format_exc()}")
270+
print(
271+
f"Run model in eager mode failed: {str(e)}\n{traceback.format_exc()}",
272+
file=sys.stderr,
273+
flush=True,
274+
)
267275

268276
# Run on compiling mode
269277
compiled_success = False
270278
try:
271-
print("Run model in compiled mode.")
279+
print("Run model in compiled mode.", file=sys.stderr, flush=True)
272280
compiled_model = get_compiled_model(args, model)
273281
compiled_out, compiled_time_stats = measure_performance(
274282
lambda: compiled_model(**input_dict), args, synchronizer_func, profile=False
275283
)
276284
compiled_success = True
277285
except Exception as e:
278-
print(f"Run model in compiled mode failed: {str(e)}\n{traceback.format_exc()}")
286+
print(
287+
f"Run model in compiled mode failed: {str(e)}\n{traceback.format_exc()}",
288+
file=sys.stderr,
289+
flush=True,
290+
)
279291

280292
test_compiler_util.print_running_status(args, eager_success, compiled_success)
281293
if eager_success and compiled_success:
@@ -358,7 +370,7 @@ def test_multi_models(args):
358370
if args.verified_samples_list_path is not None:
359371
assert os.path.isfile(args.verified_samples_list_path)
360372
graphnet_root = path_utils.get_graphnet_root()
361-
print(f"graphnet_root: {graphnet_root}")
373+
print(f"graphnet_root: {graphnet_root}", file=sys.stderr, flush=True)
362374
verified_samples = []
363375
with open(args.verified_samples_list_path, "r") as f:
364376
for line in f.readlines():
@@ -368,7 +380,11 @@ def test_multi_models(args):
368380
failed_samples = []
369381
for model_path in path_utils.get_recursively_model_path(args.model_path):
370382
if verified_samples is None or os.path.abspath(model_path) in verified_samples:
371-
print(f"[{sample_idx}] test_compiler, model_path: {model_path}")
383+
print(
384+
f"[{sample_idx}] test_compiler, model_path: {model_path}",
385+
file=sys.stderr,
386+
flush=True,
387+
)
372388
cmd = " ".join(
373389
[
374390
sys.executable,
@@ -388,10 +404,12 @@ def test_multi_models(args):
388404
sample_idx += 1
389405

390406
print(
391-
f"Totally {sample_idx} verified samples, failed {len(failed_samples)} samples."
407+
f"Totally {sample_idx} verified samples, failed {len(failed_samples)} samples.",
408+
file=sys.stderr,
409+
flush=True,
392410
)
393411
for model_path in failed_samples:
394-
print(f"- {model_path}")
412+
print(f"- {model_path}", file=sys.stderr, flush=True)
395413

396414

397415
def main(args):

graph_net/test_compiler_util.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,14 +204,16 @@ def check_allclose(
204204
cmp_all_close_func,
205205
cmp_max_diff_func,
206206
cmp_mean_diff_func,
207-
cmp_max_relative_diff_func,
208-
cmp_mean_relative_diff_func,
207+
cmp_max_relative_diff_func=None,
208+
cmp_mean_relative_diff_func=None,
209209
):
210210
cmp_configs = generate_allclose_configs(cmp_all_close_func)
211211
cmp_configs.append(("[max_diff]", cmp_max_diff_func, {}))
212212
cmp_configs.append(("[mean_diff]", cmp_mean_diff_func, {}))
213-
cmp_configs.append(("[max_relative_diff]", cmp_max_relative_diff_func, {}))
214-
cmp_configs.append(("[mean_relative_diff]", cmp_mean_relative_diff_func, {}))
213+
if cmp_max_relative_diff_func is not None:
214+
cmp_configs.append(("[max_relative_diff]", cmp_max_relative_diff_func, {}))
215+
if cmp_mean_relative_diff_func is not None:
216+
cmp_configs.append(("[mean_relative_diff]", cmp_mean_relative_diff_func, {}))
215217

216218
for key, func, kwargs in cmp_configs:
217219
print_and_store_cmp(

graph_net/torch/test_compiler.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from contextlib import contextmanager
1313
import time
1414
import json
15+
import random
1516
import numpy as np
1617
import platform
1718
from graph_net.torch.backend.graph_compiler_backend import GraphCompilerBackend
@@ -23,6 +24,7 @@
2324
from graph_net.torch.backend.nope_backend import NopeBackend
2425
from graph_net.torch.backend.unstable_to_stable_backend import UnstableToStableBackend
2526
from graph_net.test_compiler_util import generate_allclose_configs
27+
from graph_net import test_compiler_util
2628

2729

2830
registry_backend = {
@@ -36,6 +38,15 @@
3638
}
3739

3840

41+
def set_seed(random_seed):
42+
random.seed(random_seed)
43+
np.random.seed(random_seed)
44+
torch.manual_seed(random_seed)
45+
if torch.cuda.is_available():
46+
torch.cuda.manual_seed(random_seed)
47+
torch.cuda.manual_seed_all(random_seed)
48+
49+
3950
def load_class_from_file(
4051
args: argparse.Namespace, class_name: str, device: str
4152
) -> Type[torch.nn.Module]:
@@ -229,6 +240,7 @@ def test_single_model(args):
229240
flush=True,
230241
)
231242

243+
runtime_seed = 1024
232244
eager_failure = False
233245
expected_out = None
234246
eager_types = []
@@ -242,6 +254,8 @@ def test_single_model(args):
242254
file=sys.stderr,
243255
flush=True,
244256
)
257+
258+
torch.manual_seed(runtime_seed)
245259
expected_out = eager_model_call()
246260
if not isinstance(expected_out, tuple):
247261
expected_out = (expected_out,)
@@ -273,6 +287,7 @@ def test_single_model(args):
273287
else:
274288
compiled_model = compiler(model)
275289

290+
torch.manual_seed(runtime_seed)
276291
compiled_model_call = lambda: compiled_model(**input_dict)
277292
compiled_stats = measure_performance(compiled_model_call, args, compiler)
278293
print(
@@ -377,33 +392,21 @@ def print_and_store_cmp(key, cmp_func, args, expected_out, compiled_out, **kwarg
377392

378393

379394
def compare_correctness(expected_out, compiled_out, args):
380-
# cmp_configs = [
381-
# ("[equal]", get_cmp_equal, {}),
382-
# ("[all_close_atol8_rtol8]", get_cmp_all_close, {"atol": 1e-8, "rtol": 1e-8}),
383-
# ("[all_close_atol8_rtol5]", get_cmp_all_close, {"atol": 1e-8, "rtol": 1e-5}),
384-
# ("[all_close_atol5_rtol5]", get_cmp_all_close, {"atol": 1e-5, "rtol": 1e-5}),
385-
# ("[all_close_atol3_rtol2]", get_cmp_all_close, {"atol": 1e-3, "rtol": 1e-2}),
386-
# ("[all_close_atol2_rtol1]", get_cmp_all_close, {"atol": 1e-2, "rtol": 1e-1}),
387-
# ("[max_diff]", get_cmp_max_diff, {}),
388-
# ("[mean_diff]", get_cmp_mean_diff, {}),
389-
# ("[diff_count_atol8_rtol8]", get_cmp_diff_count, {"atol": 1e-8, "rtol": 1e-8}),
390-
# ("[diff_count_atol8_rtol5]", get_cmp_diff_count, {"atol": 1e-8, "rtol": 1e-5}),
391-
# ("[diff_count_atol5_rtol5]", get_cmp_diff_count, {"atol": 1e-5, "rtol": 1e-5}),
392-
# ("[diff_count_atol3_rtol2]", get_cmp_diff_count, {"atol": 1e-3, "rtol": 1e-2}),
393-
# ("[diff_count_atol2_rtol1]", get_cmp_diff_count, {"atol": 1e-2, "rtol": 1e-1}),
394-
# ]
395-
cmp_configs = generate_allclose_configs(get_cmp_all_close)
396-
cmp_configs.append(("[equal]", get_cmp_equal, {}))
397-
398-
for key, func, kwargs in cmp_configs:
399-
print_and_store_cmp(
400-
key=key,
401-
cmp_func=func,
402-
args=args,
403-
expected_out=expected_out,
404-
compiled_out=compiled_out,
405-
**kwargs,
406-
)
395+
test_compiler_util.check_equal(
396+
args,
397+
expected_out,
398+
compiled_out,
399+
cmp_equal_func=get_cmp_equal,
400+
)
401+
402+
test_compiler_util.check_allclose(
403+
args,
404+
expected_out,
405+
compiled_out,
406+
cmp_all_close_func=get_cmp_all_close,
407+
cmp_max_diff_func=get_cmp_max_diff,
408+
cmp_mean_diff_func=get_cmp_mean_diff,
409+
)
407410

408411

409412
def get_cmp_equal(expected_out, compiled_out):
@@ -495,6 +498,9 @@ def is_single_model_dir(model_dir):
495498

496499
def main(args):
497500
assert os.path.isdir(args.model_path)
501+
502+
initalize_seed = 123
503+
set_seed(random_seed=initalize_seed)
498504
if is_single_model_dir(args.model_path):
499505
test_single_model(args)
500506
else:

graph_net/torch/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def extract_dynamic_shapes(example_inputs):
260260

261261

262262
def replay_tensor(info):
263+
name = info["name"]
263264
device = info["info"]["device"]
264265
dtype = info["info"]["dtype"]
265266
shape = info["info"]["shape"]
@@ -270,7 +271,11 @@ def replay_tensor(info):
270271
return info["data"].to(device)
271272
if dtype is torch.bool:
272273
return (torch.randn(size=shape) > 0.5).to(dtype).to(device)
273-
return torch.randn(size=shape).to(dtype).to(device) * std * 0.2 + mean
274+
tensor = torch.randn(size=shape).to(dtype).to(device) * std * 0.2 + mean
275+
# TODO(Xreki): remove this ugly code, and change the weight_meta instead.
276+
if name.startswith("L_self_modules") and "buffers_running_var" in name:
277+
tensor = torch.clip(tensor, min=0)
278+
return tensor
274279

275280

276281
def modify_code_by_device(code, new_device_str):

0 commit comments

Comments
 (0)