@@ -2355,47 +2355,62 @@ def test_normal_map_tree():
23552355#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
23562356@pytest .mark .parametrize ("dtype" , [torch .float16 ], ids = ['fp16' ])
23572357def test_cutlass3_gemm (dtype ):
2358- for i in range (100 ):
2359- #A = torch.rand(2, 4092, dtype=dtype, device='cuda')
2360- #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
2361- #A = torch.rand(1, 4096, dtype=dtype, device='cuda')
2362- #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
2363- A = torch .randn (1 , 128 + 32 , dtype = dtype , device = 'cuda' )
2364- B = torch .randn (4096 , 128 + 32 , dtype = dtype , device = 'cuda' )/ math .sqrt (128 )
2365-
2366- #print('')
2367- #print(A)
2368- #print(B.t())
2369- #A[:, :-3] = 0
2370- #B[:, :-3] = 0
2371-
2372-
2373- C1 = torch .matmul (A , B .t ())
2374- C2 = F .cutlass3_gemm (A , B .t ())
2375- err = C1 - C2
2376-
2377- # tensor cores are non-deterministic
2378- # so we need to analyze errors around the mean
2379- # to test our implementation
2380- err = torch .abs (err .mean ()).item ()
2381- mag = torch .abs (C1 ).mean ()
2382- relerr = err / mag
2383-
2384- if err / torch .abs (C1 ).mean () > 5e-5 or err > 3.2e-5 :
2385- print ('' )
2386- print (i , err , mag .item (), relerr .item ())
2387- print (A .flatten ()[- 6 :])
2388- print (B .flatten ()[- 6 :])
2389- out = A .flatten ()[- 6 :]* B .flatten ()[- 6 :]
2390- print (out )
2391- print (out [:- 1 ].sum ())
2392- print ('=' * 80 )
2393- print (C1 .flatten ()[- 6 :])
2394- print (C2 .flatten ()[- 6 :])
2395- #assert False, 'ERROR'
2396-
2397- c = int (C1 .numel ()* 0.001 )
2398- assert_all_approx_close (C1 , C2 , 1e-5 , 0.01 , count = c )
2358+ for dim in [32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 ]:
2359+ errs = []
2360+ relerrs = []
2361+ max_err = 0
2362+ max_relerr = 0
2363+ for i in range (100 ):
2364+ #A = torch.rand(2, 4092, dtype=dtype, device='cuda')
2365+ #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
2366+ #A = torch.rand(1, 4096, dtype=dtype, device='cuda')
2367+ #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
2368+ A = torch .randn (1 , dim + 0 , dtype = dtype , device = 'cuda' )
2369+ B = torch .randn (4 * 496 , dim + 0 , dtype = dtype , device = 'cuda' )/ math .sqrt (dim )
2370+
2371+ #print('')
2372+ #print(A)
2373+ #print(B.t())
2374+ #A[:, :-3] = 0
2375+ #B[:, :-3] = 0
2376+
2377+
2378+ C1 = torch .matmul (A , B .t ())
2379+ C2 = F .cutlass3_gemm (A , B .t ())
2380+
2381+ # tensor cores are non-deterministic
2382+ # so we need to analyze errors around the mean
2383+ # to test our implementation
2384+ err = torch .abs (C1 - C2 )
2385+ mag = torch .abs (C1 )+ 1e-8
2386+ relerr = err / mag
2387+ max_err = max (err .max (), max_err )
2388+ max_relerr = max (relerr .max (), max_relerr )
2389+ err = err .mean ().item ()
2390+ relerr = relerr .mean ().item ()
2391+
2392+ errs .append (err )
2393+ relerrs .append (relerr )
2394+
2395+ #if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
2396+ # print('')
2397+ # print(i, err, mag.item(), relerr.item())
2398+ # print(A.flatten()[-6:])
2399+ # print(B.flatten()[-6:])
2400+ # out = A.flatten()[-6:]*B.flatten()[-6:]
2401+ # print(out)
2402+ # print(out[:-1].sum())
2403+ # print('='*80)
2404+ # print(C1.flatten()[-6:])
2405+ # print(C2.flatten()[-6:])
2406+ # #assert False, 'ERROR'
2407+
2408+ c = int (C1 .numel ()* 0.00125 * (dim / 256 ))+ 1
2409+ assert_all_approx_close (C1 , C2 , 1e-5 , 0.01 , count = c )
2410+ print ('' )
2411+ print (dim , sum (errs )/ len (errs )/ math .sqrt (dim ))
2412+ print (dim , sum (relerrs )/ len (relerrs )/ math .sqrt (dim ))
2413+ print (dim , (max_err .item (), max_relerr .item ()))
23992414
24002415#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
24012416@pytest .mark .parametrize ("dtype" , [torch .float16 ], ids = ['fp16' ])
0 commit comments