@@ -287,8 +287,10 @@ def cache_fn(data: torch.Tensor):
287287 quant_graph .restore_quantize_state ()
288288 for data in dataloader :
289289 if collate_fn is not None : data = collate_fn (data )
290- qt_input = executor .forward (data , [block .sp .inputs [0 ].name ])
291- qt_input = {block .sp .inputs [0 ].name : cache_fn (qt_input [0 ])}
290+ # PATCH 20220829, 有些 computing op 权重并非定值
291+ non_constant_input = [var for var in block .sp .inputs if not var .is_parameter ]
292+ qt_input = executor .forward (data , [var .name for var in non_constant_input ])
293+ qt_input = {var .name : cache_fn (value ) for var , value in zip (non_constant_input , qt_input )}
292294 qt_inputs .append (qt_input )
293295 cur_iter += 1
294296 if steps is not None and cur_iter > steps : break
@@ -322,7 +324,7 @@ def compute_block_loss(
322324 feed_dict = {k : v .to (executor ._device ) for k , v in qt_input .items ()}
323325
324326 qt_output = executor .partial_graph_forward (
325- operations = block .rps , feed_dict = feed_dict ,
327+ operations = block .rps , feed_dict = feed_dict ,
326328 output_names = output_names )
327329
328330 for name , quant_output in zip (output_names , qt_output ):
@@ -735,7 +737,8 @@ def finetune(
735737
736738 if op .is_computing_op :
737739 for var in op .inputs [1 :]:
738- trainable_params .append (var .value )
740+ if var .is_parameter :
741+ trainable_params .append (var .value )
739742
740743 # register quant delegator
741744 for cfg , var in op .config_with_variable :
0 commit comments