Skip to content

Commit 6629519

Browse files
authored
Merge pull request #26 from lucidrains/pre-hook-autotuner
use pre hook to fix in-place / autotuner issue
2 parents 2226ec8 + 3d1e555 commit 6629519

File tree

3 files changed

+19
-34
lines changed

3 files changed

+19
-34
lines changed

lion_pytorch/lion_pytorch.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from functools import partial
21
from typing import Tuple, Optional, Callable
32

43
import torch
@@ -34,8 +33,7 @@ def __init__(
3433
lr: float = 1e-4,
3534
betas: Tuple[float, float] = (0.9, 0.99),
3635
weight_decay: float = 0.0,
37-
use_triton: bool = False,
38-
triton_block_size: int = 1024
36+
use_triton: bool = False
3937
):
4038
assert lr > 0.
4139
assert all([0. <= beta <= 1. for beta in betas])
@@ -49,12 +47,10 @@ def __init__(
4947
super().__init__(params, defaults)
5048

5149
self.update_fn = update_fn
52-
self.use_triton = use_triton
53-
self.took_first_step = False
5450

5551
if use_triton:
5652
from lion_pytorch.triton import update_fn as triton_update_fn
57-
self.update_fn = partial(triton_update_fn, BLOCK_SIZE = triton_block_size)
53+
self.update_fn = triton_update_fn
5854

5955
@torch.no_grad()
6056
def step(
@@ -67,8 +63,6 @@ def step(
6763
with torch.enable_grad():
6864
loss = closure()
6965

70-
# update all parameters
71-
7266
for group in self.param_groups:
7367
for p in filter(lambda p: exists(p.grad), group['params']):
7468

lion_pytorch/triton.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
from torch import Tensor
32

43
try:
54
import triton
@@ -8,18 +7,19 @@
87
print('triton is not installed, please install by running `pip install triton -U --pre`')
98
exit()
109

11-
# helper functions
10+
# clone param and exp_avg before autotuning takes place
11+
# as those are updated in-place
1212

13-
def calc_num_warps(block_size):
14-
num_warps = 4
15-
if block_size >= 2048:
16-
num_warps = 8
17-
if block_size >= 4096:
18-
num_warps = 16
19-
return num_warps
13+
def clone_inplace_updated_params(nargs):
14+
nargs['p_ptr'] = nargs['p_ptr'].clone()
15+
nargs['exp_avg_ptr'] = nargs['exp_avg_ptr'].clone()
2016

2117
# triton cuda kernel
2218

19+
@triton.autotune(configs = [
20+
triton.Config({'BLOCK_SIZE': 128}, num_warps = 4, pre_hook = clone_inplace_updated_params),
21+
triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8, pre_hook = clone_inplace_updated_params),
22+
], key = ['n_elements'])
2323
@triton.jit
2424
def update_fn_kernel(
2525
p_ptr,
@@ -80,35 +80,26 @@ def update_fn_kernel(
8080
tl.store(offset_exp_avg_ptr, exp_avg, mask = mask)
8181

8282
def update_fn(
83-
p: Tensor,
84-
grad: Tensor,
85-
exp_avg: Tensor,
83+
p: torch.Tensor,
84+
grad: torch.Tensor,
85+
exp_avg: torch.Tensor,
8686
lr: float,
8787
wd: float,
8888
beta1: float,
89-
beta2: float,
90-
inplace: bool = True,
91-
BLOCK_SIZE: int = 1024
89+
beta2: float
9290
):
9391
assert all([t.is_cuda for t in (p, grad, exp_avg)])
94-
9592
n_elements = p.numel()
9693

97-
block_size = triton.next_power_of_2(BLOCK_SIZE)
98-
num_warps = calc_num_warps(block_size)
99-
n_rows = triton.cdiv(n_elements, block_size)
100-
101-
# call triton cuda kernel
94+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
10295

103-
update_fn_kernel[(n_rows,)](
96+
update_fn_kernel[grid](
10497
p,
10598
grad,
10699
exp_avg,
107100
lr,
108101
wd,
109102
beta1,
110103
beta2,
111-
n_elements,
112-
num_warps = num_warps,
113-
BLOCK_SIZE = BLOCK_SIZE
104+
n_elements
114105
)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.0',
6+
version = '0.1.2',
77
license='MIT',
88
description = 'Lion Optimizer - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)