Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions exir/passes/replace_broken_ops_with_function_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,10 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
from typing import Dict

import torch

from executorch.exir.pass_base import ExportPass

from torch._ops import OpOverload


_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = {
torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default,
torch.ops.aten.t.default: torch.ops.aten.t_copy.default,
torch.ops.aten.view.default: torch.ops.aten.view_copy.default,
torch.ops.aten.expand.default: torch.ops.aten.expand_copy.default,
torch.ops.aten.permute.default: torch.ops.aten.permute_copy.default,
torch.ops.aten.squeeze.default: torch.ops.aten.squeeze_copy.default,
torch.ops.aten.unsqueeze.default: torch.ops.aten.unsqueeze_copy.default,
torch.ops.aten.slice.Tensor: torch.ops.aten.slice_copy.Tensor,
}


class ReplaceBrokenOpsWithFunctionalOpsPass(ExportPass):
"""
Expand All @@ -37,8 +21,22 @@ class ReplaceBrokenOpsWithFunctionalOpsPass(ExportPass):

# pyre-ignore
def call_operator(self, op, args, kwargs, meta):
if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS:
return super().call_operator(
_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op], args, kwargs, meta
if op.is_view:
namespace, op_full_name = op.name().split("::")
split = op_full_name.split(".")
if len(split) == 2:
op_name, overload_name = split[0], split[1]
elif len(split) == 1:
# Add default overload if no overload listed
op_name = op_full_name
overload_name = "default"
else:
raise RuntimeError(
f"Invalid op name expected only one '.' to be present: {op_full_name}"
)

view_copy_op = getattr(
getattr(getattr(torch.ops, namespace), f"{op_name}_copy"), overload_name
)
return super().call_operator(view_copy_op, args, kwargs, meta)
return super().call_operator(op, args, kwargs, meta)
79 changes: 62 additions & 17 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,33 +595,78 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
self.assertEqual(counter, 1)

def test_compile_fix_broken_ops(self) -> None:
# When pass an input of more than 4 dimensions to Linear
# aten._unsafe_view is used under the hood
x = torch.randn([2, 3, 4, 5])
model: torch.nn.Linear = torch.nn.Linear(5, 5)

class Foo(torch.nn.Module):
def __init__(self):
class ExportableLoop(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.model = model

def forward(self, inp: torch.Tensor) -> torch.Tensor:
return self.model(inp)

f = Foo()
self.hidden_size = hidden_size
self.B = nn.Parameter(torch.randn(hidden_size, 1)) # (H, in_channels)
self.C = nn.Parameter(
torch.randn(out_channels, hidden_size)
) # (C_out, H)
A = torch.randn(2, hidden_size)
self.A_real = nn.Parameter(A[0].clone())
self.A_imag = nn.Parameter(A[1].clone())

def update_state(self, h, x_t):
# h: [B, 2, H], x_t: [B, H]
hr, hi = h[:, 0, :], h[:, 1, :] # [B, H]
hrn = hr * self.A_real - hi * self.A_imag + x_t # [B, H]
hin = hi * self.A_real + hr * self.A_imag # [B, H]
hn = torch.stack([hrn, hin], dim=1) # [B, 2, H]
return hn, hrn

def forward(self, u):
# u: [B, 1, T]
x = torch.matmul(self.B, u) # (B, H, T)
B, H, T = x.shape

h = torch.zeros(B, 2, H, device=x.device, dtype=x.dtype) # [B, 2, H]
h_accum = torch.zeros(
B, H, T, device=x.device, dtype=x.dtype
) # [B, H, T]
i = torch.tensor(0, device=x.device, dtype=torch.int64)
one = torch.tensor(1, device=x.device, dtype=torch.int64)

def cond(i, h, h_accum):
return i < T

def body(i, h, h_accum):
x_t = x.index_select(-1, i.unsqueeze(0)).squeeze(
-1
) # ✅ safe for export
h, hr = self.update_state(h, x_t) # h: [B, 2, H], hr: [B, H]
h_accum = h_accum.index_copy(
-1, i.unsqueeze(0), hr.unsqueeze(-1)
) # [B, H, T]
i_next = i + one
return i_next, h, h_accum

_, h, h_accum = torch._higher_order_ops.while_loop(
cond, body, (i, h, h_accum)
)
y = torch.matmul(self.C, h_accum).transpose(0, 1) # (B, C_out, T)
return y

# ReplaceBrokenOpsWithFunctionalOpsPass is used in to_edge()
# Instantiate and export
model = ExportableLoop(hidden_size=128, out_channels=10)
inp = torch.randn(1, 1, 32) # (B, in_channels=1, T=32)
ep = export(model, (inp,))
prog = to_edge(
export(f, (x,), strict=True),
ep,
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)
gm = prog.exported_program().graph_module
count_after = 0
for node in gm.graph.nodes:
if node.target == torch.ops.aten._unsafe_view.default:
if (
node.target == torch.ops.aten.squeeze.dims
or node.target == torch.ops.aten.select.int
):
count_after += 1
self.assertEqual(count_after, 0)
self.assertTrue(torch.allclose(prog.exported_program().module()(x), f(x)))
self.assertTrue(
torch.allclose(prog.exported_program().module()(inp), model(inp))
)

def test_convert_symb_ops(self) -> None:
class Foo(torch.nn.Module):
Expand Down
Loading