49
49
)
50
50
from torch .ao .quantization .fx .utils import (
51
51
_get_module ,
52
- assert_and_get_unique_device ,
53
52
collect_producer_nodes ,
54
- create_getattr_from_value ,
55
53
graph_module_from_producer_nodes ,
56
54
node_arg_is_weight ,
57
55
)
73
71
74
72
from torchao .quantization .pt2e import FROM_NODE_KEY
75
73
from torchao .quantization .pt2e .observer import _is_activation_post_process
76
- from torchao .utils import TORCH_VERSION_AT_LEAST_2_6
74
+ from torchao .quantization .pt2e .utils import create_getattr_from_value
75
+ from torchao .utils import (
76
+ TORCH_VERSION_AT_LEAST_2_6 ,
77
+ _assert_and_get_unique_device ,
78
+ )
77
79
78
80
if TORCH_VERSION_AT_LEAST_2_6 :
79
81
from torch .fx .traceback import NodeSource , NodeSourceAction
@@ -132,6 +134,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
132
134
modules : dict [str , torch .nn .Module ],
133
135
node_name_to_scope : dict [str , tuple [str , type ]],
134
136
node_name_to_qconfig : dict [str , QConfigAny ],
137
+ model_device : Optional [torch .device ] = None ,
135
138
) -> None :
136
139
"""Replace activation_post_process module call node with quantize and
137
140
dequantize node working with decomposed Tensor
@@ -260,7 +263,11 @@ def add_quantize_dequantize_node_info(qdq_node, original_node):
260
263
# sure that the default overload can be used.
261
264
# TODO: maybe need more complex attr name here
262
265
qparam_node = create_getattr_from_value (
263
- model , graph , module_path + prefix + key , value_or_node
266
+ model ,
267
+ graph ,
268
+ module_path + prefix + key ,
269
+ value_or_node ,
270
+ model_device ,
264
271
)
265
272
quantize_op_inputs .append (qparam_node )
266
273
else :
@@ -407,6 +414,7 @@ def _replace_observer_with_quantize_dequantize_node(
407
414
modules : dict [str , torch .nn .Module ],
408
415
node_name_to_scope : dict [str , tuple [str , type ]],
409
416
node_name_to_qconfig : dict [str , QConfigAny ],
417
+ model_device : Optional [torch .device ] = None ,
410
418
) -> None :
411
419
"""Replace activation_post_process module call node with quantize and
412
420
dequantize node
@@ -487,7 +495,11 @@ def _replace_observer_with_quantize_dequantize_node(
487
495
# For scale and zero_point values we register them as buffers in the root module.
488
496
# TODO: maybe need more complex attr name here
489
497
qparam_node = create_getattr_from_value (
490
- model , graph , module_path + prefix + key , value_or_node
498
+ model ,
499
+ graph ,
500
+ module_path + prefix + key ,
501
+ value_or_node ,
502
+ model_device ,
491
503
)
492
504
quantize_op_inputs .append (qparam_node )
493
505
else :
@@ -785,6 +797,7 @@ def convert_weighted_module(
785
797
backend_config : BackendConfig ,
786
798
is_decomposed : bool = False ,
787
799
is_reference : bool = False ,
800
+ model_device : Optional [torch .device ] = None ,
788
801
) -> None :
789
802
"""Convert a weighted module to reference quantized module in the model
790
803
If the QConfig of a QAT module is not set, the module will still be converted to
@@ -873,7 +886,10 @@ def convert_weighted_module(
873
886
is_ptq = weight_post_process is None
874
887
if is_ptq :
875
888
weight_post_process = qconfig .weight () # type: ignore[union-attr, operator]
876
- device = assert_and_get_unique_device (float_module )
889
+ if model_device is not None :
890
+ device = model_device
891
+ else :
892
+ device = _assert_and_get_unique_device (float_module )
877
893
if device :
878
894
weight_post_process .to (device )
879
895
@@ -1076,6 +1092,7 @@ def convert(
1076
1092
root_module_classes = tuple (root_module_to_quantized_reference_module .keys ())
1077
1093
qat_module_classes = get_qat_module_classes (backend_config )
1078
1094
fused_module_classes = get_fused_module_classes (backend_config )
1095
+ model_device = _assert_and_get_unique_device (model )
1079
1096
1080
1097
for node in list (model .graph .nodes ):
1081
1098
if node .op == "placeholder" :
@@ -1123,6 +1140,7 @@ def convert(
1123
1140
modules ,
1124
1141
node_name_to_scope ,
1125
1142
node_name_to_qconfig ,
1143
+ model_device ,
1126
1144
)
1127
1145
else :
1128
1146
_replace_observer_with_quantize_dequantize_node (
@@ -1131,6 +1149,7 @@ def convert(
1131
1149
modules ,
1132
1150
node_name_to_scope ,
1133
1151
node_name_to_qconfig ,
1152
+ model_device ,
1134
1153
)
1135
1154
elif isinstance (mod , DeQuantStub ):
1136
1155
_replace_observer_or_dequant_stub_with_dequantize_node (
@@ -1160,6 +1179,7 @@ def convert(
1160
1179
backend_config ,
1161
1180
is_decomposed ,
1162
1181
is_reference ,
1182
+ model_device ,
1163
1183
)
1164
1184
1165
1185
# remove deadcode after converting observers to quant/dequant ops
0 commit comments