8
8
9
9
import torch
10
10
from torch .testing ._internal .common_utils import (
11
- TestCase ,
11
+ instantiate_parametrized_tests ,
12
+ parametrize ,
12
13
run_tests ,
13
14
)
14
15
15
- from torchao .quantization import (
16
- Int4WeightOnlyConfig ,
17
- quantize_ ,
18
- )
16
+ from torchao .quantization import Int4WeightOnlyConfig , quantize_
19
17
from torchao .quantization .utils import compute_error
20
- from torchao .utils import (
21
- TORCH_VERSION_AT_LEAST_2_8 ,
22
- is_sm_at_least_90 ,
23
- )
18
+ from torchao .testing .utils import TorchAOIntegrationTestCase
19
+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_8 , is_sm_at_least_90
24
20
25
21
26
22
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_8 , "Need pytorch 2.8+" )
27
23
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
28
24
@unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
29
- class TestInt4Tensor (TestCase ):
25
+ class TestInt4Tensor (TorchAOIntegrationTestCase ):
30
26
def setUp (self ):
31
27
self .config = Int4WeightOnlyConfig (
32
28
group_size = 128 ,
@@ -61,50 +57,46 @@ def test_slice(self):
61
57
quantize_ (dummy , self .config )
62
58
weight1 = dummy .weight .narrow (0 , 0 , 64 )
63
59
weight2 = dummy .weight .narrow (1 , 0 , 128 )
64
- self .assertEqual (weight1 ._data , dummy .weight ._data .narrow (0 , 0 , 64 ))
60
+ self .assertEqual (weight1 .qdata , dummy .weight .qdata .narrow (0 , 0 , 64 ))
65
61
self .assertEqual (weight1 .scale , dummy .weight .scale .narrow (1 , 0 , 64 ))
66
- self .assertEqual (weight2 ._data , dummy .weight ._data .narrow (1 , 0 , 64 ))
62
+ self .assertEqual (weight1 .zero_point , dummy .weight .zero_point .narrow (1 , 0 , 64 ))
63
+ self .assertEqual (weight2 .qdata , dummy .weight .qdata .narrow (1 , 0 , 64 ))
67
64
self .assertEqual (weight2 .scale , dummy .weight .scale .narrow (0 , 0 , 1 ))
65
+ self .assertEqual (weight2 .zero_point , dummy .weight .zero_point .narrow (0 , 0 , 1 ))
68
66
69
67
# check for sliced weight, before and after float8 quantization
70
68
# does not differ too much
71
69
input = torch .randn (2 , 256 , dtype = dtype , device = device )
72
70
res_ref = dummy1 (input )
73
- dummy .weight = torch .nn .Parameter (weight1 , requires_grad = False )
71
+ dummy .weight = torch .nn .Parameter (weight1 . contiguous () , requires_grad = False )
74
72
res = dummy (input )
75
73
assert compute_error (res , res_ref ) > 20
76
74
77
75
input = torch .randn (2 , 128 , dtype = dtype , device = device )
78
76
res_ref = dummy2 (input )
79
- dummy .weight = torch .nn .Parameter (weight2 , requires_grad = False )
77
+ dummy .weight = torch .nn .Parameter (weight2 . contiguous () , requires_grad = False )
80
78
res = dummy (input )
81
79
assert compute_error (res , res_ref ) > 15
82
80
83
- def test_slice_and_copy_ (self ):
81
+ def test_slice_preserves_aliasing (self ):
82
+ config = self .config
84
83
l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
85
84
l .weight = torch .nn .Parameter (
86
85
torch .zeros (1024 , 1024 , dtype = torch .bfloat16 , device = "cuda" )
87
86
)
88
- quantize_ (l , self . config )
87
+ quantize_ (l , config )
89
88
param = l .weight
90
89
param_data = param .data
91
90
param_data = param_data .narrow (0 , 0 , 512 )
92
- assert param .data ._data .data_ptr () == param_data ._data .data_ptr ()
91
+ # Making sure the aliasing is preserved in sliced quantized Tensor
92
+ assert param .data .qdata .data_ptr () == param_data .qdata .data_ptr ()
93
93
assert param .data .scale .data_ptr () == param_data .scale .data_ptr ()
94
94
assert param .data .zero_point .data_ptr () == param_data .zero_point .data_ptr ()
95
- orig_value = param .data ._data [0 ][0 ].item ()
96
-
97
- # dummy_l has random input (shouldn't be 0)
98
- dummy_l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
99
- quantize_ (dummy_l , self .config )
100
- quantized = dummy_l .weight
101
- quantized = quantized .narrow (0 , 0 , 512 )
102
95
103
- param_data .copy_ (quantized )
104
-
105
- # making sure param.data is updated
106
- assert param .data ._data [0 ][0 ] != orig_value
96
+ def test_slice_and_copy_similar_to_vllm (self ):
97
+ self ._test_slice_and_copy_similar_to_vllm (self .config )
107
98
99
+ @unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
108
100
def test_bmm (self ):
109
101
class M (torch .nn .Module ):
110
102
def __init__ (self , weight ):
@@ -126,20 +118,103 @@ def forward(self, x):
126
118
quantized = m (input )
127
119
self .assertTrue (compute_error (original , quantized ) > 18 )
128
120
129
- def test_to_device (self ):
121
+ @parametrize (
122
+ "sizes" ,
123
+ [
124
+ ((128 ,), 256 , 128 ),
125
+ ((32 , 128 ), 64 , 256 ),
126
+ ((2 , 32 , 128 ), 64 , 256 ),
127
+ ],
128
+ )
129
+ def test_to_device (self , sizes ):
130
+ config = self .config
131
+ M , N , K = sizes
132
+ dtype = torch .bfloat16
130
133
for device in self .GPU_DEVICES :
131
- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
132
- quantize_ (linear , self .config )
134
+ input_tensor = torch .randn (* M , K , dtype = dtype , device = device )
135
+ linear = torch .nn .Linear (K , N , dtype = dtype )
136
+ quantize_ (linear , config )
133
137
linear .to (device )
138
+ linear (input_tensor )
134
139
135
- linear = torch .nn .Linear (128 , 256 , dtype = torch . bfloat16 )
136
- quantize_ (linear , self . config )
140
+ linear = torch .nn .Linear (K , N , dtype = dtype )
141
+ quantize_ (linear , config )
137
142
linear .to (device = device )
143
+ linear (input_tensor )
138
144
139
- linear = torch .nn .Linear (128 , 256 , dtype = torch . bfloat16 )
140
- quantize_ (linear , self . config )
145
+ linear = torch .nn .Linear (K , N , dtype = dtype )
146
+ quantize_ (linear , config )
141
147
linear .to (device )
148
+ linear (input_tensor )
149
+
150
+ @parametrize (
151
+ "sizes" ,
152
+ [
153
+ ((128 ,), 256 , 128 ),
154
+ ((32 , 128 ), 64 , 256 ),
155
+ ((2 , 32 , 128 ), 64 , 256 ),
156
+ ],
157
+ )
158
+ def test_cat (self , sizes ):
159
+ config = self .config
160
+ dtype = torch .bfloat16
161
+ device = "cuda"
162
+ M , N , K = sizes
163
+ linear1 = torch .nn .Linear (K , N , dtype = dtype , device = device )
164
+ linear2 = torch .nn .Linear (K , N , dtype = dtype , device = device )
165
+ input_cat1 = torch .randn (* M , K , dtype = dtype , device = device )
166
+
167
+ cat_weight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
168
+ dummy_linear1 = torch .nn .Linear (K , N , bias = False , dtype = dtype , device = device )
169
+
170
+ dummy_linear1 .weight = torch .nn .Parameter (cat_weight1 )
171
+ quantize_ (dummy_linear1 , config )
172
+
173
+ quantize_ (linear1 , config )
174
+ quantize_ (linear2 , config )
175
+
176
+ cat_qweight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
177
+ self .assertTrue (cat_qweight1 .shape , (2 * N , K ))
178
+ self .assertEqual (
179
+ dummy_linear1 .weight .qdata ,
180
+ cat_qweight1 .qdata ,
181
+ )
182
+ self .assertEqual (
183
+ dummy_linear1 .weight .scale ,
184
+ cat_qweight1 .scale ,
185
+ )
186
+ self .assertEqual (
187
+ dummy_linear1 .weight .zero_point ,
188
+ cat_qweight1 .zero_point ,
189
+ )
190
+
191
+ # making sure cat_qweight1 can be used for inference
192
+ dummy_linear1 .weight = torch .nn .Parameter (cat_qweight1 , requires_grad = False )
193
+ dummy_linear1 (input_cat1 )
194
+
195
+ # align the scale and zero_point before concatenation
196
+ linear2 .weight .scale = linear1 .weight .scale
197
+ linear2 .weight .zero_point = linear1 .weight .zero_point
198
+ cat_qweight2 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 1 )
199
+ self .assertTrue (cat_qweight2 .shape , (N , 2 * K ))
200
+ ref_data = torch .cat (
201
+ [
202
+ linear1 .weight .qdata ,
203
+ linear2 .weight .qdata ,
204
+ ],
205
+ dim = 1 ,
206
+ )
207
+ ref_scale = linear1 .weight .scale
208
+ ref_zero_point = linear1 .weight .zero_point
209
+ self .assertEqual (cat_qweight2 .qdata , ref_data )
210
+ self .assertEqual (cat_qweight2 .scale , ref_scale )
211
+ self .assertEqual (cat_qweight2 .zero_point , ref_zero_point )
212
+
213
+ def test_moe_weight_reshape_ops (self ):
214
+ self ._test_moe_weight_reshape_ops (self .config )
215
+
142
216
217
+ instantiate_parametrized_tests (TestInt4Tensor )
143
218
144
219
if __name__ == "__main__" :
145
220
run_tests ()
0 commit comments