Skip to content

Support for squeeze and select operators in ExecuTorchΒ #12103

@vikasbalaga

Description

@vikasbalaga

πŸš€ The feature, motivation and pitch

Hi,

I have been working on a torch model that is based on "state space representation" and trying to export it to ExecuTorch and perform inference.

But during export(), I have faced some issues, due to the presence of a python "for loop", in my forward(), which undergoes loop unrolling and resulted in abnormal model size.

For that, I have used the PyTorch "while_loop" operator and with the help of @angelayi , able to perform "Edge IR" conversion of the model.

Post Edge IR conversion, I am facing the following issue :

WARNING:root:Schema: aten::select.t(t[](a) list, int idx) -> t(*) cannot be converted to torch.FunctionSchema
WARNING:root:Schema: aten::select.t(t[](a) list, int idx) -> t(*) cannot be converted to torch.FunctionSchema
RuntimeError: Missing out variants: {'aten::squeeze', 'aten::select'}

So, I am not sure if these operators are supported in ExecuTorch or not and I have tried many workarounds but all of them boil down to same issue.

So, can someone please help me in identifying the issues here i.e., am I missing something or is the support for {'aten::squeeze', 'aten::select'} not present in ExecuTorch.

If the support is not present, can someone provide help to resolve my issue.

Please find below the sample code which will reproduce the issue.

import torch
from torch import nn
from torch.export import export
from executorch.exir import to_edge_transform_and_lower

class ExportableLoop(nn.Module):
    def __init__(self, hidden_size, out_channels):
        super().__init__()
        self.hidden_size = hidden_size
        self.B = nn.Parameter(torch.randn(hidden_size, 1))  # (H, in_channels)
        self.C = nn.Parameter(torch.randn(out_channels, hidden_size))  # (C_out, H)
        A = torch.randn(2, hidden_size)
        self.A_real = nn.Parameter(A[0].clone())
        self.A_imag = nn.Parameter(A[1].clone())

    def update_state(self, h, x_t):
        # h: [B, 2, H], x_t: [B, H]
        hr, hi = h[:, 0, :], h[:, 1, :]  # [B, H]
        hrn = hr * self.A_real - hi * self.A_imag + x_t  # [B, H]
        hin = hi * self.A_real + hr * self.A_imag        # [B, H]
        hn = torch.stack([hrn, hin], dim=1)              # [B, 2, H]
        return hn, hrn

    def forward(self, u):
        # u: [B, 1, T]
        x = torch.matmul(self.B, u)  # (B, H, T)
        B, H, T = x.shape

        h = torch.zeros(B, 2, H, device=x.device, dtype=x.dtype)  # [B, 2, H]
        h_accum = torch.zeros(B, H, T, device=x.device, dtype=x.dtype)  # [B, H, T]
        i = torch.tensor(0, device=x.device, dtype=torch.int64)
        one = torch.tensor(1, device=x.device, dtype=torch.int64)

        def cond(i, h, h_accum):
            return i < T

        def body(i, h, h_accum):
            x_t = x.index_select(-1, i.unsqueeze(0)).squeeze(-1)  # βœ… safe for export
            h, hr = self.update_state(h, x_t)  # h: [B, 2, H], hr: [B, H]
            h_accum = h_accum.index_copy(-1, i.unsqueeze(0), hr.unsqueeze(-1))  # [B, H, T]
            i_next = i + one
            return i_next, h, h_accum

        _, h, h_accum = torch._higher_order_ops.while_loop(cond, body, (i, h, h_accum))
        y = torch.matmul(self.C, h_accum).transpose(0, 1)  # (B, C_out, T)
        return y

# Instantiate and export
model = ExportableLoop(hidden_size=128, out_channels=10)
inp = torch.randn(1, 1, 32)  # (B, in_channels=1, T=32)
exported = export(model, (inp,))
print("Exporting Done...")
executorch_program = to_edge_transform_and_lower(
    exported
)
print("Edge transform Done...")

executorch_program = executorch_program.to_executorch()  ################Added this line

out_path = "loop_model.pte"
with open(out_path, "wb") as file:
    file.write(executorch_program.buffer)

print(f"Succesfully saved model as {out_path}")

Note:
The support for while_loop has been recently added as part of PR#12062.
For more info see this ticket #8769_comment

Alternatives

No response

Additional context

No response

RFC (Optional)

No response

cc @JacobSzwejbka @angelayi

Metadata

Metadata

Assignees

Labels

module: exirIssues related to Export IR and the code under exir/triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

To triage

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions