10
10
from typing import Tuple
11
11
12
12
import torch
13
- import torch .nn as nn
14
- import torch .nn .functional as F
15
13
from torch .testing ._internal import common_utils
16
14
from torch .testing ._internal .common_utils import (
17
- TestCase ,
18
15
run_tests ,
19
16
)
20
17
21
- from torchao .prototype .moe_quant .utils import MoEQuantConfig
22
18
from torchao .quantization import (
23
19
Float8DynamicActivationFloat8WeightConfig ,
24
20
Float8WeightOnlyConfig ,
28
24
)
29
25
from torchao .quantization .quantize_ .common import KernelPreference
30
26
from torchao .quantization .utils import compute_error
27
+ from torchao .testing .utils import TorchAOIntegrationTestCase
31
28
from torchao .utils import (
32
29
TORCH_VERSION_AT_LEAST_2_8 ,
33
30
_is_fbgemm_genai_gpu_available ,
39
36
torch ._dynamo .config .cache_size_limit = 128
40
37
41
38
42
- class Experts (nn .Module ):
43
- def __init__ (
44
- self ,
45
- num_local_experts : int ,
46
- dim : int ,
47
- hidden_dim : int ,
48
- dtype : torch .dtype ,
49
- device : torch .device ,
50
- ) -> None :
51
- super ().__init__ ()
52
-
53
- self .num_local_experts = num_local_experts
54
- self .dim = dim
55
-
56
- self .w1 : nn .Parameter = nn .Parameter (
57
- torch .randn (
58
- num_local_experts ,
59
- dim ,
60
- hidden_dim ,
61
- dtype = dtype ,
62
- device = device ,
63
- )
64
- )
65
-
66
- self .w2 : nn .Parameter = nn .Parameter (
67
- torch .randn (
68
- num_local_experts ,
69
- hidden_dim ,
70
- dim ,
71
- dtype = dtype ,
72
- device = device ,
73
- )
74
- )
75
-
76
- self .w3 : nn .Parameter = nn .Parameter (
77
- torch .randn (
78
- num_local_experts ,
79
- dim ,
80
- hidden_dim ,
81
- dtype = dtype ,
82
- device = device ,
83
- )
84
- )
85
-
86
- def forward (
87
- self ,
88
- routed_in_egD : torch .Tensor , # noqa: N803
89
- ) -> torch .Tensor :
90
- e = self .num_local_experts
91
- D = self .dim
92
-
93
- x_egD = routed_in_egD .view (e , - 1 , D )
94
-
95
- middle_out_egF = F .silu (torch .bmm (x_egD , self .w1 )) * torch .bmm (x_egD , self .w3 )
96
- out_egD = torch .bmm (middle_out_egF , self .w2 )
97
- out_egD = out_egD .view (- 1 , D )
98
-
99
- return out_egD
100
-
101
-
102
39
class ToyLinearModel (torch .nn .Module ):
103
40
def __init__ (self , in_features , out_features ):
104
41
super ().__init__ ()
@@ -115,7 +52,7 @@ def forward(self, x):
115
52
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_8 , "Need pytorch 2.8+" )
116
53
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
117
54
@unittest .skipIf (not is_sm_at_least_89 (), "Need sm89+" )
118
- class TestFloat8Tensor (TestCase ):
55
+ class TestFloat8Tensor (TorchAOIntegrationTestCase ):
119
56
def setUp (self ):
120
57
self .GPU_DEVICES = ["cuda" ] if torch .cuda .is_available () else []
121
58
@@ -340,45 +277,8 @@ def test_slice_preserves_aliasing(self, granularity):
340
277
341
278
@common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
342
279
def test_slice_and_copy_similar_to_vllm (self , granularity ):
343
- # making sure https://github.com/vllm-project/vllm/blob/90bd2ab6e3eb7e83d3f40d99fc23e6e43834743a/vllm/model_executor/layers/linear.py#L483-L495 works properly
344
- # the test is similar to the linked code, but with some hardcoded arguments
345
- # and does not use tensor parallelism
346
-
347
- dtype = torch .bfloat16
348
- device = "cuda"
349
280
config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
350
- l = torch .nn .Linear (1024 , 1024 , device = "cuda" , dtype = dtype )
351
- quantize_ (l , config )
352
-
353
- # high level, we do a narrow for both param.data and the loaded_weights
354
- # and do inplace copy_ to copy from the loaded_weights into param.data
355
-
356
- # simulate loaded_weight
357
- dummy_l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
358
- # making the weight different
359
- dummy_l .weight = torch .nn .Parameter (
360
- dummy_l .weight + 2 * torch .randn (1024 , 1024 , device = device , dtype = dtype ),
361
- requires_grad = False ,
362
- )
363
- quantize_ (dummy_l , config )
364
-
365
- output_dim = 0
366
- shard_size = 512
367
- for tp_rank in [0 , 1 ]:
368
- start_idx = tp_rank * shard_size
369
- param = l .weight
370
- param_data = param .data
371
- param_data = param_data .narrow (output_dim , start_idx , shard_size )
372
- orig_value = param_data .qdata [0 ][0 ].item ()
373
- loaded_weight = dummy_l .weight
374
- loaded_weight = loaded_weight .narrow (output_dim , start_idx , shard_size )
375
-
376
- # making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0]
377
- assert orig_value != loaded_weight .qdata [0 ][0 ]
378
- param_data .copy_ (loaded_weight )
379
- # making sure param.data is updated to loaded_weight
380
- assert param_data .qdata [0 ][0 ] == loaded_weight .qdata [0 ][0 ]
381
- assert param_data .scale [0 ] == loaded_weight .scale [0 ]
281
+ self ._test_slice_and_copy_similar_to_vllm (config )
382
282
383
283
@unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
384
284
def test_bmm (self ):
@@ -494,122 +394,9 @@ def test_cat(self, granularity, sizes):
494
394
495
395
@unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
496
396
def test_moe_weight_reshape_ops (self ):
497
- """This is testing the op call sequence in saving and loading quantization
498
- checkpoints in llama-models for llama4
499
- (https://github.com/meta-llama/llama-models/tree/main/models/llama4)
500
- """
501
- # only per row quantization is supported for bmm
502
397
granularity = PerRow ()
503
- dtype = torch .bfloat16
504
- device = "cuda"
505
-
506
- bmm_config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
507
- moe_config = MoEQuantConfig (bmm_config )
508
-
509
- batch_size = 4
510
- num_experts = 2
511
- input_dim = 64
512
- dim = 128
513
- hidden_dim = 256
514
-
515
- moe1 = Experts (num_experts , dim , hidden_dim , dtype , device )
516
- moe2 = Experts (num_experts , dim , hidden_dim , dtype , device )
517
- moe_combined = Experts (num_experts , dim , 2 * hidden_dim , dtype , device )
518
- input = torch .randn (batch_size , input_dim , dim , dtype = dtype , device = device )
519
-
520
- moes = [moe1 , moe2 ]
521
-
522
- for moe in moes :
523
- moe (input )
524
-
525
- def filter_fn (module , fqn ):
526
- return isinstance (module , Experts )
527
-
528
- # need to transpose before quantizing
529
- moe .w1 = torch .nn .Parameter (
530
- moe .w1 .transpose (1 , 2 ).contiguous (), requires_grad = False
531
- )
532
- moe .w2 = torch .nn .Parameter (
533
- moe .w2 .transpose (1 , 2 ).contiguous (), requires_grad = False
534
- )
535
- moe .w3 = torch .nn .Parameter (
536
- moe .w3 .transpose (1 , 2 ).contiguous (), requires_grad = False
537
- )
538
-
539
- quantize_ (moe , moe_config , filter_fn = filter_fn )
540
-
541
- # make sure it runs
542
- before = moe (input )
543
-
544
- # transposing for resharding support since only 2D resharding is supported
545
- new_last_dim = moe .w1 .shape [- 2 ]
546
- moe .w1 = torch .nn .Parameter (
547
- moe .w1 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
548
- )
549
- new_last_dim = moe .w2 .shape [- 2 ]
550
- moe .w2 = torch .nn .Parameter (
551
- moe .w2 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
552
- )
553
- new_last_dim = moe .w3 .shape [- 2 ]
554
- moe .w3 = torch .nn .Parameter (
555
- moe .w3 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
556
- )
557
-
558
- moe .w1 = torch .nn .Parameter (
559
- moe .w1 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
560
- requires_grad = False ,
561
- )
562
- moe .w2 = torch .nn .Parameter (
563
- moe .w2 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
564
- requires_grad = False ,
565
- )
566
- moe .w3 = torch .nn .Parameter (
567
- moe .w3 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
568
- requires_grad = False ,
569
- )
570
-
571
- # transpose again to recover the original weights
572
- moe .w1 = torch .nn .Parameter (moe .w1 .transpose (1 , 2 ), requires_grad = False )
573
- moe .w2 = torch .nn .Parameter (moe .w2 .transpose (1 , 2 ), requires_grad = False )
574
- moe .w3 = torch .nn .Parameter (moe .w3 .transpose (1 , 2 ), requires_grad = False )
575
-
576
- # make sure it runs
577
- after = moe (input )
578
-
579
- self .assertEqual (before , after )
580
-
581
- state_dicts = [moe1 .state_dict (), moe2 .state_dict ()]
582
- # align the scale parameter so they can be concatenated
583
- for key in ["w1" , "w2" , "w3" ]:
584
- weights = [st [key ] for st in state_dicts ]
585
- for i in range (1 , len (weights )):
586
- weights [i ].scale = weights [0 ].scale
587
-
588
- def process_key (key : str ) -> torch .Tensor :
589
- tensors = [s [key ] for s in state_dicts ]
590
- # Note: we have a hacky implementation for cat in user codebase
591
- # since it is not implemented correctly before
592
- if key == "w2" :
593
- return torch .cat (tensors , dim = - 1 )
594
- else :
595
- return torch .cat (tensors , dim = - 2 )
596
-
597
- new_state_dict = {}
598
- for key in ["w1" , "w2" , "w3" ]:
599
- new_state_dict [key ] = process_key (key )
600
-
601
- moe_combined .w1 = torch .nn .Parameter (
602
- moe_combined .w1 .transpose (1 , 2 ), requires_grad = False
603
- )
604
- moe_combined .w2 = torch .nn .Parameter (
605
- moe_combined .w2 .transpose (1 , 2 ), requires_grad = False
606
- )
607
- moe_combined .w3 = torch .nn .Parameter (
608
- moe_combined .w3 .transpose (1 , 2 ), requires_grad = False
609
- )
610
- moe_combined .load_state_dict (new_state_dict , assign = True )
611
- # make sure it runs
612
- moe_combined (input )
398
+ config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
399
+ self ._test_moe_weight_reshape_ops (config )
613
400
614
401
615
402
common_utils .instantiate_parametrized_tests (TestFloat8Tensor )
0 commit comments