@@ -40,9 +40,19 @@ def define_node(
4040 ) -> None :
4141 # Specification (0.80) states that input and output types
4242 # should all be the same
43- assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype
43+ if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
44+ raise TypeError (
45+ f"All IO needs to have the same data type, got input 1: "
46+ f"{ inputs [0 ].dtype } , input 2: { inputs [1 ].dtype } and output: "
47+ f"{ output .dtype } "
48+ )
49+
4450 # Handle int8 (quantized) and int32
45- assert inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]
51+ supported_dtypes = [ts .DType .INT8 , ts .DType .INT32 ]
52+ if inputs [0 ].dtype not in supported_dtypes :
53+ raise TypeError (
54+ f'IO data type needs to be { supported_dtypes } , got "{ inputs [0 ].dtype } "'
55+ )
4656
4757 if inputs [0 ].dtype == ts .DType .INT8 :
4858 rescaled_inputs , scale_back = tqutils .insert_rescale_ops_to_int32 (
@@ -97,15 +107,27 @@ def define_node(
97107 ) -> None :
98108 # Specification (0.80) states that input and output types
99109 # should all be the same
100- assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype
110+ if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
111+ raise TypeError (
112+ f"All IO needs to have the same data type, got input 1: "
113+ f"{ inputs [0 ].dtype } , input 2: { inputs [1 ].dtype } and output: "
114+ f"{ output .dtype } "
115+ )
101116
102117 if inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]:
103118 # Call the inherited define_node for handling integers
104119 super ().define_node (node , tosa_graph , inputs , output )
105120 else :
106121 # FP32 Sub lowering
107- assert inputs [0 ].dtype == ts .DType .FP32
108- assert output .dtype == ts .DType .FP32
122+ if (
123+ inputs [0 ].dtype != ts .DType .FP32
124+ or inputs [1 ].dtype != ts .DType .FP32
125+ or output .dtype != ts .DType .FP32
126+ ):
127+ raise TypeError (
128+ f"All IO needs to have data type fp32. Got: { inputs [0 ].dtype } , "
129+ f"input 2: { inputs [1 ].dtype } and output: { output .dtype } "
130+ )
109131
110132 # MI lowering
111133 tosa_graph .addOperator (
0 commit comments