@@ -639,7 +639,7 @@ def test_quantization(self):
639639            ("int8wo" , np .array ([0.0505 , 0.0742 , 0.1367 , 0.0429 , 0.0585 , 0.1386 , 0.0585 , 0.0703 , 0.1367 , 0.0566 , 0.0703 , 0.1464 , 0.0546 , 0.0703 , 0.1425 , 0.0546 , 0.3535 , 0.7578 , 0.5000 , 0.4062 , 0.7656 , 0.5117 , 0.4121 , 0.7656 , 0.5117 , 0.3984 , 0.7578 , 0.5234 , 0.4023 , 0.7382 , 0.5390 , 0.4570 ])),
640640            ("int8dq" , np .array ([0.0546 , 0.0761 , 0.1386 , 0.0488 , 0.0644 , 0.1425 , 0.0605 , 0.0742 , 0.1406 , 0.0625 , 0.0722 , 0.1523 , 0.0625 , 0.0742 , 0.1503 , 0.0605 , 0.3886 , 0.7968 , 0.5507 , 0.4492 , 0.7890 , 0.5351 , 0.4316 , 0.8007 , 0.5390 , 0.4179 , 0.8281 , 0.5820 , 0.4531 , 0.7812 , 0.5703 , 0.4921 ])),
641641        ]
642-   
642+ 
643643        if  TorchAoConfig ._is_cuda_capability_atleast_8_9 ():
644644            QUANTIZATION_TYPES_TO_TEST .extend ([
645645                ("float8wo_e4m3" , np .array ([0.0546 , 0.0722 , 0.1328 , 0.0468 , 0.0585 , 0.1367 , 0.0605 , 0.0703 , 0.1328 , 0.0625 , 0.0703 , 0.1445 , 0.0585 , 0.0703 , 0.1406 , 0.0605 , 0.3496 , 0.7109 , 0.4843 , 0.4042 , 0.7226 , 0.5000 , 0.4160 , 0.7031 , 0.4824 , 0.3886 , 0.6757 , 0.4667 , 0.3710 , 0.6679 , 0.4902 , 0.4238 ])),
@@ -680,7 +680,7 @@ def test_serialization(self):
680680
681681        loaded_output  =  loaded_pipe (** inputs )[0 ].flatten ()
682682        self .assertTrue (np .allclose (output , loaded_output , atol = 1e-3 , rtol = 1e-3 ))
683-      
683+ 
684684    def  test_memory_footprint_int4wo (self ):
685685        # The original checkpoints are in bf16 and about 24 GB 
686686        expected_memory_in_gb  =  6.0 
@@ -693,7 +693,7 @@ def test_memory_footprint_int4wo(self):
693693        )
694694        int4wo_memory_in_gb  =  get_model_size_in_bytes (transformer ) /  1024 ** 3 
695695        self .assertTrue (int4wo_memory_in_gb  <  expected_memory_in_gb )
696-      
696+ 
697697    def  test_memory_footprint_int8wo (self ):
698698        # The original checkpoints are in bf16 and about 24 GB 
699699        expected_memory_in_gb  =  12.0 
0 commit comments