|  | 
| 12 | 12 |     create_node, | 
| 13 | 13 |     get_first_fake_tensor, | 
| 14 | 14 | ) | 
|  | 15 | +from executorch.backends.arm.common.debug import get_node_debug_info | 
| 15 | 16 | from executorch.backends.transforms.utils import ( | 
| 16 | 17 |     create_constant_placeholder, | 
| 17 | 18 |     delete_constant_placeholder, | 
| @@ -60,8 +61,16 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:  # noqa: C901 | 
| 60 | 61 |             input_node = node.all_input_nodes[0] | 
| 61 | 62 |             is_single_user = len(input_node.users) == 1 | 
| 62 | 63 |             bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = node.args[1:5] | 
| 63 |  | -            assert bn_mean_node is not None, "Batchnorm mean node cannot be None." | 
| 64 |  | -            assert bn_var_node is not None, "Batchnorm var node cannot be None." | 
|  | 64 | +            if bn_mean_node is None: | 
|  | 65 | +                raise RuntimeError( | 
|  | 66 | +                    "BatchNorm mean buffer missing for node: " | 
|  | 67 | +                    f"{get_node_debug_info(node, graph_module)}" | 
|  | 68 | +                ) | 
|  | 69 | +            if bn_var_node is None: | 
|  | 70 | +                raise RuntimeError( | 
|  | 71 | +                    "BatchNorm variance buffer missing for node: " | 
|  | 72 | +                    f"{get_node_debug_info(node, graph_module)}" | 
|  | 73 | +                ) | 
| 65 | 74 | 
 | 
| 66 | 75 |             epsilon = node.args[-1] | 
| 67 | 76 | 
 | 
| @@ -133,14 +142,23 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:  # noqa: C901 | 
| 133 | 142 |                     input_node = new_input_node | 
| 134 | 143 |             else: | 
| 135 | 144 |                 input_weight_node, input_bias_node = input_node.args[1:3] | 
| 136 |  | -                assert ( | 
|  | 145 | +                if not ( | 
| 137 | 146 |                     isinstance(input_weight_node, Node) | 
| 138 | 147 |                     and input_weight_node.op == "placeholder" | 
| 139 |  | -                ), "Parameter weight of convolution must be a placeholder" | 
| 140 |  | -                assert (input_bias_node is None) or ( | 
| 141 |  | -                    isinstance(input_weight_node, Node) | 
| 142 |  | -                    and input_weight_node.op == "placeholder" | 
| 143 |  | -                ), "Parameter bias of convolution must be a placeholder or None" | 
|  | 148 | +                ): | 
|  | 149 | +                    raise RuntimeError( | 
|  | 150 | +                        "Parameter weight of convolution must be a placeholder" | 
|  | 151 | +                    ) | 
|  | 152 | +                if not ( | 
|  | 153 | +                    (input_bias_node is None) | 
|  | 154 | +                    or ( | 
|  | 155 | +                        isinstance(input_weight_node, Node) | 
|  | 156 | +                        and input_weight_node.op == "placeholder" | 
|  | 157 | +                    ) | 
|  | 158 | +                ): | 
|  | 159 | +                    raise RuntimeError( | 
|  | 160 | +                        "Parameter bias of convolution must be a placeholder or None" | 
|  | 161 | +                    ) | 
| 144 | 162 | 
 | 
| 145 | 163 |                 input_weight_tensor = torch.Tensor( | 
| 146 | 164 |                     get_param(self.exported_program, input_weight_node) | 
|  | 
0 commit comments