@@ -1125,21 +1125,52 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11251125
11261126 # With larger block sizes, we can expect this to blow up.
11271127 # At blocksize>=1024, don't even bother looking at relerr.
1128- if blocksize <= 64 :
1129- assert err .item () < 0.1
1130- assert relerr .item () < 0.28
1131- elif blocksize <= 256 :
1132- assert err .item () < 0.11
1133- assert relerr .item () < 0.30
1134- elif blocksize <= 512 :
1135- assert err .item () < 0.12
1136- assert relerr .item () < 0.31
1137- elif quant_type == "fp4" :
1138- # 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
1139- assert err .item () < 0.08 + math .log2 (blocksize ) * 4e-2
1140- else :
1141- # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
1142- assert err .item () < math .log2 (blocksize ) * 8e-2
1128+ #
1129+ # Actually, the above is not true anymore after fixing the integer packing bug.
1130+ # The following values were taken from averaging 1k samples per test configuration after fixing the bug.
1131+ error_dict = dict ()
1132+ error_dict ["fp4" ] = dict ()
1133+ error_dict ["nf4" ] = dict ()
1134+ error_dict ["fp4" ]["err" ] = {
1135+ 64 : 0.096545 ,
1136+ 128 : 0.102947 ,
1137+ 256 : 0.108685 ,
1138+ 512 : 0.114087 ,
1139+ 1024 : 0.119312 ,
1140+ 2048 : 0.124460 ,
1141+ 4096 : 0.129573 ,
1142+ }
1143+ error_dict ["fp4" ]["rel_err" ] = {
1144+ 64 : 0.260130 ,
1145+ 128 : 0.275734 ,
1146+ 256 : 0.289842 ,
1147+ 512 : 0.302852 ,
1148+ 1024 : 0.314982 ,
1149+ 2048 : 0.326402 ,
1150+ 4096 : 0.337228 ,
1151+ }
1152+
1153+ error_dict ["nf4" ]["err" ] = {
1154+ 64 : 0.072792 ,
1155+ 128 : 0.076835 ,
1156+ 256 : 0.080326 ,
1157+ 512 : 0.083535 ,
1158+ 1024 : 0.086603 ,
1159+ 2048 : 0.089592 ,
1160+ 4096 : 0.092537 ,
1161+ }
1162+ error_dict ["nf4" ]["rel_err" ] = {
1163+ 64 : 0.203299 ,
1164+ 128 : 0.215252 ,
1165+ 256 : 0.226044 ,
1166+ 512 : 0.236021 ,
1167+ 1024 : 0.245365 ,
1168+ 2048 : 0.254146 ,
1169+ 4096 : 0.262457 ,
1170+ }
1171+
1172+ assert err < error_dict [quant_type ]["err" ][blocksize ] + 1e-3
1173+ assert relerr < error_dict [quant_type ]["rel_err" ][blocksize ] + 1e-3
11431174
11441175 @pytest .mark .parametrize ("device" , get_available_devices ())
11451176 @pytest .mark .parametrize ("quant_type" , ["fp4" , "nf4" ])
0 commit comments