@@ -85,8 +85,12 @@ def define_node(
8585 ) -> None :
8686 import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
8787
88- input_tensor = inputs [0 ]
89- assert input_tensor .dtype == ts .DType .INT8
88+ supported_dtypes = [ts .DType .INT8 ]
89+ if inputs [0 ].dtype not in supported_dtypes :
90+ raise TypeError (
91+ f"IO data type needs to be one of { supported_dtypes } , got "
92+ f'"{ inputs [0 ].dtype } "'
93+ )
9094
9195 accumulator_type = ts .DType .INT32
9296
@@ -118,9 +122,12 @@ def define_node(
118122 ) -> None :
119123 import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
120124
121- assert (
122- inputs [0 ].dtype == ts .DType .INT8 or inputs [0 ].dtype == ts .DType .FP32
123- ), "Only FP32 and INT8 supported"
125+ supported_dtypes = [ts .DType .INT8 , ts .DType .FP32 ]
126+ if inputs [0 ].dtype not in supported_dtypes :
127+ raise TypeError (
128+ f"IO data type needs to be one of { supported_dtypes } , got "
129+ f'"{ inputs [0 ].dtype } "'
130+ )
124131
125132 if inputs [0 ].dtype == ts .DType .INT8 :
126133 super ().define_node (node , tosa_graph , inputs , output )
@@ -205,8 +212,12 @@ def define_node(
205212 ) -> None :
206213 import serializer .tosa_serializer as ts # type: ignore
207214
208- input_tensor = inputs [0 ]
209- assert input_tensor .dtype == ts .DType .INT8
215+ supported_dtypes = [ts .DType .INT8 ]
216+ if inputs [0 ].dtype not in supported_dtypes :
217+ raise TypeError (
218+ f"IO data type needs to be one of { supported_dtypes } , got "
219+ f'"{ inputs [0 ].dtype } "'
220+ )
210221
211222 accumulator_type = ts .DType .INT32
212223
@@ -241,9 +252,12 @@ def define_node(
241252 ) -> None :
242253 import serializer .tosa_serializer as ts # type: ignore
243254
244- assert (
245- inputs [0 ].dtype == ts .DType .INT8 or inputs [0 ].dtype == ts .DType .FP32
246- ), "Only FP32 and INT8 supported"
255+ supported_dtypes = [ts .DType .INT8 , ts .DType .FP32 ]
256+ if inputs [0 ].dtype not in supported_dtypes :
257+ raise TypeError (
258+ f"IO data type needs to be one of { supported_dtypes } , got "
259+ f'"{ inputs [0 ].dtype } "'
260+ )
247261
248262 if inputs [0 ].dtype == ts .DType .INT8 :
249263 super ().define_node (node , tosa_graph , inputs , output )
0 commit comments