-
Notifications
You must be signed in to change notification settings - Fork 748
Add fuse batchnorm pass #8028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fuse batchnorm pass #8028
Conversation
ExportPasses that need to be initated with an ExportedProgram can currently not be tested in a convenient way. This patch subclasses RunPasses, adding a parameter, `passes_with_exported_program`, that can be used just like `pass_list` but are initiated with an exported program before they are run. The functionality is tested in new tests for the CastInt64Pass and InsertTableOpsPass Signed-off-by: Erik Lundell <[email protected]> Change-Id: I1712d86abe7cc3672c343db568df1264c0b9133e
The pass differs from existing fuse passes since they use the get_attr node which is not supported by ArmBackend. Instead, we update the existing parameters. Also adds tests. Change-Id: Iad6d70e632191d74d96df62b1837d37fe60e7d3a
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/8028
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Cancelled Jobs, 1 PendingAs of commit d3c6a60 with merge base c5fea7e ( NEW FAILURE - The following job has failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
macos unrealted |
digantdesai
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dang missed the train on this one :p
| self.add_pass(ConvertMeanDimToAveragePoolPass()) | ||
| self.add_pass(DecomposeDivPass()) | ||
| self.add_pass(DecomposeSoftmaxesPass()) | ||
| self.add_pass(FuseBatchnorm2DPass(exported_program)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only for MI?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is already done in the BI case, in prepare_pt2e I believe
| from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses | ||
|
|
||
|
|
||
| class Int64Model(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what makes it int64?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A scalar always becomes int64
| z = self.conv2d2(x) | ||
| a = self.batch_norm2d( | ||
| y | ||
| ) # Can't be fused since paramters of conv2d2 have multiple users. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this a real constraint? i.e.
y = conv2d2(x1)
a = bn2d(y)
return a
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we update the weights of the convolution, it will produce different output. The new output will only be correct for the user that had the bn.
The pass differs from existing fuse passes since they
use the get_attr node which is not supported by ArmBackend.
Instead, we update the existing parameters.
Also adds tests.
To test the batchnorm pass, I extended RunPasses, adding a parameter,
passes_with_exported_program, that can be used justlike
pass_listbut are initiated with anexported program before they are run.
The functionality is tested in new tests for the
CastInt64Pass and InsertTableOpsPass
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218