Skip to content

Commit 0edba55

Browse files
khabinovWei Wei
authored andcommitted
[fx2trt] Handle shapes like [batch_size] and scalars for binary ops properly (#74)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/74 ## Root cause We have code like: ``` x = ... # result shape is [batch_size, N] y = mean(y, dim=1, keepdim=False) # result shape is [batch_size] z = y + 0.5 # result shape is [batch_size] ``` For TRT with implicit batch dimension it should look like: ``` x = ... # result shape is [N] y = mean(y, dim=1, keepdim=False) # result shape is [] z = y + 0.5 # result shape is [] ``` However, because we convert scalar to `TRTTensor` and don't do dimensions squeeze for it, the resulting tensor `z` would have shape `[1]`, and this is gonna break the rest of the net. ## Solution Convert the scalar value to `torch.Tensor`, because we have dimensions squeeze logic implemented for them. ## P.S.: Also added support for `sqrt` tracing. Reviewed By: yinghai, houseroad Differential Revision: D36336816 fbshipit-source-id: 412e44e99f25ab3549df540a87bd005e6b3fe08a
1 parent 5d80f41 commit 0edba55

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

fx/converters/converter_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,21 @@ def add_binary_elementwise_layer(
452452
)
453453
return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val)
454454

455+
# If the following conditions are true:
456+
# 1. the network has implicit batch dimension,
457+
# 2. one operand has shape [] (real shape is [batch_size]),
458+
# 3. another operand is a scalar,
459+
# then the result should also have shape [] (real shape is [batch_size]).
460+
#
461+
# In such case, we need to convert the scalar operand to tensor, because
462+
# this way the shape will become [1], and then will be properly squeezed
463+
# into [], meaning that the result will have shape [], which is what we
464+
# expect.
465+
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
466+
rhs_val = torch.tensor([rhs_val], dtype=dtype)
467+
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
468+
lhs_val = torch.tensor([lhs_val], dtype=dtype)
469+
455470
# When lhs is scalar, and rhs has shape [1,], then currently the assert
456471
# will fail because lhs shape has fewer dimensions than rhs shape. This
457472
# happens when using implicit batch dimension, when we removed the 1st

tracer/acc_tracer/acc_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,6 +1482,7 @@ def log(*, input):
14821482

14831483
@register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary)
14841484
@register_acc_op_mapping(op_and_target=("call_function", torch.sqrt))
1485+
@register_acc_op_mapping(op_and_target=("call_method", "sqrt"))
14851486
@register_acc_op
14861487
def sqrt(*, input):
14871488
return torch.sqrt(input=input)

0 commit comments

Comments
 (0)