@@ -378,17 +378,17 @@ def forward(self, indices):
378378 # )
379379
380380 def test_embedding_torchao (self ) -> None :
381- for bit_width , test_dtype_variant , test_per_group in zip (
381+ for bit_width , use_dtype_variant , test_per_group in zip (
382382 [2 , 4 , 8 ], [True , False ], [True , False ]
383383 ):
384- self ._test_embedding_torchao (bit_width , test_dtype_variant , test_per_group )
384+ self ._test_embedding_torchao (bit_width , use_dtype_variant , test_per_group )
385385
386386 def _test_embedding_torchao (
387- self , bit_width : int , test_dtype_variant : bool , test_per_group : bool
387+ self , bit_width : int , use_dtype_variant : bool , test_per_group : bool
388388 ) -> None :
389389 assert bit_width in [2 , 4 , 8 ]
390390 embedding_suffix = f"{ bit_width } bit" if bit_width < 8 else "byte"
391- if test_dtype_variant :
391+ if use_dtype_variant :
392392 embedding_suffix = f"{ embedding_suffix } _dtype"
393393
394394 indices = torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 )
@@ -399,7 +399,7 @@ def _test_embedding_torchao(
399399
400400 # torchao adds a dtype cast to match embeddings original weight type
401401 # this does not happen for float32 because it is the default dtype
402- model = model .to (torch .float16 ) if test_dtype_variant else model
402+ model = model .to (torch .float16 ) if use_dtype_variant else model
403403
404404 # quantize the model
405405 granularity = PerGroup (32 ) if test_per_group else PerAxis (0 )
0 commit comments