Skip to content

Commit a7d5efe

Browse files
authored
修复了一个由于参数并非常量而可能导致的训练错误 (#229)
1 parent af0fb25 commit a7d5efe

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

ppq/quantization/optim/training.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,10 @@ def cache_fn(data: torch.Tensor):
287287
quant_graph.restore_quantize_state()
288288
for data in dataloader:
289289
if collate_fn is not None: data = collate_fn(data)
290-
qt_input = executor.forward(data, [block.sp.inputs[0].name])
291-
qt_input = {block.sp.inputs[0].name: cache_fn(qt_input[0])}
290+
# PATCH 20220829, 有些 computing op 权重并非定值
291+
non_constant_input = [var for var in block.sp.inputs if not var.is_parameter]
292+
qt_input = executor.forward(data, [var.name for var in non_constant_input])
293+
qt_input = {var.name: cache_fn(value) for var, value in zip(non_constant_input, qt_input)}
292294
qt_inputs.append(qt_input)
293295
cur_iter += 1
294296
if steps is not None and cur_iter > steps: break
@@ -322,7 +324,7 @@ def compute_block_loss(
322324
feed_dict = {k: v.to(executor._device) for k, v in qt_input.items()}
323325

324326
qt_output = executor.partial_graph_forward(
325-
operations=block.rps, feed_dict=feed_dict,
327+
operations=block.rps, feed_dict=feed_dict,
326328
output_names=output_names)
327329

328330
for name, quant_output in zip(output_names, qt_output):
@@ -735,7 +737,8 @@ def finetune(
735737

736738
if op.is_computing_op:
737739
for var in op.inputs[1:]:
738-
trainable_params.append(var.value)
740+
if var.is_parameter:
741+
trainable_params.append(var.value)
739742

740743
# register quant delegator
741744
for cfg, var in op.config_with_variable:

0 commit comments

Comments
 (0)