@@ -1125,75 +1125,46 @@ def quantized_relu_common(
1125
1125
1126
1126
1127
1127
def quantized_relu_variant (
1128
- per_tensor : bool ,
1129
1128
dtype : torch .dtype | None = None ,
1130
1129
) -> Callable [[Callable [..., torch .Tensor ]], Callable [..., torch .Tensor ]]:
1131
1130
"""Create a quantized relu variant with type checking."""
1132
1131
1133
1132
def decorator (_ : Callable [..., torch .Tensor ]) -> Callable [..., torch .Tensor ]:
1134
1133
def variant (
1135
1134
X : torch .Tensor ,
1136
- X_zero_point : torch . Tensor | int ,
1135
+ X_zero_point : int ,
1137
1136
out_zero_point : int ,
1138
- out_multiplier : torch . Tensor | int ,
1139
- out_shift : torch . Tensor | int ,
1137
+ out_multiplier : int ,
1138
+ out_shift : int ,
1140
1139
) -> torch .Tensor :
1141
- if per_tensor :
1142
- if dtype and X .dtype != dtype :
1143
- raise ValueError (f"X dtype must be { dtype } . Got { X .dtype } " )
1144
-
1145
- assert isinstance (out_shift , int )
1146
- assert isinstance (out_multiplier , int )
1147
- _out_shift = out_shift
1148
- _out_multiplier = out_multiplier
1149
- else :
1150
- assert isinstance (out_multiplier , torch .Tensor )
1151
- if out_multiplier .numel () > 1 :
1152
- raise ValueError ("Only scalar out_multiplier is supported" )
1153
-
1154
- assert isinstance (out_shift , torch .Tensor )
1155
- if out_shift .numel () > 1 :
1156
- raise ValueError ("Only scalar out_shift is supported" )
1157
-
1158
- assert isinstance (X_zero_point , torch .Tensor )
1159
- if X_zero_point .shape != X .shape :
1160
- raise ValueError (
1161
- f"X_zero_point shape must be { X .shape } . Got { X_zero_point .shape } "
1162
- )
1163
-
1164
- _out_multiplier = int (out_multiplier .item ())
1165
- _out_shift = int (out_shift .item ())
1140
+ if dtype and X .dtype != dtype :
1141
+ raise ValueError (f"X dtype must be { dtype } . Got { X .dtype } " )
1166
1142
1167
1143
return quantized_relu_common (
1168
1144
X ,
1169
1145
X_zero_point ,
1170
1146
out_zero_point ,
1171
- _out_multiplier ,
1172
- _out_shift ,
1147
+ out_multiplier ,
1148
+ out_shift ,
1173
1149
)
1174
1150
1175
1151
return variant
1176
1152
1177
1153
return decorator
1178
1154
1179
1155
1180
- @impl (m , "quantized_relu" )
1181
- @quantized_relu_variant (False )
1182
- def quantized_relu () -> torch .Tensor : ...
1183
-
1184
-
1185
1156
@impl (m , "quantized_relu.per_tensor" )
1186
- @quantized_relu_variant (True )
1157
+ @quantized_relu_variant ()
1187
1158
def quantized_relu_per_tensor () -> torch .Tensor : ...
1188
1159
1189
1160
1190
1161
@impl (m , "quantized_relu_asym8s_asym8s.per_tensor" )
1191
- @quantized_relu_variant (True , torch .int8 )
1162
+ @quantized_relu_variant (torch .int8 )
1192
1163
def quantized_relu_asym8s_asym8s_per_tensor () -> torch .Tensor : ...
1193
1164
1194
1165
1195
1166
@impl (m , "quantized_relu_asym8u_asym8u.per_tensor" )
1196
- @quantized_relu_variant (True , torch .uint8 )
1167
+ @quantized_relu_variant (torch .uint8 )
1197
1168
def quantized_relu_asym8u_asym8u_per_tensor () -> torch .Tensor : ...
1198
1169
1199
1170
0 commit comments