Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions backends/arm/operators/op_upsample_bilinear2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:
assert (
inputs[0].shape is not None and output.shape is not None
), "Only static shapes are supported"
if inputs[0].shape is None or output.shape is None:
raise ValueError("Only static shapes are supported")

input_dtype = inputs[0].dtype

Expand All @@ -55,9 +54,12 @@ def define_node(
def in_int16_range(x):
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)

assert in_int16_range(scale_n_yx)
assert in_int16_range(scale_d_yx)
assert in_int16_range(border_yx)
if not in_int16_range(scale_n_yx):
raise ValueError("scale_n_yx is out of the int16 range")
if not in_int16_range(scale_d_yx):
raise ValueError("scale_d_yx is out of the int16 range")
if not in_int16_range(border_yx):
raise ValueError("border_yx is out of the int16 range")

attr = ts.TosaSerializerAttribute()
attr.ResizeAttribute(
Expand Down
14 changes: 8 additions & 6 deletions backends/arm/operators/op_upsample_nearest2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ def define_node(
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

assert (
inputs[0].shape is not None and output.shape is not None
), "Only static shapes are supported"
if inputs[0].shape is None or output.shape is None:
raise ValueError("Only static shapes are supported")

# tosa_shape output is NHWC, take HW
input_size_yx = torch.tensor(
Expand All @@ -55,9 +54,12 @@ def define_node(
def in_int16_range(x):
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)

assert in_int16_range(scale_n_yx)
assert in_int16_range(scale_d_yx)
assert in_int16_range(border_yx)
if not in_int16_range(scale_n_yx):
raise ValueError("scale_n_yx is out of the int16 range")
if not in_int16_range(scale_d_yx):
raise ValueError("scale_d_yx is out of the int16 range")
if not in_int16_range(border_yx):
raise ValueError("border_yx is out of the int16 range")

attr = ts.TosaSerializerAttribute()
attr.ResizeAttribute(
Expand Down
Loading