|
61 | 61 | default="",
|
62 | 62 | help="Generate and save an ETRecord to the given file location",
|
63 | 63 | )
|
| 64 | + parser.add_argument( |
| 65 | + "-t", |
| 66 | + "--test_after_export", |
| 67 | + action="store_true", |
| 68 | + required=False, |
| 69 | + default=False, |
| 70 | + help="Test the pte with pybindings", |
| 71 | + ) |
64 | 72 | parser.add_argument("-o", "--output_dir", default=".", help="output directory")
|
65 | 73 |
|
66 | 74 | args = parser.parse_args()
|
|
117 | 125 | quant_tag = "q8" if args.quantize else "fp32"
|
118 | 126 | model_name = f"{args.model_name}_xnnpack_{quant_tag}"
|
119 | 127 | save_pte_program(exec_prog, model_name, args.output_dir)
|
| 128 | + |
| 129 | + if args.test_after_export: |
| 130 | + logging.info("Testing the pte with pybind") |
| 131 | + from executorch.extension.pybindings.portable_lib import ( |
| 132 | + _load_for_executorch_from_buffer, |
| 133 | + ) |
| 134 | + |
| 135 | + # Import custom ops. This requires portable_lib to be loaded first. |
| 136 | + from executorch.extension.llm.custom_ops import ( # noqa: F401, F403 |
| 137 | + custom_ops, |
| 138 | + ) # usort: skip |
| 139 | + |
| 140 | + # Import quantized ops. This requires portable_lib to be loaded first. |
| 141 | + from executorch.kernels import quantized # usort: skip # noqa: F401, F403 |
| 142 | + from torch.utils._pytree import tree_flatten |
| 143 | + |
| 144 | + m = _load_for_executorch_from_buffer(exec_prog.buffer) |
| 145 | + logging.info("Successfully loaded the model") |
| 146 | + flattened = tree_flatten(example_inputs)[0] |
| 147 | + res = m.run_method("forward", flattened) |
| 148 | + logging.info("Successfully ran the model") |
0 commit comments