Skip to content

Commit 134de54

Browse files
jerryzh168Wei Wei
authored andcommitted
[quant][fx][bc-breaking] Add required example_inputs argument to prepare_fx and prepare_qat_fx (#77608)
Summary: X-link: pytorch/pytorch#77608 Pull Request resolved: pytorch/fx2trt#76 X-link: facebookresearch/d2go#249 X-link: fairinternal/ClassyVision#104 X-link: pytorch/benchmark#916 X-link: facebookresearch/ClassyVision#791 X-link: facebookresearch/mobile-vision#68 FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors. Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base. As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now. If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to pass the arguments by keyword BC-breaking Note: Before: ```python m = resnet18(...) m = prepare_fx(m, qconfig_dict) # or m = prepare_qat_fx(m, qconfig_dict) ``` After: ```python m = resnet18(...) m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),)) # or m = prepare_qat_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),)) ``` Reviewed By: vkuzo, andrewor14 Differential Revision: D35984526 fbshipit-source-id: 706c8df71722c9aa5082a6491734f0144f0dd670
1 parent 4086fdc commit 134de54

File tree

1 file changed

+60
-12
lines changed

1 file changed

+60
-12
lines changed

test/quant/test_quant_trt.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,14 @@ def forward(self, x):
8484

8585
# quantized input, quantized output
8686
m = M()
87-
qconfig_dict = {"": torch.ao.quantization.default_qconfig}
8887
m.eval()
88+
qconfig_dict = {"": torch.ao.quantization.default_qconfig}
89+
example_inputs = (torch.rand(1, 1, 3, 3),)
8990
mp = torch.ao.quantization.quantize_fx.prepare_fx(
90-
m, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict
91+
m,
92+
qconfig_dict,
93+
example_inputs,
94+
prepare_custom_config_dict=prepare_custom_config_dict,
9195
)
9296
self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check)
9397
mp(torch.randn(1, 1, 4, 4))
@@ -221,20 +225,29 @@ def forward(self, x):
221225
original_m.standalone.conv.bias.detach()
222226
)
223227

228+
sm_example_inputs = (data,)
224229
prepare_config = {
225230
"standalone_module_name": [
226-
("standalone", None, interface_config, backend_config_dict)
231+
(
232+
"standalone",
233+
None,
234+
sm_example_inputs,
235+
interface_config,
236+
backend_config_dict,
237+
)
227238
]
228239
}
229240

230241
original_m_copy = copy.deepcopy(original_m)
231242
original_ref_m_copy = copy.deepcopy(original_ref_m)
232243

233244
qconfig_dict = {"": qconfig}
245+
example_inputs = (data,)
234246
# check prepared model
235247
m = prepare_fx(
236248
original_m_copy,
237249
qconfig_dict,
250+
example_inputs,
238251
prepare_custom_config_dict=prepare_config,
239252
backend_config_dict=backend_config_dict,
240253
)
@@ -255,7 +268,10 @@ def forward(self, x):
255268

