@@ -2262,7 +2262,7 @@ def test_fp4_quant(dtype):
22622262 A2 = F .dequantize_fp4 (qa , SA )
22632263
22642264 err = (A1 - A2 ).abs ().float ()
2265- relerr = (err / A1 .abs ().float ()).mean ()
2265+ relerr = (err / ( A1 .abs ().float () + 1e-8 )).mean ()
22662266 idx = err > 1.0
22672267 err = err .mean ()
22682268
@@ -2361,91 +2361,133 @@ def test_normal_map_tree():
23612361
23622362@pytest .mark .parametrize ("double_quant" , [True , False ], ids = ['DQ_True' , 'DQ_False' ])
23632363@pytest .mark .parametrize ("storage_type" , ['nf4' , 'fp4' ], ids = ['nf4' , 'fp4' ])
2364+ @pytest .mark .parametrize ("kind" , ['fc1' , 'fc2' , 'attn' , 'attn_packed' ], ids = ['fc1' , 'fc2' , 'attn' , 'attn_packed' ])
23642365@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = ['fp16' , 'bf16' , 'fp32' ])
2365- def test_gemv_4bit (dtype , storage_type , double_quant ):
2366- print ('' )
2367- for dim in [128 , 256 , 512 , 1024 , 2048 , 4096 ]:
2366+ def test_gemv_4bit (dtype , storage_type , double_quant , kind ):
2367+ for dim in [128 , 256 , 512 , 1024 , 2048 , 4096 , 6144 ]:
23682368 #for dim in [4*1024]:
2369- #for dim in [1*16]:
2370- errs = []
2371- relerrs = []
2372- max_err = 0
2373- max_relerr = 0
2369+ #for dim in [1*128]:
2370+ errs1 = []
2371+ errs2 = []
2372+ errs3 = []
2373+ relerrs1 = []
2374+ relerrs2 = []
2375+ relerrs3 = []
2376+ max_errs1 = []
2377+ max_errs2 = []
2378+ max_errs3 = []
2379+
23742380
23752381 for i in range (100 ):
2376- #A = torch.rand(2, 4092, dtype=dtype, device='cuda')
2377- #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
2378- #A = torch.rand(1, 4096, dtype=dtype, device='cuda')
2379- #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
2380- A = torch .randn (1 , dim , dtype = dtype , device = 'cuda' )
2381- #B = torch.randn(4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
2382- B = torch .randn (dim * 4 , dim , dtype = dtype , device = 'cuda' )/ math .sqrt (dim )
2383-
2384- #print('')
2385- #print(A)
2386- #print(B.t())
2387- #A[:, :-1] = 0
2388- #B[:, :-1] = 0
2389- #A.flatten()[:-1] = 0
2390- #B.flatten()[:-1] = 0
2382+ if kind == 'fc1' :
2383+ A = torch .randn (1 , dim , dtype = dtype , device = 'cuda' )
2384+ B = torch .randn (dim * 4 , dim , dtype = dtype , device = 'cuda' )/ math .sqrt (dim )
2385+ elif kind == 'fc2' :
2386+ A = torch .randn (1 , 4 * dim , dtype = dtype , device = 'cuda' )
2387+ B = torch .randn (dim , 4 * dim , dtype = dtype , device = 'cuda' )/ math .sqrt (dim )
2388+ elif kind == 'attn' :
2389+ A = torch .randn (1 , dim , dtype = dtype , device = 'cuda' )
2390+ B = torch .randn (dim , dim , dtype = dtype , device = 'cuda' )/ math .sqrt (dim )
2391+ elif kind == 'attn_packed' :
2392+ A = torch .randn (1 , dim , dtype = dtype , device = 'cuda' )
2393+ B = torch .randn (dim * 3 , dim , dtype = dtype , device = 'cuda' )/ math .sqrt (dim )
23912394
23922395 qB , state = F .quantize_4bit (B , quant_type = storage_type , compress_statistics = double_quant )
2393- #F.dequantize_4bit(qB, state)
2394-
23952396 C3 = torch .matmul (A , B .t ())
23962397 C2 = F .gemv_4bit (A , qB .t (), state = state )
23972398 A .requires_grad = True
23982399 C1 = bnb .matmul_4bit (A , qB .t (), state )
23992400
2400- #print(state)
2401- #print(qB)
2401+ err1 = (C1 - C2 ).abs ().float ()
2402+ err2 = (C3 - C2 ).abs ().float ()
2403+ err3 = (C3 - C1 ).abs ().float ()
2404+
2405+ mag1 = torch .abs (C1 ).float ()+ 1e-5
2406+ mag2 = torch .abs (C3 ).float ()+ 1e-5
2407+ mag3 = torch .abs (C3 ).float ()+ 1e-5
2408+
2409+ relerr1 = err1 / mag1
2410+ relerr2 = err2 / mag2
2411+ relerr3 = err3 / mag3
24022412
2403- #print('')
2404- #print(A)
2405- #print(B)
2406- #print('='*89)
2407- #print(C3)
2413+ max_err1 = err1 .max ()
2414+ max_err2 = err2 .max ()
2415+ max_err3 = err3 .max ()
24082416
2409- #print(C1.shape, C2.shape)
2417+ errs1 .append (err1 .mean ().item ())
2418+ errs2 .append (err2 .mean ().item ())
2419+ errs3 .append (err3 .mean ().item ())
24102420
2411- # tensor cores are non-deterministic
2412- # so we need to analyze errors around the mean
2413- # to test our implementation
2414- err = torch .abs (C1 - C2 ).float ()
2415- mag = torch .abs (C1 ).float ()+ 1e-5
2416- relerr = err / mag
2417- max_err = max (err .max (), max_err )
2418- max_relerr = max (relerr .max (), max_relerr )
2419- err = err .mean ().item ()
2420- relerr = relerr .mean ().item ()
2421- #print(err)
2421+ relerrs1 .append (relerr1 .mean ().item ())
2422+ relerrs2 .append (relerr2 .mean ().item ())
2423+ relerrs3 .append (relerr3 .mean ().item ())
24222424
2423- errs .append (err )
2424- relerrs .append (relerr )
2425+ max_errs1 .append (max_err1 .item ())
2426+ max_errs2 .append (max_err2 .item ())
2427+ max_errs3 .append (max_err3 .item ())
24252428
24262429 c = int (C1 .numel ()* 0.0014 * (dim / 256 ))+ 1
24272430
24282431 c = assert_all_approx_close (C1 , C2 , 1e-5 , 0.01 , count = c , throw = False )
2429- #print('')
2430- #print(dim, sum(errs)/len(errs)/math.sqrt(dim))
2431- #print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
2432- #print(dim, (max_err.item(), max_relerr.item()))
2433- print (C1 .flatten ()[- 20 :])
2434- print (C2 .flatten ()[- 20 :])
2435- #print(C1.flatten())
2436- #print(C2.flatten())
2437- #print(C3.flatten()[-20:])
2438- print (sum (errs )/ len (errs )/ math .sqrt (dim ) , dim )
2439- print (sum (relerrs )/ len (relerrs )/ math .sqrt (dim ) , dim )
2432+ err1 = sum (errs1 )/ len (errs1 )/ math .sqrt (dim )
2433+ err2 = sum (errs2 )/ len (errs2 )/ math .sqrt (dim )
2434+ err3 = sum (errs3 )/ len (errs3 )/ math .sqrt (dim )
2435+ relerr1 = sum (relerrs1 )/ len (relerrs1 )/ math .sqrt (dim )
2436+ relerr2 = sum (relerrs2 )/ len (relerrs2 )/ math .sqrt (dim )
2437+ relerr3 = sum (relerrs3 )/ len (relerrs3 )/ math .sqrt (dim )
2438+ maxerr1 = sum (max_errs1 )/ len (max_errs1 )/ math .sqrt (dim )
2439+ maxerr2 = sum (max_errs2 )/ len (max_errs2 )/ math .sqrt (dim )
2440+ maxerr3 = sum (max_errs3 )/ len (max_errs3 )/ math .sqrt (dim )
2441+ absratio = err2 / err3
2442+ relratio = relerr2 / relerr3
2443+ maxratio = relerr2 / relerr3
2444+
2445+ # for debugging if the tests fails
2446+ #
2447+ #print('='*80)
2448+ #print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
2449+ #print(C1.flatten()[-20:])
2450+ #print(C2.flatten()[-20:])
2451+ #print(f'inference vs training abs: {err1}')
2452+ #print(f'inference vs training rel: {relerr1}')
2453+ #print(f'inference vs training max: {maxerr1}')
2454+ #print(f'inference vs training vs torch err ratio abs: {absratio}')
2455+ #print(f'inference vs training vs torch err ratio rel: {relratio}')
2456+ #print(f'inference vs training vs torch err ratio max: {maxratio}')
24402457 if dtype == torch .float16 :
2441- assert sum (errs )/ len (errs )/ math .sqrt (dim ) < 5e-5
2442- assert sum (relerrs )/ len (relerrs )/ math .sqrt (dim ) < 0.0005
2458+ if dim <= 512 :
2459+ assert err1 < 7e-5
2460+ assert relerr1 < 0.0008
2461+ else :
2462+ assert err1 < 6e-5
2463+ assert relerr1 < 2e-4
2464+ assert absratio < 1.005 and absratio > 0.995
2465+ assert relratio < 1.005 and relratio > 0.995
2466+ assert maxratio < 1.005 and maxratio > 0.995
24432467 elif dtype == torch .float32 :
2444- assert sum (errs )/ len (errs )/ math .sqrt (dim ) < 5e-8
2445- assert sum (relerrs )/ len (relerrs )/ math .sqrt (dim ) < 1e-7
2468+ if dim <= 512 :
2469+ assert err1 < 5e-8
2470+ assert relerr1 < 1e-6
2471+ assert maxerr1 < 1e-7
2472+ else :
2473+ assert err1 < 5e-8
2474+ assert relerr1 < 8e-6
2475+ assert maxerr1 < 1e-7
2476+ assert absratio < 1.005 and absratio > 0.995
2477+ assert relratio < 1.005 and relratio > 0.995
2478+ assert maxratio < 1.005 and maxratio > 0.995
24462479 elif dtype == torch .bfloat16 :
2447- assert sum (errs )/ len (errs )/ math .sqrt (dim ) < 3e-4
2448- assert sum (relerrs )/ len (relerrs )/ math .sqrt (dim ) < 0.003
2480+ if dim <= 512 :
2481+ assert err1 < 5e-4
2482+ assert relerr1 < 0.007
2483+ assert maxerr1 < 0.015
2484+ else :
2485+ assert err1 < 2e-4
2486+ assert relerr1 < 0.002
2487+ assert maxerr1 < 0.0012
2488+ assert absratio < 1.005 and absratio > 0.995
2489+ assert relratio < 1.04 and relratio > 0.96
2490+ assert maxratio < 1.02 and maxratio > 0.98
24492491
24502492@pytest .mark .skip ("Row scale has some bugs for ampere" )
24512493def test_managed ():
0 commit comments