@@ -272,7 +272,8 @@ class QuantizationTransformPass(object):
272
272
the quantized ops's inputs.
273
273
"""
274
274
_supported_quantizable_op_type = [
275
- 'conv2d' , 'depthwise_conv2d' , 'conv2d_transpose' , 'mul' , 'matmul'
275
+ 'conv2d' , 'depthwise_conv2d' , 'conv2d_transpose' , 'mul' , 'matmul' ,
276
+ 'matmul_v2'
276
277
]
277
278
278
279
def __init__ (self ,
@@ -520,6 +521,16 @@ def _transform_backward(graph, op):
520
521
dequant_var_node = dequantized_vars [var_node .name ()]
521
522
graph .update_input_link (var_node , dequant_var_node , op )
522
523
524
+ def _has_weight (op ):
525
+ has_weight = False
526
+ for var_node in op .inputs :
527
+ if var_node .name () not in op .input_arg_names ():
528
+ continue
529
+ name = var_node .name ()
530
+ if var_node .name () in persistable_vars :
531
+ has_weight = True
532
+ return has_weight
533
+
523
534
if not self ._is_test :
524
535
self ._create_global_step (graph )
525
536
ops = graph .all_op_nodes ()
@@ -535,11 +546,11 @@ def _transform_backward(graph, op):
535
546
# The loop for transforming the forward graph:
536
547
for op in ops :
537
548
if op .name () in self ._quantizable_ops :
538
- if not self ._is_skip_quant (graph , op ):
549
+ if not self ._is_skip_quant (graph , op ) and _has_weight ( op ) :
539
550
_transform_forward (graph , op )
540
551
# The loop for renaming the inputs of backward op.
541
552
for op in ops :
542
- if op .name () in self ._quantizable_grad_ops :
553
+ if op .name () in self ._quantizable_grad_ops and _has_weight ( op ) :
543
554
_transform_backward (graph , op )
544
555
graph .resolve_hazard ()
545
556
return graph
0 commit comments