You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[fx2trt] Handle shapes like [batch_size] and scalars for binary ops properly (#74)
Summary:
Pull Request resolved: https://github.com/pytorch/fx2trt/pull/74
## Root cause
We have code like:
```
x = ... # result shape is [batch_size, N]
y = mean(y, dim=1, keepdim=False) # result shape is [batch_size]
z = y + 0.5 # result shape is [batch_size]
```
For TRT with implicit batch dimension it should look like:
```
x = ... # result shape is [N]
y = mean(y, dim=1, keepdim=False) # result shape is []
z = y + 0.5 # result shape is []
```
However, because we convert scalar to `TRTTensor` and don't do dimensions squeeze for it, the resulting tensor `z` would have shape `[1]`, and this is gonna break the rest of the net.
## Solution
Convert the scalar value to `torch.Tensor`, because we have dimensions squeeze logic implemented for them.
## P.S.:
Also added support for `sqrt` tracing.
Reviewed By: yinghai, houseroad
Differential Revision: D36336816
fbshipit-source-id: 412e44e99f25ab3549df540a87bd005e6b3fe08a
0 commit comments