-
Notifications
You must be signed in to change notification settings - Fork 501
Expand file tree
/
Copy pathtest_swiglu.py
More file actions
492 lines (412 loc) · 16.2 KB
/
test_swiglu.py
File metadata and controls
492 lines (412 loc) · 16.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
import tempfile
import pytest
import torch
import torch.multiprocessing as mp
import transformers
from packaging import version
from test.utils import supports_bfloat16
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaMLP
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.phi3.configuration_phi3 import Phi3Config
from transformers.models.phi3.modeling_phi3 import Phi3MLP
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
from liger_kernel.transformers.functional import liger_swiglu
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
from liger_kernel.transformers.swiglu import LigerExperts
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from liger_kernel.utils import infer_comm_backend
from liger_kernel.utils import infer_device
IS_TRANSFORMERS_V5_OR_LATER = version.parse(transformers.__version__) >= version.parse("5.0.0")
if IS_TRANSFORMERS_V5_OR_LATER:
from transformers.models.mixtral.modeling_mixtral import MixtralExperts
else:
from transformers.models.mixtral.modeling_mixtral import MixtralBlockSparseTop2MLP
device = infer_device()
LLAMA_CONFIG = LlamaConfig(
hidden_size=4096,
intermediate_size=11008,
hidden_act="silu",
)
PHI3_CONFIG = Phi3Config(
hidden_size=4096,
intermediate_size=11008,
hidden_act="silu",
)
SLEEP_SECONDS = 0.1
@pytest.mark.parametrize(
"bsz, seq_len, hidden_size, intermediate_size",
[
(2, 256, 256, 512),
# weird shapes
(6, 42, 123, 431),
],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
# atol is for small values: they have more difference, so set atol higher
# rtol is for larger values: they are very close, so set rtol lower
(torch.float32, 1e-0, 1e-5),
# TODO: we should find a better way to tune this. 1e4 is too large apparently
pytest.param(
torch.bfloat16,
1e4,
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
),
],
)
def test_correctness_llamamlp(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol):
_input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype)
x1 = _input.clone().requires_grad_(True)
x2 = _input.clone().requires_grad_(True)
# initialize weights
G = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype)
U = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype)
D = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype)
llama_mlp = LlamaMLP(config=LLAMA_CONFIG).to(device).to(dtype)
llama_mlp.gate_proj.weight.data = G.T
llama_mlp.up_proj.weight.data = U.T
llama_mlp.down_proj.weight.data = D.T
liger_mlp = LigerSwiGLUMLP(config=LLAMA_CONFIG).to(device).to(dtype)
liger_mlp.gate_proj.weight.data = G.T
liger_mlp.up_proj.weight.data = U.T
liger_mlp.down_proj.weight.data = D.T
y1 = llama_mlp(x1)
y2 = liger_mlp(x2)
assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
dy = torch.randn_like(y1)
y1.backward(dy.clone(), retain_graph=True)
y2.backward(dy.clone(), retain_graph=True)
assert torch.allclose(
llama_mlp.gate_proj.weight.grad,
liger_mlp.gate_proj.weight.grad,
atol=atol,
rtol=rtol,
)
assert torch.allclose(
llama_mlp.up_proj.weight.grad,
liger_mlp.up_proj.weight.grad,
atol=atol,
rtol=rtol,
)
assert torch.allclose(
llama_mlp.down_proj.weight.grad,
liger_mlp.down_proj.weight.grad,
atol=atol,
rtol=rtol,
)
assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)
@pytest.mark.skipif(IS_TRANSFORMERS_V5_OR_LATER, reason="Skip for transformers >= v5.0.0")
@pytest.mark.parametrize(
"bsz, seq_len, hidden_size, intermediate_size",
[
(2, 256, 256, 512),
# weird shapes
(6, 42, 123, 431),
],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
# atol is for small values: they have more difference, so set atol higher
# rtol is for larger values: they are very close, so set rtol lower
(torch.float32, 1e-0, 1e-5),
# TODO: we should find a better way to tune this. 1e4 is too large apparently
pytest.param(
torch.bfloat16,
1e4,
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
),
],
)
def test_correctness_mixtralblocksparsetop2mlp(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol):
MIXTRAL_CONFIG = MixtralConfig(
num_local_experts=8,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act="silu",
num_experts_per_tok=2,
)
_input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype)
x1 = _input.clone().requires_grad_(True)
x2 = _input.clone().requires_grad_(True)
# initialize weights
G = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype)
U = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype)
D = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype)
mixtral_blocksparsetop2mlp = MixtralBlockSparseTop2MLP(config=MIXTRAL_CONFIG).to(device).to(dtype)
mixtral_blocksparsetop2mlp.w1.weight.data = G.T
mixtral_blocksparsetop2mlp.w2.weight.data = U.T
mixtral_blocksparsetop2mlp.w3.weight.data = D.T
liger_blocksparsetop2mlp = LigerBlockSparseTop2MLP(config=MIXTRAL_CONFIG).to(device).to(dtype)
liger_blocksparsetop2mlp.w1.weight.data = G.T
liger_blocksparsetop2mlp.w2.weight.data = U.T
liger_blocksparsetop2mlp.w3.weight.data = D.T
y1 = mixtral_blocksparsetop2mlp(x1)
y2 = liger_blocksparsetop2mlp(x2)
assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
dy = torch.randn_like(y1)
y1.backward(dy.clone(), retain_graph=True)
y2.backward(dy.clone(), retain_graph=True)
assert torch.allclose(
mixtral_blocksparsetop2mlp.w1.weight.grad,
liger_blocksparsetop2mlp.w1.weight.grad,
atol=atol,
rtol=rtol,
)
assert torch.allclose(
mixtral_blocksparsetop2mlp.w2.weight.grad,
liger_blocksparsetop2mlp.w2.weight.grad,
atol=atol,
rtol=rtol,
)
assert torch.allclose(
mixtral_blocksparsetop2mlp.w3.weight.grad,
liger_blocksparsetop2mlp.w3.weight.grad,
atol=atol,
rtol=rtol,
)
assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)
@pytest.mark.skipif(not IS_TRANSFORMERS_V5_OR_LATER, reason="Skip for transformers < v5.0.0")
@pytest.mark.parametrize(
"bsz, seq_len, hidden_size, intermediate_size",
[
(2, 256, 256, 512),
# weird shapes
(6, 42, 123, 431),
],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
# atol is for small values: they have more difference, so set atol higher
# rtol is for larger values: they are very close, so set rtol lower
(torch.float32, 1e-0, 1e-5),
# TODO: we should find a better way to tune this. 1e4 is too large apparently
pytest.param(
torch.bfloat16,
1e4,
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
),
],
)
def test_correctness_mixtralexperts(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol):
MIXTRAL_CONFIG = MixtralConfig(
num_local_experts=8,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
experts_implementation="eager",
hidden_act="silu",
num_experts_per_tok=2,
)
_input = torch.randn(bsz * seq_len, hidden_size, device=device, dtype=dtype)
x1 = _input.clone().requires_grad_(True)
x2 = _input.clone().requires_grad_(True)
# match shape: (num_experts, 2 * intermediate_dim, hidden_dim)
GU = torch.randn(
MIXTRAL_CONFIG.num_local_experts,
2 * intermediate_size,
hidden_size,
device=device,
dtype=dtype,
requires_grad=True,
)
# match shape: (num_experts, hidden_dim, intermediate_dim)
D = torch.randn(
MIXTRAL_CONFIG.num_local_experts, hidden_size, intermediate_size, device=device, dtype=dtype, requires_grad=True
)
# Generate random router logits and do topk
router_logits = torch.randn(bsz * seq_len, MIXTRAL_CONFIG.num_local_experts, device=device, dtype=dtype)
router_logits = router_logits.softmax(dim=-1)
top_k_weights, top_k_index = router_logits.topk(k=MIXTRAL_CONFIG.num_experts_per_tok, dim=-1)
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
mixtral_experts = MixtralExperts(config=MIXTRAL_CONFIG).to(device).to(dtype)
mixtral_experts.gate_up_proj.data = GU.clone().detach()
mixtral_experts.down_proj.data = D.clone().detach()
liger_experts = LigerExperts(config=MIXTRAL_CONFIG).to(device).to(dtype)
liger_experts.gate_up_proj.data = GU.clone().detach()
liger_experts.down_proj.data = D.clone().detach()
mixtral_experts.gate_up_proj.requires_grad_()
mixtral_experts.down_proj.requires_grad_()
liger_experts.gate_up_proj.requires_grad_()
liger_experts.down_proj.requires_grad_()
y1 = mixtral_experts(x1, top_k_index, top_k_weights)
y2 = liger_experts(x2, top_k_index, top_k_weights)
assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
dy = torch.randn_like(y1)
y1.backward(dy.clone(), retain_graph=True)
y2.backward(dy.clone(), retain_graph=True)
assert torch.allclose(
mixtral_experts.gate_up_proj.grad,
liger_experts.gate_up_proj.grad,
atol=atol,
rtol=rtol,
)
assert torch.allclose(
mixtral_experts.down_proj.grad,
liger_experts.down_proj.grad,
atol=atol,
rtol=rtol,
)
assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)
@pytest.mark.parametrize(
"bsz, seq_len, hidden_size, intermediate_size",
[
(2, 256, 256, 512),
# weird shapes
(6, 42, 123, 431),
],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
# atol is for small values: they have more difference, so set atol higher
# rtol is for larger values: they are very close, so set rtol lower
(torch.float32, 1e-0, 1e-5),
# TODO: we should find a better way to tune this. 1e4 is too large apparently
pytest.param(
torch.bfloat16,
1e4,
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
),
],
)
def test_correctness_phi3mlp(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol):
_input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype)
x1 = _input.clone().requires_grad_(True)
x2 = _input.clone().requires_grad_(True)
# initialize weights
GU = torch.randn(hidden_size, intermediate_size * 2, device=device, dtype=dtype)
D = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype)
phi3_mlp = Phi3MLP(config=PHI3_CONFIG).to(device).to(dtype)
phi3_mlp.gate_up_proj.weight.data = GU.T
phi3_mlp.down_proj.weight.data = D.T
liger_mlp = LigerPhi3SwiGLUMLP(config=PHI3_CONFIG).to(device).to(dtype)
liger_mlp.gate_up_proj.weight.data = GU.T
liger_mlp.down_proj.weight.data = D.T
y1 = phi3_mlp(x1)
y2 = liger_mlp(x2)
assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
dy = torch.randn_like(y1)
y1.backward(dy.clone(), retain_graph=True)
y2.backward(dy.clone(), retain_graph=True)
assert torch.allclose(
phi3_mlp.gate_up_proj.weight.grad,
liger_mlp.gate_up_proj.weight.grad,
atol=atol,
rtol=rtol,
)
assert torch.allclose(
phi3_mlp.down_proj.weight.grad,
liger_mlp.down_proj.weight.grad,
atol=atol,
rtol=rtol,
)
assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)
@pytest.mark.parametrize(
"bsz, seq_len, size",
[
(2, 8, 8),
(9, 7, 41),
],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
# atol is for small values: they have more difference, so set atol higher
# rtol is for larger values: they are very close, so set rtol lower
(torch.float32, 1e-0, 1e-5),
# TODO: we should find a better way to tune this. 1e4 is too large apparently
(torch.bfloat16, 1e4, 1e-2),
],
)
def test_correctness_functional(bsz, seq_len, size, dtype, atol, rtol):
_input = torch.randn(bsz, seq_len, size, device=device, dtype=dtype)
_b = torch.randn(bsz, seq_len, size, device=device, dtype=dtype)
x1 = _input.clone().requires_grad_(True)
x2 = _input.clone().requires_grad_(True)
b1 = _b.clone().requires_grad_(True)
b2 = _b.clone().requires_grad_(True)
y1 = liger_swiglu(a=x1, b=b1)
y2 = LigerSiLUMulFunction.apply(x2, b2)
assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
# Test backward pass
grad_output = torch.randn_like(y1)
y1.backward(grad_output)
y2.backward(grad_output)
# Check if gradients are close for x
assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)
assert torch.allclose(b1.grad, b2.grad, atol=atol, rtol=rtol)
def _test_dtensor_liger_silumul(rank, world_size, bsz, seq_len, hidden_size, dtype, atol, rtol, file_name):
torch.distributed.init_process_group(
backend=infer_comm_backend(),
init_method=f"file://{file_name}",
rank=rank,
world_size=world_size,
)
device = f"{infer_device()}:{rank}" if infer_device() != "cpu" else "cpu"
device_mesh = torch.distributed.device_mesh.init_device_mesh(
infer_device(), mesh_shape=(world_size,), mesh_dim_names=("tp",)
)
_a = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype)
_b = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype)
# Broadcast from rank 0 so all ranks operate on identical tensors
torch.distributed.broadcast(_a, src=0)
torch.distributed.broadcast(_b, src=0)
assert hidden_size % world_size == 0, f"hidden_size ({hidden_size}) must be divisible by world_size ({world_size})"
# DTensor path: shard inputs along the hidden dim
a1 = _a.clone().detach().requires_grad_(True)
b1 = _b.clone().detach().requires_grad_(True)
da = torch.distributed.tensor.distribute_tensor(
a1, device_mesh=device_mesh, placements=[torch.distributed.tensor.Shard(2)]
)
db = torch.distributed.tensor.distribute_tensor(
b1, device_mesh=device_mesh, placements=[torch.distributed.tensor.Shard(2)]
)
# Regular tensor path
a2 = _a.clone().detach().requires_grad_(True)
b2 = _b.clone().detach().requires_grad_(True)
c1 = LigerSiLUMulFunction.apply(da, db)
c2 = LigerSiLUMulFunction.apply(a2, b2)
torch.testing.assert_close(c1.full_tensor(), c2, atol=atol, rtol=rtol)
grad = torch.randn_like(c2)
torch.distributed.broadcast(grad, src=0)
dgrad = torch.distributed.tensor.distribute_tensor(
grad, device_mesh=device_mesh, placements=[torch.distributed.tensor.Shard(2)]
)
c1.backward(dgrad)
c2.backward(grad)
torch.testing.assert_close(da.grad.full_tensor(), a2.grad, atol=atol, rtol=rtol)
torch.testing.assert_close(db.grad.full_tensor(), b2.grad, atol=atol, rtol=rtol)
@pytest.mark.xfail(
torch.cuda.device_count() < 8,
reason="Pending multi-GPU host support. This test is expected to pass when run with multi-GPU host.",
)
@pytest.mark.parametrize(
"world_size, bsz, seq_len, hidden_size",
[
(4, 2, 2, 8),
(8, 9, 7, 64),
],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
(torch.float32, 1e-4, 1e-6),
(torch.bfloat16, 2e-1, 2e-2),
],
)
def test_dtensor_liger_silumul(world_size, bsz, seq_len, hidden_size, dtype, atol, rtol):
with tempfile.NamedTemporaryFile() as f:
mp.spawn(
_test_dtensor_liger_silumul,
args=(world_size, bsz, seq_len, hidden_size, dtype, atol, rtol, f.name),
nprocs=world_size,
join=True,
)