Skip to content

Commit 135fe9c

Browse files
update subset_transform (#218)
1 parent 5197a64 commit 135fe9c

File tree

9 files changed

+40
-37
lines changed

9 files changed

+40
-37
lines changed

llmc/compression/quantization/awq.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,15 @@ def block_transform(self, block, input_feat, block_kwargs):
195195
@torch.no_grad()
196196
def subset_transform(
197197
self,
198-
layers_dict,
198+
subset,
199199
input_feat,
200-
prev_op,
201-
input_name,
202-
inspect_module,
203200
subset_kwargs,
204201
):
202+
layers_dict = subset['layers']
203+
prev_op = subset['prev_op']
204+
input_name = subset['input'][0]
205+
inspect_module = subset['inspect']
206+
205207
if not check_do_quant(
206208
self.block_idx,
207209
list(layers_dict.keys())[0],

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -465,10 +465,8 @@ def block_transform(self, block, input_feat, block_kwargs):
465465

466466
for index, subset in enumerate(subsets):
467467
logger.info(f'subset: {subset}')
468-
prev_op = subset['prev_op']
469468
layers_dict = subset['layers']
470469
input_name = subset['input'][0]
471-
inspect_module = subset['inspect']
472470
inspect_has_kwargs = subset['has_kwargs']
473471
if inspect_has_kwargs:
474472
if 'sub_keys' in subset:
@@ -478,11 +476,8 @@ def block_transform(self, block, input_feat, block_kwargs):
478476
else:
479477
subset_kwargs = {}
480478
self.subset_transform(
481-
layers_dict,
479+
subset,
482480
input_feat,
483-
prev_op,
484-
input_name,
485-
inspect_module,
486481
subset_kwargs,
487482
)
488483
if self.act_static:

llmc/compression/quantization/dgq.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,13 +276,14 @@ def search_scale_zero_subset(self, layers_dict, input_feat):
276276
@torch.no_grad()
277277
def subset_transform(
278278
self,
279-
layers_dict,
279+
subset,
280280
input_feat,
281-
prev_op,
282-
input_name,
283-
inspect_module,
284281
subset_kwargs,
285282
):
283+
layers_dict = subset['layers']
284+
prev_op = subset['prev_op']
285+
input_name = subset['input'][0]
286+
286287
layers = list(layers_dict.values())
287288
if isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_)):
288289
self.smoothquant_transform(prev_op, layers, input_feat[input_name])

llmc/compression/quantization/gptq.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,13 @@ def block_transform(self, block, input_feat, block_kwargs):
8989
super().block_transform(block, input_feat, block_kwargs)
9090

9191
@torch.no_grad()
92-
def subset_transform(self, layers_dict, *args, **kwargs):
92+
def subset_transform(
93+
self,
94+
subset,
95+
input_feat,
96+
subset_kwargs,
97+
):
98+
layers_dict = subset['layers']
9399
for name in layers_dict:
94100
layer = layers_dict[name]
95101
self.layer_transform(layer, name)

llmc/compression/quantization/osplus.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,15 @@ def search_scale_shift_subset(
171171
@torch.no_grad()
172172
def subset_transform(
173173
self,
174-
layers_dict,
174+
subset,
175175
input_feat,
176-
prev_op,
177-
input_name,
178-
inspect_module,
179176
subset_kwargs,
180177
):
178+
layers_dict = subset['layers']
179+
prev_op = subset['prev_op']
180+
input_name = subset['input'][0]
181+
inspect_module = subset['inspect']
182+
181183
assert (
182184
len(prev_op) == 1
183185
), 'Only support single prev_op. If multi prev_ops, code need to be updated.'

llmc/compression/quantization/rtn.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,8 @@ def block_opt(self, *opt_kwargs):
1919
@torch.no_grad()
2020
def subset_transform(
2121
self,
22-
layers_dict,
22+
subset,
2323
input_feat,
24-
prev_op,
25-
input_name,
26-
inspect_module,
2724
subset_kwargs,
2825
):
2926
pass

llmc/compression/quantization/smoothquant.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,14 @@ def search_scale_subset(self, layers, tensors):
6161
@torch.no_grad()
6262
def subset_transform(
6363
self,
64-
layers_dict,
64+
subset,
6565
input_feat,
66-
prev_op,
67-
input_name,
68-
inspect_module,
6966
subset_kwargs,
7067
):
68+
layers_dict = subset['layers']
69+
prev_op = subset['prev_op']
70+
input_name = subset['input'][0]
71+
7172
if not self.filter_subset(prev_op):
7273
logger.info('Do not transform this subset.')
7374
return

llmc/compression/sparsification/magnitude.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@ def __init__(self, model, sparsity_config, input, padding_mask, config, modality
1414
@torch.no_grad()
1515
def subset_transform(
1616
self,
17-
layers_dict,
17+
subset,
1818
input_feat,
19-
prev_op,
20-
input_name,
21-
inspect_module,
22-
subset_kwargs
19+
subset_kwargs,
2320
):
21+
layers_dict = subset['layers']
22+
2423
layers = list(layers_dict.values())
2524
for layer in layers:
2625
W = layer.weight.data

llmc/compression/sparsification/wanda.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ def get_row_scale(self, layer, act):
3333
@torch.no_grad()
3434
def subset_transform(
3535
self,
36-
layers_dict,
36+
subset,
3737
input_feat,
38-
prev_op,
39-
input_name,
40-
inspect_module,
41-
subset_kwargs
38+
subset_kwargs,
4239
):
40+
layers_dict = subset['layers']
41+
input_name = subset['input'][0]
42+
4343
layers = list(layers_dict.values())
4444
for layer in layers:
4545
scaler_row = self.get_row_scale(layer, input_feat[input_name][0])

0 commit comments

Comments
 (0)