256269
# quantize the reference model
257270
ref_m = prepare_fx(
258-
original_ref_m_copy, qconfig_dict, backend_config_dict=backend_config_dict
271+
original_ref_m_copy,
272+
qconfig_dict,
273+
example_inputs,
274+
backend_config_dict=backend_config_dict,
259275
)
260276
ref_m(data)
261277
ref_m = convert_fx(
@@ -410,8 +426,12 @@ def _test_module(
410426
else:
411427
m = m.eval()
412428
prepare = prepare_fx
429+
example_inputs = tuple(inputs)
413430
prepared = prepare(
414-
m, {"": self.trt_qconfig}, backend_config_dict=self.trt_backend_config_dict
431+
m,
432+
{"": self.trt_qconfig},
433+
example_inputs,
434+
backend_config_dict=self.trt_backend_config_dict,
415435
)
416436
self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_prepare)
417437
# calibration
@@ -528,8 +548,12 @@ def forward(self, x):
528548
return x
529549

530550
m = M().eval()
551+
example_inputs = (torch.rand(1, 3, 5, 5),)
531552
m = prepare_fx(
532-
m, {"": self.trt_qconfig}, backend_config_dict=self.trt_backend_config_dict
553+
m,
554+
{"": self.trt_qconfig},
555+
example_inputs,
556+
backend_config_dict=self.trt_backend_config_dict,
533557
)
534558
m = convert_fx(
535559
m, is_reference=True, backend_config_dict=self.trt_backend_config_dict
@@ -558,9 +582,11 @@ def forward(self, x):
558582

559583
m = LinearModule().eval()
560584
trt_unsupported_qconfig = default_qconfig
585+
example_inputs = (torch.rand(1, 5),)
561586
prepared = prepare_fx(
562587
m,
563588
{"": trt_unsupported_qconfig},
589+
example_inputs=example_inputs,
564590
backend_config_dict=self.trt_backend_config_dict,
565591
)
566592
# calibration
@@ -588,8 +614,12 @@ def forward(self, x):
588614
return torch.cat([x, x], 1)
589615

590616
m = M().eval()
617+
example_inputs = (torch.rand(2, 2),)
591618
prepared = prepare_fx(
592-
m, {"": self.trt_qconfig}, backend_config_dict=self.trt_backend_config_dict
619+
m,
620+
{"": self.trt_qconfig},
621+
example_inputs,
622+
backend_config_dict=self.trt_backend_config_dict,
593623
)
594624
self.assertTrue(len(dict(prepared.named_children())) == 1)
595625
quantized = convert_fx(
@@ -615,8 +645,12 @@ def forward(self, x):
615645
return torch.addmm(self.bias, x, self.weight)
616646

617647
m = M().eval()
648+
example_inputs = (torch.rand(1, 5),)
618649
prepared = prepare_fx(
619-
m, {"": self.trt_qconfig}, backend_config_dict=self.trt_backend_config_dict
650+
m,
651+
{"": self.trt_qconfig},
652+
example_inputs,
653+
backend_config_dict=self.trt_backend_config_dict,
620654
)
621655
node_occurrence = {
622656
# weight
@@ -684,8 +718,12 @@ def conv_add_extra_inputs_getter(pattern):
684718
m = M().eval()
685719
modified_backend_config_dict = copy.deepcopy(self.trt_backend_config_dict)
686720
modified_backend_config_dict["configs"].insert(0, conv_add_config)
721+
example_inputs = (torch.rand(1, 3, 3, 3), torch.rand(1, 3, 1, 1))
687722
m = prepare_fx(
688-
m, {"": self.trt_qconfig}, backend_config_dict=modified_backend_config_dict
723+
m,
724+
{"": self.trt_qconfig},
725+
example_inputs,
726+
backend_config_dict=modified_backend_config_dict,
689727
)
690728
print(m)
691729
node_occurrence = {
@@ -717,7 +755,7 @@ def __init__(self):
717755
self.conv = torch.nn.Conv2d(3, 3, 3)
718756
self.standalone = Standalone()
719757

720-
def forward(self, x, y):
758+
def forward(self, x):
721759
y = self.conv(x)
722760
return self.standalone(x, y)
723761

@@ -765,9 +803,16 @@ def forward(self, x, y):
765803
conv_config,
766804
]
767805
}
806+
sm_example_inputs = (torch.rand(1, 3, 3, 3), torch.rand(1, 3, 1, 1))
768807
prepare_custom_config_dict = {
769808
"standalone_module_name": [
770-
("standalone", None, {"input_quantized_idxs": [0, 1]}, None)
809+
(
810+
"standalone",
811+
None,
812+
sm_example_inputs,
813+
{"input_quantized_idxs": [0, 1]},
814+
None,
815+
)
771816
]
772817
}
773818
# TODO: use self.trt_qconfig after input_quantized_idxs and output_quantized_idxs
@@ -778,9 +823,11 @@ def forward(self, x, y):
778823
),
779824
weight=torch.ao.quantization.default_weight_observer,
780825
)
826+
example_inputs = (torch.rand(1, 3, 5, 5),)
781827
m = prepare_fx(
782828
m,
783829
{"": qconfig},
830+
example_inputs,
784831
prepare_custom_config_dict=prepare_custom_config_dict,
785832
backend_config_dict=backend_config_dict,
786833
)
@@ -829,10 +876,11 @@ def forward(self, x):
829876

830877
model = LinearModule().eval()
831878
inputs = [torch.rand(8, 5)]
832-
879+
example_inputs = tuple(inputs)
833880
prepared = prepare_fx(
834881
model,
835882
{"": self.trt_qconfig},
883+
example_inputs,
836884
backend_config_dict=self.trt_backend_config_dict,
837885
)
838886
quantized = convert_fx(

0 commit comments

Comments
 (0)