Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,24 @@ def get_anchors(
)


class SingleInputBasicPattern(QuantizationPattern):
@abstractmethod
def partition_types(self) -> list[OpOverload]:
pass

def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
node = fused_partition[0].nodes[-1]

return PartitionAnchors(
inputs=[(node, NodeArgsIdx(0))],
weights=[],
biases=[],
output=[(node,)],
)


def get_anchors_for_fixed_quant_specs(
fused_partition: list[fx.GraphModule],
scale: float,
Expand Down Expand Up @@ -376,7 +394,7 @@ def partition_types(self):
return [torch.ops.aten.flatten.using_ints]


class HardTanhPattern(QuantizationPattern):
class HardTanhPattern(SingleInputBasicPattern):
"""
Quantizer for HardTanh operator. Shared quantization spec is selected, as activation functions usually follows
computation layer.
Expand All @@ -385,23 +403,12 @@ class HardTanhPattern(QuantizationPattern):
def partition_types(self):
return [torch.ops.aten.hardtanh.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
node = fused_partition[0].nodes[-1]

return PartitionAnchors(
inputs=[(node, NodeArgsIdx(0))],
weights=[],
biases=[],
output=[(node,)],
)

def replacement_op(self):
raise AssertionError()


class HardTanhInPlacePattern(QuantizationPattern):
class HardTanhInPlacePattern(SingleInputBasicPattern):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't you forget to remove the get_anchors() method?

"""
Quantizer for HardTanh operator with param inplace=True. Shared quantization spec is selected, as activation
functions usually follows computation layer.
Expand Down Expand Up @@ -513,19 +520,18 @@ def partition_types(self):
return [torch.ops.aten.permute.default]


class ReluPattern(SharedSpecPattern):
class ReluPattern(SingleInputBasicPattern):
"""
Quantizer for Relu operator. Shared quantization spec is selected, as ReLU usually follows computation layer.
Quantizer for Relu operator.
"""

def partition_types(self):
return [torch.ops.aten.relu.default]


class ReluInPlacePattern(SharedSpecPattern):
class ReluInPlacePattern(SingleInputBasicPattern):
"""
Quantizer for Relu operator with param inplace=True. Shared quantization spec is selected, as ReLU usually
follows computation layer.
Quantizer for Relu operator with param inplace=True.
"""

def partition_types(self):
Expand Down