File tree Expand file tree Collapse file tree 1 file changed +11
-4
lines changed Expand file tree Collapse file tree 1 file changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -46,13 +46,20 @@ def define_node(
46
46
input_zp = cast (int , node .args [3 ])
47
47
output_zp = cast (int , node .args [4 ])
48
48
49
- if input_dtype != map_dtype (torch .int8 , self .tosa_spec ) and input_zp != 0 :
49
+ if (
50
+ input_dtype
51
+ not in [
52
+ map_dtype (torch .int8 , self .tosa_spec ),
53
+ map_dtype (torch .int16 , self .tosa_spec ),
54
+ ]
55
+ and input_zp != 0
56
+ ):
50
57
raise ValueError (
51
- f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ input_dtype = } , { input_zp = } "
58
+ f"If input dtype is not int8 or int16 , input_zp must be 0. Got input_dtype{ input_dtype = } , { input_zp = } "
52
59
)
53
- if output_dtype != torch .int8 and output_zp != 0 :
60
+ if output_dtype not in [ torch .int8 , torch . int16 ] and output_zp != 0 :
54
61
raise ValueError (
55
- f"If output dtype is not int8, output_zp must be 0. Got { ts .DTypeNames [output_dtype ]} , { output_zp = } "
62
+ f"If output dtype is not int8 or int16 , output_zp must be 0. Got { ts .DTypeNames [output_dtype ]} , { output_zp = } "
56
63
)
57
64
58
65
build_rescale (
You can’t perform that action at this time.
0 commit comments