Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ python_library(
"//executorch/exir/dialects:lib",
"//executorch/exir/passes:lib",
"//executorch/exir/passes:spec_prop_pass",
"//executorch/backends/transforms:remove_clone_ops"
],
)

Expand Down Expand Up @@ -118,3 +119,15 @@ python_unittest(
"//executorch/exir:pass_base",
],
)

python_library(
name = "compiler_utils",
srcs = [
"compiler_utils.py",
],
typing = True,
deps = [
"//caffe2:torch",
"//executorch/exir/dialects:lib",
],
)
302 changes: 302 additions & 0 deletions backends/cadence/aot/compiler_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict


# This file contains all the helper utility functions.

from itertools import zip_longest
from math import frexp, isclose, trunc
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union

import torch
import torch.fx

from executorch.exir.dialects._ops import ops as exir_ops
from torch.utils._pytree import tree_flatten


# Return the output node of the graph
def get_output_node(graph: torch.fx.Graph) -> torch.fx.Node:
assert graph is not None, "Cannot get output of an empty graph"
output_node = next(iter(reversed(graph.nodes)))
assert (
output_node and output_node.op == "output" and len(output_node.args) == 1
), "Failed to find output node"
return output_node


# Return true if the node is part of the flattened output
def is_node_in_flattened_output(graph: torch.fx.Graph, node: torch.fx.Node) -> bool:
output_node = get_output_node(graph)
return node in tree_flatten(output_node.args[0])[0]


# Returns a list with placeholders/inputs
def get_placeholders(graph: torch.fx.Graph) -> List[torch.fx.Node]:
return list(filter(lambda x: x.op == "placeholder", graph.nodes))


# Return the shape of the incoming node.
def get_shape(
graph_module: torch.fx.GraphModule, node: torch.fx.Node
) -> Union[torch.Size, None]:
"""
Return the shape of the tensor correspnding to node. If the node has a
tensor spec, return the shape from the metadata. If the node is a param,
return it shape. Otherwise return None.
"""
try:
# Case 1. node is a scalar (this pass happens before tensorization)
if isinstance(node, (float, int, bool)):
return torch.Size([1])
# Case 2. node has TensorSpec metadata
fake_tensor = node.meta.get("val")
if fake_tensor is not None:
return fake_tensor.shape
# Case 3. node holds a param
if node.op == "get_attr":
attr_node = getattr(graph_module, node.target)
return attr_node.shape
# Default: return None
return None
except RuntimeError:
return None


# Return true if shape_2 can be broadcasted to shape_1
def broadcastable(shape_1: Sequence[int], shape_2: Sequence[int]) -> bool:
"""
Check if 'shape_2' can be broadcasted to 'shape_1'. The broadcast is
feasible if:
(1) shape_2 does not have higher dimensionality than shape_1;
(2) the value at each dimension of shape_2 is either the same as shape_1 or 1;
(3) shape_1 or shape_2 is empty.
"""
return (
not shape_1
or not shape_2
or all(
x == y or y == 1 or y is None
for x, y in zip_longest(shape_1[::-1], shape_2[::-1])
)
)


# Return a chain of nodes with target in op_targets
def get_cascaded_ops(
nodes: List[torch.fx.Node],
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
op_targets: Iterable[Union[Callable[..., Any], str]],
) -> Sequence[torch.fx.Node]:
"""
'nodes' contains a chain of ops with target in 'op_targets'. Extend that chain
by one if nodes[-1] has a single user with its op target in 'op_targets'.
"""
cur = nodes[-1]
users = list(cur.users.keys())
# Assert that (a) there is only one user of cur, and (b) that user is
# one of the op in op_targets.
if len(users) == 1 and users[0].target in op_targets:
nodes.append(users[0])
# Recursively find the chain starting at the user
return get_cascaded_ops(nodes, op_targets)

return nodes


# Capture the effect of transpose op on incoming dimension order
def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
"""
Given a transpose node, and the incoming dimension ordering of the input
tensor to the transpose node, return the net effect of transpose op on the
dimension order.
"""
assert node.target == exir_ops.edge.aten.transpose_copy.int
# Assert that the dims is not empty
assert dims is not None
dim_len = len(dims)
# Get dim0 and dim1 from the transpose op args
transpose_dims0 = node.args[1]
transpose_dims1 = node.args[2]
assert isinstance(transpose_dims0, int)
assert isinstance(transpose_dims1, int)
dim0 = transpose_dims0 if transpose_dims0 >= 0 else transpose_dims0 + dim_len
dim1 = transpose_dims1 if transpose_dims1 >= 0 else transpose_dims1 + dim_len
# Perform transpose on dimmension ordering (dims)
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
return dims


# Capture the effect of permute op on incoming dimension order
def get_permuted_dims(node: torch.fx.Node, dims: Optional[List[int]]) -> List[int]:
"""
Given a permute node, and the incoming dimension ordering of the input
tensor to the permute node, return the net effect of permute op on the
dimension order.
"""
assert node.target == exir_ops.edge.aten.permute_copy.default
# Permute each index of the dimension ordering (dims)
permute_dims = node.args[1]
assert isinstance(permute_dims, List)
assert all(isinstance(x, int) for x in permute_dims)
# If the dims is empty, we can simply return the permute order
if not dims:
return permute_dims
dims = [dims[x] for x in permute_dims]
return dims


