-
Notifications
You must be signed in to change notification settings - Fork 742
Fix scalar arithemetic and add test cases #6224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/6224
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 55fd7ab with merge base 8673567 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
||
| # If the argument is call_function, match shape by inserting view node. | ||
| if arg.op == "call_function": | ||
| self._match_op_shape(graph_module, node, arg, rank, max_rank, shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't add/sub handle broadcasting in TOSA? This view will become a copy and I am not sure if we can do faster broadcasting than this. scalr to 1D makes sense I am talking about matching shape. Ok with following up on this later if you want to change that is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TOSA handles broadcasting but still requires that the input tensors share the same rank (see https://www.mlplatform.org/tosa/tosa_spec.html#_add). So the pass is actually badly named, should be match_rank. will update.
| super().__init__() | ||
| self.exported_program = exported_program | ||
|
|
||
| targeted_ops = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
other ops can't take scalar? curious why only these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pass might be relevant for other ops but the reasoning is to only target what we know we need and what we test. I.e. to limit the scope of the change.
|
@digantdesai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
Internal errors executorch/backends/arm/_passes/match_shapes_pass.py:80:4 Inconsistent override [15]: executorch/backends/arm/_passes/match_shapes_pass.py:113:8 Incompatible return type [7]: Expected |
Add UnsquezeScalarPlaceholders pass to make scalars rank 1 Add MatchArgRanksPass to guarantee same rank for all inputs for ops that require it. Additional fixes to make Scalar tests pass Map which cases work and which don't. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I4ea5e189e26cf7aff391ec153d525b2fb61aa16f Fix shape issues Change-Id: I0b8588cd5f8b284c25e806bb83bc788067d5b649
9513de0 to
55fd7ab
Compare
|
@digantdesai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@digantdesai merged this pull request in 6669e18. |
Add UnsquezeScalarPlaceholders pass to make scalars rank 1 Add MatchShapesPass to guarantee same rank for all inputs for ops that require it.
Additional fixes to make Scalar tests pass
Map which cases work and which don't.
Signed-off-by: Erik Lundell [email protected]
Change-Id: I4ea5e189e26cf7aff391ec153d525b2fb61aa16f
Fix shape issues
Change-Id: I0b8588cd5f8b284c25e806bb83bc788067d5b649