@@ -41,9 +41,18 @@ def define_node(
4141 ) -> None :
4242 # Specification (0.80) states that input and output types
4343 # should all be the same
44- assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype
44+ if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
45+ raise TypeError (
46+ f"All IO needs to have the same data type, got input 1: "
47+ f"{ inputs [0 ].dtype } , input 2: { inputs [1 ].dtype } and output: "
48+ f"{ output .dtype } "
49+ )
4550 # Handle int8 (quantized) and int32
46- assert inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]
51+ supported_dtypes = [ts .DType .INT8 , ts .DType .INT32 ]
52+ if inputs [0 ].dtype not in supported_dtypes :
53+ raise TypeError (
54+ f'IO data type needs to be { supported_dtypes } , got "{ inputs [0 ].dtype } "'
55+ )
4756
4857 dim_order = (
4958 inputs [0 ].dim_order
@@ -105,15 +114,22 @@ def define_node(
105114 ) -> None :
106115 # Specification (0.80) states that input and output types
107116 # should all be the same
108- assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype
117+ if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
118+ raise TypeError (
119+ f"All IO needs to have the same data type, got input 1: "
120+ f"{ inputs [0 ].dtype } , input 2: { inputs [1 ].dtype } and output: "
121+ f"{ output .dtype } "
122+ )
109123
110124 if inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]:
111125 # Call the inherited define_node for handling integers
112126 super ().define_node (node , tosa_graph , inputs , output )
113127 else :
114128 # FP32 Add lowering
115- assert inputs [0 ].dtype == ts .DType .FP32
116- assert output .dtype == ts .DType .FP32
129+ if inputs [0 ].dtype != ts .DType .FP32 :
130+ raise TypeError (
131+ f"Expected IO data type to be FP32, got { inputs [0 ].dtype } "
132+ )
117133
118134 input1 , input2 = tutils .reshape_for_broadcast (tosa_graph , inputs )
119135
0 commit comments