# Return the tensor of buffer/parameter op
def get_tensor_from_attr(
graph_module: torch.fx.GraphModule, node: Optional[torch.fx.Node]
) -> Optional[torch.Tensor]:
"""
For an input node that is a named buffer or parameter, return
the underlying tensor.
"""
if node is None:
return None
assert node.op == "get_attr"
return getattr(graph_module, node.target)


def is_node_with_op(node: torch.fx.Node, op: str) -> bool:
"""
Return true if the incoming node has the given op type
"""
return node.op == op


def count_users_with_target_op_type(
nodes: Iterable[torch.fx.Node],
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
op_target: Union[Callable[..., Any], str],
) -> int:
"""
Given a set of nodes and a node target type `op_target`, iterate over all
the users of nodes, and return the total number of users with target
op_target.
"""

def contributions_per_node(
node: torch.fx.Node,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
op_target: Union[Callable[..., Any], str],
) -> int:
return [use.target for use in node.users if use.op == "call_function"].count(
op_target
)

return sum([contributions_per_node(node, op_target) for node in nodes])


def contains_node_with_matching_target(
nodes: Iterable[torch.fx.Node],
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
op_target: Union[Callable[..., Any], str],
) -> bool:
"""
Given a list of nodes, return true if any node in the list has target
'op_target'.
"""
return any(node.target == op_target for node in nodes)


def is_quantized_tensor(x: torch.Tensor) -> bool:
"""
Return true if the tensor x is quantized
"""
return x.is_quantized


def get_scale(x: torch.Tensor) -> torch.Tensor:
"""
Return the scale of a quantized tensor as a float32 tensor.
"""
return (
x.q_per_channel_scales().to(torch.float32)
if x.qscheme() == torch.per_channel_affine
else torch.tensor([x.q_scale()], dtype=torch.float32)
)


def get_zero_point(x: torch.Tensor, reduce: bool = True) -> torch.Tensor:
"""
Return the zero point of a quantized tensor as int32 tensor.
"""
# If x was quantized per-tensor, simply create a tensor out of the scalar
# zero_point, and return it.
if x.qscheme() == torch.per_tensor_affine:
return torch.tensor([x.q_zero_point()], dtype=torch.int32)
# If x was quantized per-channel, check if the zero_point is all zeros. If
# so, then we can compress the zero_point tensor to a scalar.
assert x.qscheme() == torch.per_channel_affine, "Unhandled quantization scheme"
zero_point = x.q_per_channel_zero_points().to(torch.int32)
return (
torch.tensor([zero_point[0]], dtype=torch.int32)
if reduce and all(zero_point == zero_point[0])
else zero_point
)


def quantize_tensor_multiplier(
requantize_scale_tensor: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given requantize_scale_tensor with values in the interval (0, 1),
produce a pair of tensors (out_multiplier, right_shift) where out_multiplier
is an int32 tensor representing fixed-point values in the interval [-1, 1),
and right_shift is an amount to shift right by, so that the floating-point
multiplication of some int32 input with each value of requantize_scale_tensor:
result = int32_value * requantize_scale_tensors[i]
is best approximated by the integer-arithmetic-only code:
result = RoundingRightShift(FixedPointMultiplication(int32_value,
out_multiplier[i]), right_shift[i])
"""

# This is identical to C++11 std::round(). The general python round rounds
# down, and C++ rounds away from zero.
# pyre-fixme[2]: Parameter must be annotated.
def round_away_zero(f) -> int:
r = -0.5 if (f < 0) else 0.5
return trunc(f + r)

def quantize_scalar_multiplier(requantize_scale: float) -> Tuple[int, int]:
significand, exponent = frexp(requantize_scale)
significand_q31 = int(round_away_zero(significand * (1 << 31)))
# Handle the special case when the real multiplier was so close to 1
# that its fixed-point approximation was indistinguishable from 1.
# We handle this by dividing it by two, incrementing exponent by 1.
# the right shift amount.
if significand_q31 == (1 << 31):
significand_q31 //= 2
exponent += 1

# Verify that the decomposition of requantize_scale into significand
# and exponent is correct.
reconstructed = significand_q31 / (1 << 31) * pow(2, exponent)
assert isclose(
requantize_scale, reconstructed, rel_tol=1e-4, abs_tol=1e-4
), "computation of significand and exponent from requantize_scale is not accurate"

return (significand_q31, exponent)

# Flatten the input scale tensor so that we can operate on individual values
orig_shape = requantize_scale_tensor.shape
flattened_tensor = requantize_scale_tensor.flatten().to(torch.float32)
out_multiplier = torch.zeros(flattened_tensor.shape, dtype=torch.int32)
right_shift = torch.zeros(flattened_tensor.shape, dtype=torch.int32)

# Iterate over the flattened scale tensor and compute the decomposition of
# each value in scale tensor into significand(out_multiplier) and
# exponent(right_shift)
for idx, scale in enumerate(flattened_tensor):
(si, ex) = quantize_scalar_multiplier(scale)
out_multiplier[idx], right_shift[idx] = si, ex

# Reshape the tensors back to the original shape
out_multiplier = out_multiplier.reshape(orig_shape)
right_shift = right_shift.reshape(orig_shape)

return (out_multiplier, right_shift)
Loading
Loading