File tree Expand file tree Collapse file tree 1 file changed +3
-10
lines changed Expand file tree Collapse file tree 1 file changed +3
-10
lines changed Original file line number Diff line number Diff line change 77 print ('triton is not installed, please install by running `pip install triton -U --pre`' )
88 exit ()
99
10- # clone param and exp_avg before autotuning takes place
11- # as those are updated in-place
12-
13- def clone_inplace_updated_params (nargs ):
14- nargs ['p_ptr' ] = nargs ['p_ptr' ].clone ()
15- nargs ['exp_avg_ptr' ] = nargs ['exp_avg_ptr' ].clone ()
16-
1710# triton cuda kernel
1811
1912@triton .autotune (configs = [
20- triton .Config ({'BLOCK_SIZE' : 128 }, num_warps = 4 , pre_hook = clone_inplace_updated_params ),
21- triton .Config ({'BLOCK_SIZE' : 1024 }, num_warps = 8 , pre_hook = clone_inplace_updated_params ),
22- ], key = ['n_elements' ])
13+ triton .Config ({'BLOCK_SIZE' : 128 }, num_warps = 4 ),
14+ triton .Config ({'BLOCK_SIZE' : 1024 }, num_warps = 8 ),
15+ ], key = ['n_elements' ], restore_value = [ 'p_ptr' , 'exp_avg_ptr' ] )
2316@triton .jit
2417def update_fn_kernel (
2518 p_ptr ,
You can’t perform that action at this time.
0 commit comments