Skip to content

Commit 49f91e8

Browse files
committed
Update
1 parent 35e2381 commit 49f91e8

File tree

1 file changed

+80
-42
lines changed

1 file changed

+80
-42
lines changed

graph_net/paddle/fixed_random_seed_device_runner.py

Lines changed: 80 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from graph_net.paddle import utils
1313
from graph_net import test_compiler_util
14+
from graph_net import path_utils
1415

1516

1617
def set_seed(random_seed):
@@ -125,6 +126,52 @@ def measure_performance(model_call, synchronizer_func, warmup, trials):
125126
return test_compiler_util.get_timing_stats(e2e_times)
126127

127128

129+
def test_single_model(args, model_path):
130+
set_seed(123)
131+
132+
print(f"[Config] device: {args.device}")
133+
print(f"[Config] compiler: {args.compiler}")
134+
print(f"[Config] hardware: {get_hardware_name(args.device)}")
135+
print(f"[Config] framework_version: {paddle.__version__}")
136+
print(f"[Config] warmup: {args.warmup}")
137+
print(f"[Config] trials: {args.trials}")
138+
139+
success = False
140+
try:
141+
synchronizer_func = paddle.device.synchronize
142+
143+
input_dict = get_input_dict(model_path)
144+
model = load_model(model_path)
145+
model.eval()
146+
147+
print(f"Run model with compiler: {args.compiler}")
148+
if args.compiler == "nope":
149+
compiled_model = model
150+
else:
151+
compiled_model = get_compiled_model(model, args.compiler, model_path)
152+
153+
time_stats = measure_performance(
154+
lambda: compiled_model(**input_dict),
155+
synchronizer_func,
156+
args.warmup,
157+
args.trials
158+
)
159+
success = True
160+
161+
print(f"[Result] model_path: {model_path}")
162+
print(f"[Result] compiler: {args.compiler}")
163+
print(f"[Result] device: {args.device}")
164+
print(f"[Result] e2e_mean: {time_stats['mean']:.5f}")
165+
print(f"[Result] e2e_std: {time_stats['std']:.5f}")
166+
167+
except Exception as e:
168+
print(f"Run model failed: {str(e)}")
169+
print(traceback.format_exc())
170+
return False
171+
172+
return success
173+
174+
128175
def main():
129176
parser = argparse.ArgumentParser(description="Test device performance with fixed random seeds")
130177
parser.add_argument(
@@ -161,52 +208,43 @@ def main():
161208
default=10,
162209
help="Number of timing trials"
163210
)
211+
parser.add_argument(
212+
"--allow-list",
213+
type=str,
214+
required=False,
215+
default=None,
216+
help="Path to allow list file"
217+
)
164218

165219
args = parser.parse_args()
166220

167-
set_seed(123)
221+
test_samples = []
222+
if args.allow_list is not None:
223+
assert os.path.isfile(args.allow_list)
224+
graphnet_root = path_utils.get_graphnet_root()
225+
print(f"graphnet_root: {graphnet_root}")
226+
test_samples = []
227+
with open(args.allow_list, "r") as f:
228+
for line in f.readlines():
229+
test_samples.append(os.path.join(graphnet_root, line.strip()))
230+
231+
sample_idx = 0
232+
failed_samples = []
168233

169-
print(f"[Config] device: {args.device}")
170-
print(f"[Config] compiler: {args.compiler}")
171-
print(f"[Config] hardware: {get_hardware_name(args.device)}")
172-
print(f"[Config] framework_version: {paddle.__version__}")
173-
print(f"[Config] warmup: {args.warmup}")
174-
print(f"[Config] trials: {args.trials}")
175-
176-
success = False
177-
try:
178-
synchronizer_func = paddle.device.synchronize
179-
180-
input_dict = get_input_dict(args.model_path)
181-
model = load_model(args.model_path)
182-
model.eval()
183-
184-
print(f"Run model with compiler: {args.compiler}")
185-
if args.compiler == "nope":
186-
compiled_model = model
187-
else:
188-
compiled_model = get_compiled_model(model, args.compiler, args.model_path)
189-
190-
time_stats = measure_performance(
191-
lambda: compiled_model(**input_dict),
192-
synchronizer_func,
193-
args.warmup,
194-
args.trials
195-
)
196-
success = True
197-
198-
print(f"[Result] model_path: {args.model_path}")
199-
print(f"[Result] compiler: {args.compiler}")
200-
print(f"[Result] device: {args.device}")
201-
print(f"[Result] e2e_mean: {time_stats['mean']:.5f}")
202-
print(f"[Result] e2e_std: {time_stats['std']:.5f}")
203-
204-
except Exception as e:
205-
print(f"Run model failed: {str(e)}")
206-
print(traceback.format_exc())
207-
return 1
208-
209-
return 0 if success else 1
234+
for model_path in path_utils.get_recursively_model_path(args.model_path):
235+
if not test_samples or os.path.abspath(model_path) in test_samples:
236+
print(f"[{sample_idx}] fixed_random_seed_device_runner, model_path: {model_path}")
237+
238+
success = test_single_model(args, model_path)
239+
if not success:
240+
failed_samples.append(model_path)
241+
sample_idx += 1
242+
243+
print(f"Totally {sample_idx} verified samples, failed {len(failed_samples)} samples.")
244+
for model_path in failed_samples:
245+
print(f"- {model_path}")
246+
247+
return 0 if len(failed_samples) == 0 else 1
210248

211249

212250
if __name__ == "__main__":

0 commit comments

Comments
 (0)