|
18 | 18 | import shutil
|
19 | 19 |
|
20 | 20 | from paddle.fluid.evaluator import Evaluator
|
21 |
| -from paddle.fluid.framework import Program, Parameter, default_main_program, Variable |
| 21 | +from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable |
22 | 22 | from . import core
|
23 | 23 |
|
24 | 24 | __all__ = [
|
@@ -1374,3 +1374,101 @@ def has_success(checkpoint_dir, cur_dir):
|
1374 | 1374 | if success_num > current_dir:
|
1375 | 1375 | current_dir = success_num
|
1376 | 1376 | return current_dir
|
| 1377 | + |
| 1378 | + |
| 1379 | +def get_test_program(filelist, program=None, startup_program=None): |
| 1380 | + """ |
| 1381 | + Transpile current train program to a program to read test dataset |
| 1382 | + if the program is using reader ops like "open_files_op". |
| 1383 | + """ |
| 1384 | + |
| 1385 | + def _copy_reader_var_(block, var, new_name=None): |
| 1386 | + if new_name == None: |
| 1387 | + new_name = var.name |
| 1388 | + new_var = block.create_var( |
| 1389 | + name=str(new_name), type=core.VarDesc.VarType.READER) |
| 1390 | + new_var.desc.set_shapes(var.desc.shapes()) |
| 1391 | + new_var.desc.set_dtypes(var.desc.dtypes()) |
| 1392 | + new_var.persistable = True |
| 1393 | + return new_var |
| 1394 | + |
| 1395 | + def _get_test_reader_name(train_reader_name): |
| 1396 | + return train_reader_name + "_test" |
| 1397 | + |
| 1398 | + def _is_reader_op(op): |
| 1399 | + block = op.block |
| 1400 | + if "Out" in op.output_names: |
| 1401 | + reader_out = block.vars[op.output("Out")[0]] |
| 1402 | + if reader_out.type == core.VarDesc.VarType.READER: |
| 1403 | + return True |
| 1404 | + return False |
| 1405 | + |
| 1406 | + if program == None: |
| 1407 | + program = default_main_program() |
| 1408 | + if startup_program == None: |
| 1409 | + startup_program = default_startup_program() |
| 1410 | + startup_block = startup_program.global_block() |
| 1411 | + |
| 1412 | + # 1. find out the orignal reader var name |
| 1413 | + startup_reader_op_list = [] |
| 1414 | + |
| 1415 | + for op in startup_block.ops: |
| 1416 | + if _is_reader_op(op): |
| 1417 | + startup_reader_op_list.append(op) |
| 1418 | + |
| 1419 | + if len(startup_reader_op_list) == 0: |
| 1420 | + return program |
| 1421 | + |
| 1422 | + root_reader_op = startup_reader_op_list[0] |
| 1423 | + train_test_reader_map = {} |
| 1424 | + # 2. add operators to startup to read open and read test data files |
| 1425 | + for op in startup_reader_op_list: |
| 1426 | + assert (len(op.output("Out")) == 1) |
| 1427 | + train_reader_name = op.output("Out")[0] |
| 1428 | + train_reader = startup_block.vars[train_reader_name] |
| 1429 | + test_reader = _copy_reader_var_( |
| 1430 | + startup_block, |
| 1431 | + train_reader, |
| 1432 | + new_name=_get_test_reader_name(train_reader_name)) |
| 1433 | + train_test_reader_map[train_reader.name] = test_reader |
| 1434 | + |
| 1435 | + test_op_inputs = {} |
| 1436 | + for name in op.input_names: |
| 1437 | + train_arg_names = op.input(name) |
| 1438 | + test_arg_vars = [] |
| 1439 | + for arg_name in train_arg_names: |
| 1440 | + arg_var = train_test_reader_map[ |
| 1441 | + arg_name] if name == "UnderlyingReader" else startup_block.vars[ |
| 1442 | + arg_name] |
| 1443 | + test_arg_vars.append(arg_var) |
| 1444 | + test_op_inputs[name] = test_arg_vars |
| 1445 | + |
| 1446 | + test_op = startup_block.append_op( |
| 1447 | + type=op.type, |
| 1448 | + inputs=test_op_inputs, |
| 1449 | + outputs={'Out': [test_reader]}, |
| 1450 | + attrs=op.attrs) |
| 1451 | + # root reader op's filelist attr for read test files |
| 1452 | + if op.type == root_reader_op.type: |
| 1453 | + test_op.set_attr("file_names", filelist) |
| 1454 | + if op.type == "create_multi_pass_reader": |
| 1455 | + test_op.set_attr("pass_num", 1) |
| 1456 | + |
| 1457 | + # 3. rename reader vars in inference program to different name |
| 1458 | + # to avoid read from train data. |
| 1459 | + main_block = program.global_block() |
| 1460 | + for var in main_block.vars.values(): |
| 1461 | + if var.type == core.VarDesc.VarType.READER: |
| 1462 | + main_block.rename_var( |
| 1463 | + str(var.name), str(_get_test_reader_name(var.name))) |
| 1464 | + |
| 1465 | + for op in main_block.ops: |
| 1466 | + if op.type == root_reader_op.type: |
| 1467 | + test_op.set_attr("file_names", filelist) |
| 1468 | + if op.type == "create_multi_pass_reader": |
| 1469 | + test_op.set_attr("pass_num", 1) |
| 1470 | + |
| 1471 | + startup_program.sync_with_cpp() |
| 1472 | + program.sync_with_cpp() |
| 1473 | + |
| 1474 | + return program |
0 commit comments