@@ -202,6 +202,9 @@ def embedding_2bit(
202202 weight_quant_max : int ,
203203 indices : torch .Tensor ,
204204) -> torch .Tensor :
205+ assert weight_quant_min == - 2 , "embedding_2bit in ExecuTorch expects weight_quant_min == -2"
206+ assert weight_quant_max == 1 , "embedding_2bit in ExecuTorch expects weight_quant_max == 1"
207+
205208 embedding_weight_checks (weight , weight_scales , weight_zero_points )
206209 group_size = (4 * weight .size (1 )) // (
207210 weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
@@ -257,6 +260,9 @@ def embedding_2bit_dtype(
257260 indices : torch .Tensor ,
258261 dtype : Optional [torch .dtype ],
259262) -> torch .Tensor :
263+ assert weight_quant_min == - 2 , "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2"
264+ assert weight_quant_max == 1 , "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1"
265+
260266 embedding_weight_checks (weight , weight_scales , weight_zero_points )
261267 group_size = (4 * weight .size (1 )) // (
262268 weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
@@ -334,6 +340,9 @@ def embedding_4bit(
334340 weight_quant_max : int ,
335341 indices : torch .Tensor ,
336342) -> torch .Tensor :
343+ assert weight_quant_min == - 8 , "embedding_4bit in ExecuTorch expects weight_quant_min == -8"
344+ assert weight_quant_max == 7 , "embedding_4bit in ExecuTorch expects weight_quant_max == 7"
345+
337346 embedding_weight_checks (weight , weight_scales , weight_zero_points )
338347 group_size = (2 * weight .size (1 )) // (
339348 weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
@@ -387,6 +396,9 @@ def embedding_4bit_dtype(
387396 indices : torch .Tensor ,
388397 dtype : Optional [torch .dtype ],
389398) -> torch .Tensor :
399+ assert weight_quant_min == - 8 , "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8"
400+ assert weight_quant_max == 7 , "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7"
401+
390402 embedding_weight_checks (weight , weight_scales , weight_zero_points )
391403 group_size = (2 * weight .size (1 )) // (
392404 weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
0 commit comments