@@ -44,7 +44,7 @@ def define_node(
4444 import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
4545
4646 validate_num_inputs (self .target , inputs , 1 )
47- validate_same_dtype (self .target , [* inputs , output ])
47+ validate_same_dtype (self .target , [* inputs , output ], ts )
4848
4949 # Handle int8 (quantized) and int32
5050 if not (inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]):
@@ -106,7 +106,7 @@ def define_node(
106106 import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
107107
108108 validate_num_inputs (self .target , inputs , 1 )
109- validate_same_dtype (self .target , [* inputs , output ])
109+ validate_same_dtype (self .target , [* inputs , output ], ts )
110110
111111 if inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]:
112112 # Call the inherited define_node for handling integers
@@ -153,7 +153,7 @@ def define_node(
153153 import serializer .tosa_serializer as ts # type: ignore
154154
155155 validate_num_inputs (self .target , inputs , 1 )
156- validate_same_dtype (self .target , [* inputs , output ])
156+ validate_same_dtype (self .target , [* inputs , output ], ts )
157157
158158 # Handle int8 (quantized) and int32
159159 if not (inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]):
@@ -216,7 +216,7 @@ def define_node(
216216 import serializer .tosa_serializer as ts # type: ignore
217217
218218 validate_num_inputs (self .target , inputs , 1 )
219- validate_same_dtype (self .target , [* inputs , output ])
219+ validate_same_dtype (self .target , [* inputs , output ], ts )
220220
221221 if inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]:
222222 # Call the inherited define_node for handling integers
0 commit comments