Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions backends/arm/operators/op_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,11 @@ def define_node(
validate_same_dtype(self.target, [*inputs, output])

# Handle int8 (quantized) and int32
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
if inputs[0].dtype not in supported_dtypes:
raise TypeError(
f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"'
)

scale_back = 1.0
if inputs[0].dtype == ts.DType.INT8:
Expand Down Expand Up @@ -228,8 +232,15 @@ def define_node(
super().define_node(node, tosa_graph, inputs, output)
else:
# FP32 Sub lowering
assert inputs[0].dtype == ts.DType.FP32
assert output.dtype == ts.DType.FP32
if (
inputs[0].dtype != ts.DType.FP32
or inputs[1].dtype != ts.DType.FP32
or output.dtype != ts.DType.FP32
):
raise TypeError(
f"All IO needs to have data type fp32. Got: {inputs[0].dtype}, "
f"input 2: {inputs[1].dtype} and output: {output.dtype}"
)

# MI lowering
tosa_graph.addOperator(
Expand Down
9 changes: 6 additions & 3 deletions backends/arm/operators/op_upsample_bilinear2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,12 @@ def define_node(
def in_int16_range(x):
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)

assert in_int16_range(scale_n_yx)
assert in_int16_range(scale_d_yx)
assert in_int16_range(border_yx)
if not in_int16_range(scale_n_yx):
raise ValueError("scale_n_yx is out of the int16 range")
if not in_int16_range(scale_d_yx):
raise ValueError("scale_d_yx is out of the int16 range")
if not in_int16_range(border_yx):
raise ValueError("border_yx is out of the int16 range")

scales = [scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]]

Expand Down
14 changes: 8 additions & 6 deletions backends/arm/operators/op_upsample_nearest2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ def define_node(
validate_num_inputs(self.target, inputs, 3)
validate_same_dtype(self.target, [inputs[0], output])

assert (
inputs[0].shape is not None and output.shape is not None
), "Only static shapes are supported"
if inputs[0].shape is None or output.shape is None:
raise ValueError("Only static shapes are supported")

# tosa_shape output is NHWC, take HW
input_size_yx = torch.tensor(
Expand All @@ -121,9 +120,12 @@ def define_node(
def in_int16_range(x):
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)

assert in_int16_range(scale_n_yx)
assert in_int16_range(scale_d_yx)
assert in_int16_range(border_yx)
if not in_int16_range(scale_n_yx):
raise ValueError("scale_n_yx is out of the int16 range")
if not in_int16_range(scale_d_yx):
raise ValueError("scale_d_yx is out of the int16 range")
if not in_int16_range(border_yx):
raise ValueError("border_yx is out of the int16 range")

scales = [scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]]
scales_tensor = tosa_graph.addConst(
Expand Down
Loading