From a9f6dffb4fef6a32bbcbb015c235d93b4656cd75 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Mon, 27 Jan 2025 11:46:49 +0100 Subject: [PATCH] Don't set requires_grad in Arm-backend fold_qdq-pass It was previously manually set to True for all placeholders as a workaround for an issue where some params did not have requires_grad properly set. This caused issues for placeholders which were not leaf variables, and since the work around is not needed anyore we can just remove it. Change-Id: I258cc6ddf7205ab1948e1127f2fceadd2f942beb --- .../fold_qdq_with_annotated_qparams_pass.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 683e3969e79..dc019df1908 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -105,21 +105,6 @@ def fold_and_annotate_arg( for arg in arg_list: if not isinstance(arg, Node): return - """ - Make sure arg has requires_grad set to False - For parameters that are not quantized, sometimes (i.e. convolution) - the Parameter(FakeTensor(...)) has requires_grad set to True, which - causes the retracing of the graph to fail with: - - E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch. - E - E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) - E Original traceback: - E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward - E x = conv(x) - """ - if arg.op == "placeholder": - arg.meta["val"].requires_grad = False arg_quant_params = None if arg.target == dq_op: