Skip to content

Support higher-order while_loop and scan torch ops #8769

@GregoryComer

Description

@GregoryComer

🚀 The feature, motivation and pitch

Pytorch provides several higher-level control flow ops, including while_loop and scan. These ops have to be used explicitly in the PyTorch model (Dynamo can't infer them from Python control flow), but they are useful in expressing more complex models. As a first step to try to handle complex patterns without needing to chop up the model and port logic to C++, we should support some of the higher-order ops in ET.

I believe while_loop, and scan should be a good starting point, as they are sufficient to express many of the data dependent patterns. This may allow for writing some generator loops in pure PyTorch, specially for non-multi-turn generation.

Currently, these ops are not supported on ET. It's likely that we will need both AOT and runtime work to do this. Exporting a simple cond model gives the following error. Simple scan and while_loop models also fail with various errors

Edit: From Tarun's comment, cond should be supported, but it seems like there are some issues to resolve, as I haven't been able to use cond when I've tried it. I will create specific tasks for cond. I've removed it from this issue for now.

Alternatives

No response

Additional context

No response

RFC (Optional)

No response

cc @JacobSzwejbka @angelayi

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: exirIssues related to Export IR and the code under exir/module: runtimeIssues related to the core runtime and code under runtime/triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    Backlog

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions