Skip to content

Commit b2369c4

Browse files
committed
Update on "[devtool] make datasink as a sepreate directory"
this diff make data_sink_base and its childrens as a seperate directory for better structure. Differential Revision: [D69732404](https://our.internmc.facebook.com/intern/diff/D69732404/) [ghstack-poisoned]
2 parents 0f863a7 + 4d7144d commit b2369c4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+599
-644
lines changed

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
AddmmPattern,
1414
AddPattern,
1515
BmmPattern,
16+
CatPattern,
1617
Conv1dPattern,
1718
Conv2dPattern,
1819
LayerNormPattern,
@@ -246,6 +247,16 @@ def get_args_and_kwargs_matmul(
246247
return args, kwargs
247248

248249

250+
def get_args_and_kwargs_cat(
251+
inputs_inputs: List[fx.Node], other_inputs: List[fx.Node], op_node: fx.Node
252+
) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]:
253+
args = tuple([inputs_inputs] + other_inputs)
254+
dim = op_node.args[1] if len(op_node.args) > 1 else 0
255+
# pyre-fixme[6]: Incompatible parameter type
256+
kwargs = {"dim": int(dim)}
257+
return args, kwargs
258+
259+
249260
def get_args_and_kwargs_conv(
250261
graph_module: GraphModule,
251262
inputs_inputs: List[fx.Node],
@@ -390,12 +401,17 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
390401
self.mark_fused(p.nodes)
391402

392403
dequants_inputs = []
393-
for node, idx in anchors.inputs:
404+
for node, idx, *_spec in anchors.inputs:
405+
arg = (
406+
node.args[idx]
407+
if isinstance(idx, int)
408+
else node.args[idx[0]][idx[1]]
409+
)
394410
if (
395-
node.args[idx].target
411+
arg.target
396412
== torch.ops.quantized_decomposed.dequantize_per_tensor.default
397413
):
398-
dequants_inputs.append(node.args[idx])
414+
dequants_inputs.append(arg)
399415
dequants_weights = []
400416
for node, idx in anchors.weights:
401417
if (
@@ -434,6 +450,10 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
434450
dequants_inputs,
435451
quant_node,
436452
)
453+
elif isinstance(pattern, CatPattern):
454+
args, kwargs = get_args_and_kwargs_cat(
455+
inputs_inputs, other_inputs, op_node
456+
)
437457
elif isinstance(pattern, (Conv1dPattern, Conv2dPattern)):
438458
args, kwargs = get_args_and_kwargs_conv(
439459
graph_module,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,17 @@ class PartitionAnchors:
3333
is used for other types of input values as well as handling default parameters.
3434
"""
3535

36-
inputs: List[Tuple[fx.Node, int]] = field(default_factory=list)
36+
# Inputs can share quantization parameters
37+
inputs: List[
38+
Union[
39+
Tuple[fx.Node, Union[int, Tuple[int, int]]],
40+
Tuple[
41+
fx.Node,
42+
Union[int, Tuple[int, int]],
43+
SharedQuantizationSpec,
44+
],
45+
]
46+
] = field(default_factory=list)
3747
weights: List[Tuple[fx.Node, int]] = field(default_factory=list)
3848
biases: List[
3949
Union[Tuple[fx.Node, int], Tuple[fx.Node, int, DerivedQuantizationSpec]]
@@ -155,6 +165,52 @@ def replacement_op(self) -> OpOverload:
155165
return torch.ops.cadence.quantized_matmul.default
156166

157167

168+
class CatPattern(QuantizationPattern):
169+
def partition_types(self) -> List[OpOverload]:
170+
return [torch.ops.aten.cat.default]
171+
172+
def get_anchors(
173+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
174+
) -> PartitionAnchors:
175+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
176+
cat_node = fused_partition[0].nodes[-1]
177+
178+
# Create args. The first argument does not have quant spec and
179+
# will inherit from the overall quant spec. All subsequent args
180+
# will share that spec.
181+
# Note that outpus also share that spec.
182+
args: List[
183+
Union[
184+
Tuple[fx.Node, Union[int, Tuple[int, int]]],
185+
Tuple[
186+
fx.Node,
187+
Union[int, Tuple[int, int]],
188+
SharedQuantizationSpec,
189+
],
190+
]
191+
] = [(cat_node, (0, 0))]
192+
for i in range(1, len(cat_node.args[0])):
193+
args.append(
194+
(
195+
cat_node,
196+
(0, i),
197+
SharedQuantizationSpec((cat_node.args[0][0], cat_node)),
198+
)
199+
)
200+
201+
return PartitionAnchors(
202+
inputs=args,
203+
weights=[],
204+
biases=[],
205+
output=[
206+
(cat_node, SharedQuantizationSpec((cat_node.args[0][0], cat_node)))
207+
],
208+
)
209+
210+
def replacement_op(self) -> OpOverload:
211+
return torch.ops.aten.cat.default
212+
213+
158214
class Conv1dPattern(QuantizationPattern):
159215
def partition_types(self) -> List[OpOverload]:
160216
return [torch.ops.aten.conv1d.default]

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
AddmmPattern,
1515
AddPattern,
1616
BmmPattern,
17+
CatPattern,
1718
Conv1dPattern,
1819
Conv2dPattern,
1920
LayerNormPattern,
@@ -144,17 +145,38 @@ def annotate_inputs(
144145
"quantization_annotation",
145146
QuantizationAnnotation(_annotated=True),
146147
)
148+
arg = (
149+
# pyre-ignore[16]: no attribute
150+
node.args[idx]
151+
if isinstance(idx, int)
152+
# pyre-ignore[16]: no attribute
153+
else node.args[idx[0]][idx[1]]
154+
)
155+
annotation.input_qspec_map[arg] = (
156+
custom_spec[0] if custom_spec else spec
157+
)
147158
# pyre-ignore[16]: no attribute
159+
node.meta["quantization_annotation"] = annotation
160+
161+
def annotate_weights_or_biases(
162+
weights_or_biases: List[Tuple[fx.Node, int]],
163+
spec: Optional[QuantizationSpec],
164+
) -> None:
165+
for node, idx, *custom_spec in weights_or_biases:
166+
annotation = node.meta.get(
167+
"quantization_annotation",
168+
QuantizationAnnotation(_annotated=True),
169+
)
148170
annotation.input_qspec_map[node.args[idx]] = (
149171
custom_spec[0] if custom_spec else spec
150172
)
151-
# pyre-ignore[16]: no attribute
152173
node.meta["quantization_annotation"] = annotation
153174

175+
# pyre-ignore[6]: incompatible parameter type
154176
annotate_inputs(anchors.inputs, input_act_qspec)
155-
annotate_inputs(anchors.weights, weight_qspec)
177+
annotate_weights_or_biases(anchors.weights, weight_qspec)
156178
# pyre-ignore[6]: incompatible parameter type
157-
annotate_inputs(anchors.biases, bias_qspec)
179+
annotate_weights_or_biases(anchors.biases, bias_qspec)
158180
return model
159181

160182
def validate(self, model: fx.GraphModule) -> None:
@@ -223,4 +245,5 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
223245
if quantizers is None:
224246
quantizers = get_cadence_default_quantizers()
225247
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8uW8u))
248+
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8uW8u))
226249
super().__init__(quantizers)

backends/qualcomm/_passes/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
from .annotate_and_quant_scalar import AnnotateAndQuantScalar
21
from .annotate_decomposed import AnnotateDecomposed
32
from .annotate_quant_attrs import AnnotateQuantAttrs
43
from .constant_i64_to_i32 import ConstantI64toI32
5-
from .convert_binary_op_with_scalar import ConvertBinaryOpsWithScalar
64
from .convert_bmm_to_matmul import ConvertBmmToMatmul
75
from .convert_interpolate_with_upsample2d import ConvertInterpolateWithUpsample2D
8-
from .convert_prelu import ConvertPReLU
96
from .convert_to_linear import ConvertToLinear
107
from .decompose_any import DecomposeAny
118
from .decompose_einsum import DecomposeEinsum
@@ -17,7 +14,9 @@
1714
from .insert_io_qdq import InsertIOQDQ
1815
from .insert_requantize import InsertRequantize
1916
from .layout_transform import LayoutTransform
17+
from .lift_constant_scalar_operands import LiftConstantScalarOperands
2018
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
19+
from .recompose_prelu import RecomposePReLU
2120
from .recompose_rms_norm import RecomposeRmsNorm
2221
from .reduce_dynamic_range import ReduceDynamicRange
2322
from .remove_redundancy import RemoveRedundancy
@@ -27,14 +26,12 @@
2726

2827

2928
__all__ = [
30-
AnnotateAndQuantScalar,
3129
AnnotateDecomposed,
3230
AnnotateQuantAttrs,
3331
ConstantI64toI32,
3432
ConvertBmmToMatmul,
35-
ConvertBinaryOpsWithScalar,
3633
ConvertInterpolateWithUpsample2D,
37-
ConvertPReLU,
34+
RecomposePReLU,
3835
ConvertToLinear,
3936
DecomposeAny,
4037
DecomposeEinsum,
@@ -46,6 +43,7 @@
4643
InsertIOQDQ,
4744
InsertRequantize,
4845
LayoutTransform,
46+
LiftConstantScalarOperands,
4947
RecomposePixelUnshuffle,
5048
RecomposeRmsNorm,
5149
ReduceDynamicRange,

backends/qualcomm/_passes/annotate_and_quant_scalar.py

Lines changed: 0 additions & 137 deletions
This file was deleted.

0 commit comments

Comments
 (0)