|
1 | 1 | import torch |
2 | | -from torch import Tensor |
3 | 2 |
|
4 | 3 | try: |
5 | 4 | import triton |
|
8 | 7 | print('triton is not installed, please install by running `pip install triton -U --pre`') |
9 | 8 | exit() |
10 | 9 |
|
11 | | -# helper functions |
| 10 | +# clone param and exp_avg before autotuning takes place |
| 11 | +# as those are updated in-place |
12 | 12 |
|
13 | | -def calc_num_warps(block_size): |
14 | | - num_warps = 4 |
15 | | - if block_size >= 2048: |
16 | | - num_warps = 8 |
17 | | - if block_size >= 4096: |
18 | | - num_warps = 16 |
19 | | - return num_warps |
| 13 | +def clone_inplace_updated_params(nargs): |
| 14 | + nargs['p_ptr'] = nargs['p_ptr'].clone() |
| 15 | + nargs['exp_avg_ptr'] = nargs['exp_avg_ptr'].clone() |
20 | 16 |
|
21 | 17 | # triton cuda kernel |
22 | 18 |
|
| 19 | +@triton.autotune(configs = [ |
| 20 | + triton.Config({'BLOCK_SIZE': 128}, num_warps = 4, pre_hook = clone_inplace_updated_params), |
| 21 | + triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8, pre_hook = clone_inplace_updated_params), |
| 22 | +], key = ['n_elements']) |
23 | 23 | @triton.jit |
24 | 24 | def update_fn_kernel( |
25 | 25 | p_ptr, |
@@ -80,35 +80,26 @@ def update_fn_kernel( |
80 | 80 | tl.store(offset_exp_avg_ptr, exp_avg, mask = mask) |
81 | 81 |
|
82 | 82 | def update_fn( |
83 | | - p: Tensor, |
84 | | - grad: Tensor, |
85 | | - exp_avg: Tensor, |
| 83 | + p: torch.Tensor, |
| 84 | + grad: torch.Tensor, |
| 85 | + exp_avg: torch.Tensor, |
86 | 86 | lr: float, |
87 | 87 | wd: float, |
88 | 88 | beta1: float, |
89 | | - beta2: float, |
90 | | - inplace: bool = True, |
91 | | - BLOCK_SIZE: int = 1024 |
| 89 | + beta2: float |
92 | 90 | ): |
93 | 91 | assert all([t.is_cuda for t in (p, grad, exp_avg)]) |
94 | | - |
95 | 92 | n_elements = p.numel() |
96 | 93 |
|
97 | | - block_size = triton.next_power_of_2(BLOCK_SIZE) |
98 | | - num_warps = calc_num_warps(block_size) |
99 | | - n_rows = triton.cdiv(n_elements, block_size) |
100 | | - |
101 | | - # call triton cuda kernel |
| 94 | + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) |
102 | 95 |
|
103 | | - update_fn_kernel[(n_rows,)]( |
| 96 | + update_fn_kernel[grid]( |
104 | 97 | p, |
105 | 98 | grad, |
106 | 99 | exp_avg, |
107 | 100 | lr, |
108 | 101 | wd, |
109 | 102 | beta1, |
110 | 103 | beta2, |
111 | | - n_elements, |
112 | | - num_warps = num_warps, |
113 | | - BLOCK_SIZE = BLOCK_SIZE |
| 104 | + n_elements |
114 | 105 | ) |
0 commit comments