Skip to content

Commit 974b754

Browse files
authored
Fix in-place modification when autotuning triton Lion update
1 parent 6629519 commit 974b754

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

lion_pytorch/triton.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,12 @@
77
print('triton is not installed, please install by running `pip install triton -U --pre`')
88
exit()
99

10-
# clone param and exp_avg before autotuning takes place
11-
# as those are updated in-place
12-
13-
def clone_inplace_updated_params(nargs):
14-
nargs['p_ptr'] = nargs['p_ptr'].clone()
15-
nargs['exp_avg_ptr'] = nargs['exp_avg_ptr'].clone()
16-
1710
# triton cuda kernel
1811

1912
@triton.autotune(configs = [
20-
triton.Config({'BLOCK_SIZE': 128}, num_warps = 4, pre_hook = clone_inplace_updated_params),
21-
triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8, pre_hook = clone_inplace_updated_params),
22-
], key = ['n_elements'])
13+
triton.Config({'BLOCK_SIZE': 128}, num_warps = 4),
14+
triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8),
15+
], key = ['n_elements'], restore_value=['p_ptr', 'exp_avg_ptr'])
2316
@triton.jit
2417
def update_fn_kernel(
2518
p_ptr,

0 commit comments

Comments
 (0)