@@ -105,7 +105,7 @@ def define_node(
105105 import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
106106
107107 validate_num_inputs (self .target , inputs , [3 , 4 , 6 ])
108- validate_same_dtype (self .target , [inputs [0 ], output ])
108+ validate_same_dtype (self .target , [inputs [0 ], output ], ts )
109109
110110 supported_dtypes = [ts .DType .INT8 ]
111111 if inputs [0 ].dtype not in supported_dtypes :
@@ -145,7 +145,7 @@ def define_node(
145145 import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
146146
147147 validate_num_inputs (self .target , inputs , [3 , 4 , 6 ])
148- validate_same_dtype (self .target , [inputs [0 ], output ])
148+ validate_same_dtype (self .target , [inputs [0 ], output ], ts )
149149
150150 supported_dtypes = [ts .DType .INT8 , ts .DType .FP32 ]
151151 if inputs [0 ].dtype not in supported_dtypes :
@@ -252,7 +252,7 @@ def define_node(
252252 import serializer .tosa_serializer as ts # type: ignore
253253
254254 validate_num_inputs (self .target , inputs , [3 , 4 , 6 ])
255- validate_same_dtype (self .target , [inputs [0 ], output ])
255+ validate_same_dtype (self .target , [inputs [0 ], output ], ts )
256256
257257 supported_dtypes = [ts .DType .INT8 ]
258258 if inputs [0 ].dtype not in supported_dtypes :
@@ -295,7 +295,7 @@ def define_node(
295295 import serializer .tosa_serializer as ts # type: ignore
296296
297297 validate_num_inputs (self .target , inputs , [3 , 4 , 6 ])
298- validate_same_dtype (self .target , [inputs [0 ], output ])
298+ validate_same_dtype (self .target , [inputs [0 ], output ], ts )
299299
300300 supported_dtypes = [ts .DType .INT8 , ts .DType .FP32 ]
301301 if inputs [0 ].dtype not in supported_dtypes :
0 commit comments