- 
                Notifications
    You must be signed in to change notification settings 
- Fork 706
support qnn mean (dim=None) #14675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support qnn mean (dim=None) #14675
Conversation
| 🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/14675
 Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. | 
| This PR needs a  | 
| self.lower_module_and_test_output(module, sample_input) | ||
|  | ||
| def test_qnn_backend_mean(self): | ||
| modules = [Mean(), Mean()] # noqa: F405 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need more configurations here? like Mean(dim=0), Mean(dim=0, keepdim=True).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
4e25030    to
    6948323      
    Compare
  
    Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
6948323    to
    c50ec69      
    Compare
  
    Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
c50ec69    to
    2b6dc7f      
    Compare
  
    | Does it look good? I think it can fix 5 failing op tests | 
2b6dc7f    to
    b39d9fb      
    Compare
  
    Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
| # Scalar case | ||
| { | ||
| QCOM_MODULE: Mean(), | ||
| QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),), | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be torch.tensor(5.0) if you want to test scalar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, good catch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this is indeed a good test. ReduceMean doesn't support 0d tensor, I think we need to have a pass to convert 0d to 1d tensor if the input of mean is 0d. I have some sketch code here. What do you think? For now, I comment out the test case and think we can follow up on this
import torch
from executorch.exir.pass_base import ExportPass, PassResult
class Rank0ToRank1(ExportPass):
    """
    For selected ops and selected input positions, if the input is rank-0 (scalar),
    insert a reshape to [1] before the op.
    """
    def __init__(self, op_input_map=None) -> None:
        super().__init__()
        # key is the op, value is the input indices to be reshaped
        self.op_input_map = {
          torch.ops.aten.mean.dim: [0],
        }
    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
        graph = graph_module.graph
        changed = False
        for node in list(graph.nodes):
            if node.op == "call_function" and node.target in self.op_input_map:
                input_indices = self.op_input_map[node.target]
                new_args = list(node.args)
                for idx in input_indices:
                    if idx < len(new_args):
                        inp = new_args[idx]
                        if hasattr(inp, "meta"):
                            val = inp.meta.get("val", None)
                            if val is not None and hasattr(val, "shape") and val.shape == ():
                                # Insert reshape right before the op
                                with graph.inserting_before(node):
                                    reshape_node = graph.call_function(
                                        torch.ops.aten.reshape.default,
                                        args=(inp, (1,))
                                    )
                                    reshape_node.meta["val"] = val.reshape(1,)
                                # Replace arg idx with reshape_node
                                new_args[idx] = reshape_node
                                changed = True
                # update node args if modified
                node.args = tuple(new_args)
        if changed:
            graph_module.recompile()
        return PassResult(graph_module, changed)
b39d9fb    to
    cd343ee      
    Compare
  
    Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
cd343ee    to
    70f1009      
    Compare
  
    70f1009    to
    255c5c9      
    Compare
  
    Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
255c5c9    to
    1ffcbd9      
    Compare
  
    Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776
1ffcbd9    to
    0a200df      
    Compare
  
    | @pytorchbot cherry-pick --onto release/1.0 -c regression | 
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape Differential Revision: D83520776 (cherry picked from commit 9ab5592)
| Cherry picking #14675The cherry pick PR is at #14755 and it is recommended to link a regression cherry pick PR with an issue. The following tracker issues are updated: Details for Dev Infra teamRaised by workflow job | 
Summary: Address mean op lower failure. When dim is not specified, it will take mean across all axes. For QNN, we need to get axes based on input shape
Differential Revision: D83520776