Skip to content

Commit f198c27

Browse files
Arm backend: Convert remaining asserts that should be exceptions (#14590)
Signed-off-by: Sebastian Larsson <[email protected]>
1 parent fd5f946 commit f198c27

File tree

4 files changed

+50
-15
lines changed

4 files changed

+50
-15
lines changed

backends/arm/_passes/fuse_batchnorm2d_pass.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
create_node,
1313
get_first_fake_tensor,
1414
)
15+
from executorch.backends.arm.common.debug import get_node_debug_info
1516
from executorch.backends.transforms.utils import (
1617
create_constant_placeholder,
1718
delete_constant_placeholder,
@@ -60,8 +61,16 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
6061
input_node = node.all_input_nodes[0]
6162
is_single_user = len(input_node.users) == 1
6263
bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = node.args[1:5]
63-
assert bn_mean_node is not None, "Batchnorm mean node cannot be None."
64-
assert bn_var_node is not None, "Batchnorm var node cannot be None."
64+
if bn_mean_node is None:
65+
raise RuntimeError(
66+
"BatchNorm mean buffer missing for node: "
67+
f"{get_node_debug_info(node, graph_module)}"
68+
)
69+
if bn_var_node is None:
70+
raise RuntimeError(
71+
"BatchNorm variance buffer missing for node: "
72+
f"{get_node_debug_info(node, graph_module)}"
73+
)
6574

6675
epsilon = node.args[-1]
6776

@@ -133,14 +142,23 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
133142
input_node = new_input_node
134143
else:
135144
input_weight_node, input_bias_node = input_node.args[1:3]
136-
assert (
145+
if not (
137146
isinstance(input_weight_node, Node)
138147
and input_weight_node.op == "placeholder"
139-
), "Parameter weight of convolution must be a placeholder"
140-
assert (input_bias_node is None) or (
141-
isinstance(input_weight_node, Node)
142-
and input_weight_node.op == "placeholder"
143-
), "Parameter bias of convolution must be a placeholder or None"
148+
):
149+
raise RuntimeError(
150+
"Parameter weight of convolution must be a placeholder"
151+
)
152+
if not (
153+
(input_bias_node is None)
154+
or (
155+
isinstance(input_weight_node, Node)
156+
and input_weight_node.op == "placeholder"
157+
)
158+
):
159+
raise RuntimeError(
160+
"Parameter bias of convolution must be a placeholder or None"
161+
)
144162

145163
input_weight_tensor = torch.Tensor(
146164
get_param(self.exported_program, input_weight_node)

backends/arm/arm_vela.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ def vela_bin_pack_io(prefix, data):
3434
io_elem_size = data[prefix + "_elem_size"][i]
3535
io_offset = data[prefix + "_offset"][i]
3636
io_region = data[prefix + "_region"][i]
37-
assert len(io_shape) == vela_io_shape_dims
37+
if len(io_shape) != vela_io_shape_dims:
38+
raise ValueError(
39+
f"Expected {vela_io_shape_dims}D shape, got {len(io_shape)}D"
40+
)
3841
inp_pad = io_shape.tolist()
3942
io_struct = struct.pack(
4043
"<iiiiiiiii", *inp_pad, io_elem_size, io_offset, io_region

backends/arm/common/arm_compile_spec.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def validate(self):
126126

127127
def to_list(self):
128128
"""Get the ArmCompileSpec in list form."""
129-
assert self.tosa_spec
129+
if not self.tosa_spec:
130+
raise ValueError("tosa_spec must be set before calling to_list()")
130131

131132
# Always supply a TOSA version
132133
compile_spec = [

backends/arm/process_node.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,16 @@ def process_inputs_to_parameters(
106106
) from e
107107
parameter_data = get_param(edge_program, node)
108108

109-
assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor"
109+
if not isinstance(parameter_data, torch.Tensor):
110+
raise TypeError(
111+
f"Expected parameter '{node.name}' to be a torch.Tensor, got "
112+
f"{type(parameter_data).__name__}"
113+
)
110114
parameter_values = parameter_data.detach().numpy()
111115

112116
if tosa_arg.dtype == torch.float32:
113-
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
117+
if not tosa_spec.support_float():
118+
raise ValueError(f"{tosa_spec} doesn't support float operations")
114119

115120
# Handle special case for INT48 tensors
116121
special_type = node.meta.get(TosaSpecialDtype.meta_key(), None)
@@ -142,7 +147,11 @@ def process_inputs_to_buffers(
142147
) from e
143148
buffer_data = get_buffer(edge_program, node)
144149

145-
assert isinstance(buffer_data, torch.Tensor), "Expect Attr to be tensor"
150+
if not isinstance(buffer_data, torch.Tensor):
151+
raise TypeError(
152+
f"Expected buffer '{node.name}' to be a torch.Tensor, got "
153+
f"{type(buffer_data).__name__}"
154+
)
146155
buffer_values = buffer_data.detach().numpy()
147156

148157
# TODO: fragile code for temporary fix
@@ -183,8 +192,12 @@ def process_placeholder(
183192
tosa_spec: TosaSpecification,
184193
):
185194
"""Wrapper for processing and serializing all types of placeholders"""
186-
assert node.name == node.target, "Expect placeholder name and target to match"
187-
assert 0 == len(node.args), "Can't handle default input values"
195+
if node.name != node.target:
196+
raise ValueError(
197+
f"Placeholder name '{node.name}' does not match target '{node.target}'"
198+
)
199+
if len(node.args) != 0:
200+
raise ValueError(f"Placeholder '{node.name}' must not have default values")
188201

189202
if node.name in edge_program.graph_signature.user_inputs:
190203
process_inputs(node, tosa_graph, tosa_spec)

0 commit comments

Comments
 (0)