@@ -202,8 +202,12 @@ def embedding_2bit(
202202 weight_quant_max : int ,
203203 indices : torch .Tensor ,
204204) -> torch .Tensor :
205- assert weight_quant_min == - 2 , "embedding_2bit in ExecuTorch expects weight_quant_min == -2"
206- assert weight_quant_max == 1 , "embedding_2bit in ExecuTorch expects weight_quant_max == 1"
205+ assert (
206+ weight_quant_min == - 2
207+ ), "embedding_2bit in ExecuTorch expects weight_quant_min == -2"
208+ assert (
209+ weight_quant_max == 1
210+ ), "embedding_2bit in ExecuTorch expects weight_quant_max == 1"
207211
208212 embedding_weight_checks (weight , weight_scales , weight_zero_points )
209213 group_size = (4 * weight .size (1 )) // (
@@ -260,8 +264,12 @@ def embedding_2bit_dtype(
260264 indices : torch .Tensor ,
261265 dtype : Optional [torch .dtype ],
262266) -> torch .Tensor :
263- assert weight_quant_min == - 2 , "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2"
264- assert weight_quant_max == 1 , "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1"
267+ assert (
268+ weight_quant_min == - 2
269+ ), "embedding_2bit_dtype in ExecuTorch expects weight_quant_min == -2"
270+ assert (
271+ weight_quant_max == 1
272+ ), "embedding_2bit_dtype in ExecuTorch expects weight_quant_max == 1"
265273
266274 embedding_weight_checks (weight , weight_scales , weight_zero_points )
267275 group_size = (4 * weight .size (1 )) // (
@@ -340,8 +348,12 @@ def embedding_4bit(
340348 weight_quant_max : int ,
341349 indices : torch .Tensor ,
342350) -> torch .Tensor :
343- assert weight_quant_min == - 8 , "embedding_4bit in ExecuTorch expects weight_quant_min == -8"
344- assert weight_quant_max == 7 , "embedding_4bit in ExecuTorch expects weight_quant_max == 7"
351+ assert (
352+ weight_quant_min == - 8
353+ ), "embedding_4bit in ExecuTorch expects weight_quant_min == -8"
354+ assert (
355+ weight_quant_max == 7
356+ ), "embedding_4bit in ExecuTorch expects weight_quant_max == 7"
345357
346358 embedding_weight_checks (weight , weight_scales , weight_zero_points )
347359 group_size = (2 * weight .size (1 )) // (
@@ -396,8 +408,12 @@ def embedding_4bit_dtype(
396408 indices : torch .Tensor ,
397409 dtype : Optional [torch .dtype ],
398410) -> torch .Tensor :
399- assert weight_quant_min == - 8 , "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8"
400- assert weight_quant_max == 7 , "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7"
411+ assert (
412+ weight_quant_min == - 8
413+ ), "embedding_4bit_dtype in ExecuTorch expects weight_quant_min == -8"
414+ assert (
415+ weight_quant_max == 7
416+ ), "embedding_4bit_dtype in ExecuTorch expects weight_quant_max == 7"
401417
402418 embedding_weight_checks (weight , weight_scales , weight_zero_points )
403419 group_size = (2 * weight .size (1 )) // (
0 commit comments