|
| 1 | +import argparse |
| 2 | +import onnx |
| 3 | +from onnxruntime.quantization.qdq_loss_debug import ( |
| 4 | + collect_activations, compute_activation_error, compute_weight_error, |
| 5 | + create_activation_matching, create_weight_matching, |
| 6 | + modify_model_output_intermediate_tensors) |
| 7 | + |
| 8 | +import resnet50_data_reader |
| 9 | + |
| 10 | + |
| 11 | +def get_args(): |
| 12 | + parser = argparse.ArgumentParser() |
| 13 | + parser.add_argument( |
| 14 | + "--float_model", required=True, help="Path to original floating point model" |
| 15 | + ) |
| 16 | + parser.add_argument("--qdq_model", required=True, help="Path to qdq model") |
| 17 | + parser.add_argument( |
| 18 | + "--calibrate_dataset", default="./test_images", help="calibration data set" |
| 19 | + ) |
| 20 | + args = parser.parse_args() |
| 21 | + return args |
| 22 | + |
| 23 | + |
| 24 | +def _generate_aug_model_path(model_path: str) -> str: |
| 25 | + aug_model_path = ( |
| 26 | + model_path[: -len(".onnx")] if model_path.endswith(".onnx") else model_path |
| 27 | + ) |
| 28 | + return aug_model_path + ".save_tensors.onnx" |
| 29 | + |
| 30 | + |
| 31 | +def main(): |
| 32 | + # Process input parameters and setup model input data reader |
| 33 | + args = get_args() |
| 34 | + float_model_path = args.float_model |
| 35 | + qdq_model_path = args.qdq_model |
| 36 | + calibration_dataset_path = args.calibrate_dataset |
| 37 | + |
| 38 | + print("------------------------------------------------\n") |
| 39 | + print("Comparing weights of float model vs qdq model.....") |
| 40 | + |
| 41 | + matched_weights = create_weight_matching(float_model_path, qdq_model_path) |
| 42 | + weights_error = compute_weight_error(matched_weights) |
| 43 | + for weight_name, err in weights_error.items(): |
| 44 | + print(f"Cross model error of '{weight_name}': {err}\n") |
| 45 | + |
| 46 | + print("------------------------------------------------\n") |
| 47 | + print("Augmenting models to save intermediate activations......") |
| 48 | + |
| 49 | + aug_float_model = modify_model_output_intermediate_tensors(float_model_path) |
| 50 | + aug_float_model_path = _generate_aug_model_path(float_model_path) |
| 51 | + onnx.save( |
| 52 | + aug_float_model, |
| 53 | + aug_float_model_path, |
| 54 | + save_as_external_data=False, |
| 55 | + ) |
| 56 | + del aug_float_model |
| 57 | + |
| 58 | + aug_qdq_model = modify_model_output_intermediate_tensors(qdq_model_path) |
| 59 | + aug_qdq_model_path = _generate_aug_model_path(qdq_model_path) |
| 60 | + onnx.save( |
| 61 | + aug_qdq_model, |
| 62 | + aug_qdq_model_path, |
| 63 | + save_as_external_data=False, |
| 64 | + ) |
| 65 | + del aug_qdq_model |
| 66 | + |
| 67 | + print("------------------------------------------------\n") |
| 68 | + print("Running the augmented floating point model to collect activations......") |
| 69 | + input_data_reader = resnet50_data_reader.ResNet50DataReader( |
| 70 | + calibration_dataset_path, float_model_path |
| 71 | + ) |
| 72 | + float_activations = collect_activations(aug_float_model_path, input_data_reader) |
| 73 | + |
| 74 | + print("------------------------------------------------\n") |
| 75 | + print("Running the augmented qdq model to collect activations......") |
| 76 | + input_data_reader.rewind() |
| 77 | + qdq_activations = collect_activations(aug_qdq_model_path, input_data_reader) |
| 78 | + |
| 79 | + print("------------------------------------------------\n") |
| 80 | + print("Comparing activations of float model vs qdq model......") |
| 81 | + |
| 82 | + act_matching = create_activation_matching(qdq_activations, float_activations) |
| 83 | + act_error = compute_activation_error(act_matching) |
| 84 | + for act_name, err in act_error.items(): |
| 85 | + print(f"Cross model error of '{act_name}': {err['xmodel_err']} \n") |
| 86 | + print(f"QDQ error of '{act_name}': {err['qdq_err']} \n") |
| 87 | + |
| 88 | + |
| 89 | +if __name__ == "__main__": |
| 90 | + main() |
0 commit comments