-
Notifications
You must be signed in to change notification settings - Fork 508
Expand file tree
/
Copy pathfused_linear_jsd.py
More file actions
228 lines (197 loc) · 9.46 KB
/
fused_linear_jsd.py
File metadata and controls
228 lines (197 loc) · 9.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
from typing import Optional
import torch
import triton
from liger_kernel.ops.jsd import _jsd_kernel
from liger_kernel.ops.utils import amp_custom_bwd
from liger_kernel.ops.utils import amp_custom_fwd
from liger_kernel.ops.utils import element_mul_kernel
from liger_kernel.ops.utils import is_hip
from liger_kernel.utils import infer_device
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
MAX_FUSED_SIZE = 4096 if infer_device() == "npu" else 65536 // 2
def fused_linear_jsd_forward(
student_input,
student_weight,
teacher_input,
teacher_weight,
shift_labels,
jsd_beta,
ignore_index,
has_label,
temperature,
):
device = student_input.device
dtype = student_input.dtype
# inputs have shape: BT x H
# materialized activations will have shape: BT x V
# the increase in memory = BT x V
# reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
# for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
# inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
# for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
BT, H = student_input.shape
V = student_weight.shape[0]
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None
grad_input = torch.zeros_like(student_input)
# we use fp32 for loss accumulator
loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
if has_label:
n_non_ignore = (shift_labels != ignore_index).sum().item()
else:
n_non_ignore = BT
for chunk_id in range(num_chunks):
start_idx = chunk_id * chunk_size
end_idx = min((chunk_id + 1) * chunk_size, BT)
# chunk both inputs, shape: chunk_size x H
student_input_chunk = student_input[start_idx:end_idx]
teacher_input_chunk = teacher_input[start_idx:end_idx]
# shape: chunk_size x V
# For anything starting from logits to the final JSD loss, we do computation
# in FP32 to avoid losing numerical stability.
student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32)
teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32)
chunk_n_rows = student_logits_chunk.shape[0]
# unreduced loss
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size
# log-softmax with temperature
student_logits_chunk = student_logits_chunk / temperature
teacher_logits_chunk = teacher_logits_chunk / temperature
student_prob_chunk = torch.log_softmax(student_logits_chunk, dim=-1)
teacher_prob_chunk = torch.log_softmax(teacher_logits_chunk, dim=-1)
# ensure _input and target are contiguous
student_prob_chunk = student_prob_chunk.contiguous()
teacher_prob_chunk = teacher_prob_chunk.contiguous()
# Here we calculate the gradient of prob_chunk in place so we can save memory.
_jsd_kernel[(chunk_n_rows,)](
X_ptr=student_prob_chunk,
X_stride=student_prob_chunk.stride(-2),
Y_ptr=teacher_prob_chunk,
Y_stride=teacher_prob_chunk.stride(-2),
loss_ptr=loss_1d_slice,
loss_stride=loss_1d_slice.stride(-2),
dX_ptr=student_prob_chunk,
dX_stride=student_prob_chunk.stride(-2),
label_ptr=(
shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device)
), # dummy ptr if no label
beta=jsd_beta,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
n_cols=V,
BLOCK_SIZE=BLOCK_SIZE,
HAS_LABEL=has_label,
)
loss_1d[start_idx:end_idx] = loss_1d_slice
# gradients of prob_chunk in place, shape: chunk_size x V
# gradients of logits_chunk in place, shape: chunk_size x V
student_logits_chunk = (
student_prob_chunk
- torch.softmax(student_logits_chunk, dim=-1)
* student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(student_prob_chunk.shape)
) / temperature
# now we traverse back to grad w.r.t. input to `lm_head` and grad
# w.r.t. `lm_head` which should be computed in original dtype
student_logits_chunk = student_logits_chunk.to(dtype)
grad_input[start_idx:end_idx] = student_logits_chunk @ student_weight
if grad_weight is not None:
grad_weight.add_(student_logits_chunk.t() @ student_input_chunk)
loss = torch.sum(loss_1d)
return loss, grad_input, grad_weight
def fused_linear_jsd_backward(grad_output, grad_input, grad_weight):
# If JSD is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
BT, H = grad_input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
element_mul_kernel[(n_rows,)](
grad_input,
grad_input.stride(-2),
grad_output,
H,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
# handle grad_weight
if grad_weight is not None:
V, H = grad_weight.shape
n_rows = V
element_mul_kernel[(n_rows,)](
grad_weight,
grad_weight.stride(-2),
grad_output,
H,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
return grad_input, grad_weight
class LigerFusedLinearJSDFunction(torch.autograd.Function):
"""
Fusing the last linear layer with generalized JSD
Handle the forward and backward pass of the final linear layer via JSD by avoiding
the materialization of the large logits tensor. Since JSD is the last layer, we can
compute the gradient at the forward pass.
"""
@staticmethod
@amp_custom_fwd
def forward(
ctx,
student_input: torch.Tensor,
student_weight: torch.Tensor,
teacher_input: torch.Tensor,
teacher_weight: torch.Tensor,
shift_labels: Optional[torch.Tensor] = None,
jsd_beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
):
"""
Args:
student_input (torch.tensor): input of the last projection layer in student model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
student_weight (torch.tensor): the last projection layer in student model, with shape (V, H), where V is vocab size
teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
ignore_index (int): the index to ignore. Default: -100
temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
Returns:
loss (torch.Tensor): generalized JSD
"""
has_label = False
if shift_labels is not None:
assert shift_labels.shape == (teacher_input.shape[0],), (
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
)
shift_labels = shift_labels.contiguous()
has_label = True
loss, grad_input, grad_weight = fused_linear_jsd_forward(
student_input,
student_weight,
teacher_input,
teacher_weight,
shift_labels,
jsd_beta,
ignore_index,
has_label,
temperature,
)
# downcast to dtype and store for backward
ctx.save_for_backward(
grad_input.detach(),
grad_weight.detach() if grad_weight is not None else None,
)
return loss
@staticmethod
@amp_custom_bwd
def backward(ctx, grad_output):
(grad_input, grad_weight) = ctx.saved_tensors
grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight)
return (grad_input, grad_weight, None, None, None, None, None, None)