Skip to content

Commit f65c53e

Browse files
committed
Fix batched testing.
1 parent 07e4340 commit f65c53e

File tree

2 files changed

+353
-98
lines changed

2 files changed

+353
-98
lines changed

graph_net/paddle/test_compiler.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def check_outputs(args, expected_out, compiled_out):
206206
args, eager_dtypes, compiled_dtypes
207207
)
208208

209-
def regular_outputs(origin_outputs):
209+
def transfer_to_float(origin_outputs):
210210
outputs = []
211211
for item in origin_outputs:
212212
if (
@@ -219,14 +219,19 @@ def regular_outputs(origin_outputs):
219219
return outputs
220220

221221
if type_match:
222-
expected_out = regular_outputs(expected_out)
223-
compiled_out = regular_outputs(compiled_out)
224-
225-
test_compiler_util.check_correctness(
222+
test_compiler_util.check_equal(
226223
args,
227224
expected_out,
228225
compiled_out,
229226
cmp_equal_func=get_cmp_equal,
227+
)
228+
229+
expected_out_fp32 = transfer_to_float(expected_out)
230+
compiled_out_fp32 = transfer_to_float(compiled_out)
231+
test_compiler_util.check_allclose(
232+
args,
233+
expected_out_fp32,
234+
compiled_out_fp32,
230235
cmp_all_close_func=get_cmp_all_close,
231236
cmp_max_diff_func=get_cmp_max_diff,
232237
cmp_mean_diff_func=get_cmp_mean_diff,
@@ -240,8 +245,6 @@ def test_single_model(args):
240245
model = get_model(args)
241246
model.eval()
242247

243-
# num_eager_ops = count_number_of_ops(args, model, eager_mode=True)
244-
245248
test_compiler_util.print_basic_config(
246249
args, get_hardward_name(args), get_compile_framework_version(args)
247250
)
@@ -314,8 +317,11 @@ def get_cmp_diff_count(expected_out, compiled_out, atol, rtol):
314317

315318

316319
def test_multi_models(args):
320+
sample_idx = 0
321+
failed_samples = []
317322
for model_path in path_utils.get_recursively_model_path(args.model_path):
318-
cmd = "".join(
323+
print(f"[{sample_idx}] test_compiler, model_path: {model_path}")
324+
cmd = " ".join(
319325
[
320326
sys.executable,
321327
"-m graph_net.paddle.test_compiler",
@@ -329,7 +335,14 @@ def test_multi_models(args):
329335
]
330336
)
331337
cmd_ret = os.system(cmd)
332-
assert cmd_ret == 0, f"{cmd_ret=}, {cmd=}"
338+
# assert cmd_ret == 0, f"{cmd_ret=}, {cmd=}"
339+
if cmd_ret != 0:
340+
failed_samples.append(model_path)
341+
sample_idx += 1
342+
343+
print(f"Totally {sample_idx} samples, failed {len(failed_samples)} samples.")
344+
for model_path in failed_samples:
345+
print(f"- {model_path}")
333346

334347

335348
def main(args):
@@ -380,12 +393,5 @@ def main(args):
380393
default="graph-net-test-compiler-log",
381394
help="Log prompt for performance log filtering.",
382395
)
383-
parser.add_argument(
384-
"--output-dir",
385-
type=str,
386-
required=False,
387-
default=None,
388-
help="Directory to save the structured JSON result file.",
389-
)
390396
args = parser.parse_args()
391397
main(args=args)

0 commit comments

Comments
 (0)