Skip to content

Commit 844cfd4

Browse files
authored
fix: batch norm issue encountered in RAFT (#3758)
1 parent 926d72c commit 844cfd4

File tree

4 files changed

+87
-44
lines changed

4 files changed

+87
-44
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,33 @@ def aten_ops_batch_norm_legit_no_training(
127127
)
128128

129129

130+
@dynamo_tensorrt_converter(
131+
torch.ops.aten._native_batch_norm_legit.no_stats,
132+
capability_validator=one_user_validator,
133+
supports_dynamic_shapes=True,
134+
)
135+
def aten_ops_batch_norm_legit_no_stats(
136+
ctx: ConversionContext,
137+
target: Target,
138+
args: Tuple[Argument, ...],
139+
kwargs: Dict[str, Argument],
140+
name: str,
141+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
142+
return impl.normalization.batch_norm(
143+
ctx,
144+
target,
145+
SourceIR.ATEN,
146+
name,
147+
input=args[0],
148+
weight=args[1],
149+
bias=args[2],
150+
training=False,
151+
momentum=args[4],
152+
eps=args[5],
153+
return_mean_rstd=True,
154+
)
155+
156+
130157
@dynamo_tensorrt_converter(
131158
torch.ops.aten.native_layer_norm.default,
132159
supports_dynamic_shapes=True,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def to_trt_weights(
344344
count: Optional[int] = None,
345345
) -> trt.Weights:
346346
"""
347-
Convert a PyTorch tensor or NumPy array to TensorRT weights.
347+
Convert a PyTorch tensor to TensorRT weights.
348348
349349
Args:
350350
value (Union[torch.Tensor, np.ndarray]): The tensor or array to convert to TRT weights

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 59 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import tensorrt as trt
66
import torch
7+
from torch._subclasses.fake_tensor import unset_fake_temporarily
78
from torch.fx.node import Target
89
from torch_tensorrt.dynamo._SourceIR import SourceIR
910
from torch_tensorrt.dynamo.conversion import impl
@@ -32,21 +33,22 @@ def batch_norm(
3233
source_ir: Optional[SourceIR],
3334
name: str,
3435
input: trt.ITensor,
35-
weight: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]],
36-
bias: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]],
37-
running_mean: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]],
38-
running_var: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]],
39-
training: bool,
4036
momentum: float,
4137
eps: float,
42-
cudnn_enabled: bool,
4338
return_mean_rstd: bool,
39+
weight: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]] = None,
40+
bias: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]] = None,
41+
running_mean: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]] = None,
42+
running_var: Optional[Union[trt.ITensor, torch.Tensor, np.ndarray]] = None,
43+
training: bool = False,
44+
cudnn_enabled: bool = False,
4445
) -> Union[trt.ITensor, Tuple[trt.ITensor, torch.Tensor, torch.Tensor]]:
4546
if has_dynamic_shape(input.shape):
4647
assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm."
4748

4849
# Save the original output shape for later use
4950
output_shape = input.shape
51+
feature_num = output_shape[1]
5052
# We perform constant folding for batch norm when the weight, bias, running_mean, and running_var are all tensors.
5153
# Batch norm operation can be fused into a single layer, which is more efficient than the original implementation.
5254
# In this way, the batch norm layer will be fused with the Convolution layer and get a performance boost.
@@ -59,26 +61,41 @@ def batch_norm(
5961
]
6062
):
6163
# We name the weight here according to the state_dict name
62-
weight = (
63-
get_trt_tensor(ctx, 1.0, f"{name}_weight", dtype=input.dtype)
64-
if weight is None
65-
else get_trt_tensor(ctx, weight, f"{name}_weight")
66-
)
67-
bias = (
68-
get_trt_tensor(ctx, 0.0, f"{name}_bias", dtype=input.dtype)
69-
if bias is None
70-
else get_trt_tensor(ctx, bias, f"{name}_bias")
71-
)
72-
running_mean = (
73-
get_trt_tensor(ctx, 0.0, f"{name}_running_mean", dtype=input.dtype)
74-
if running_mean is None
75-
else get_trt_tensor(ctx, running_mean, f"{name}_running_mean")
76-
)
77-
running_var = (
78-
get_trt_tensor(ctx, 1.0, f"{name}_running_var", dtype=input.dtype)
79-
if running_var is None
80-
else get_trt_tensor(ctx, running_var, f"{name}_running_var")
81-
)
64+
with unset_fake_temporarily():
65+
weight = (
66+
get_trt_tensor(
67+
ctx, torch.ones((feature_num,)), f"{name}_weight", dtype=input.dtype
68+
)
69+
if weight is None
70+
else get_trt_tensor(ctx, weight, f"{name}_weight")
71+
)
72+
bias = (
73+
get_trt_tensor(
74+
ctx, torch.zeros((feature_num,)), f"{name}_bias", dtype=input.dtype
75+
)
76+
if bias is None
77+
else get_trt_tensor(ctx, bias, f"{name}_bias")
78+
)
79+
running_mean = (
80+
get_trt_tensor(
81+
ctx,
82+
torch.zeros((feature_num,)),
83+
f"{name}_running_mean",
84+
dtype=input.dtype,
85+
)
86+
if running_mean is None
87+
else get_trt_tensor(ctx, running_mean, f"{name}_running_mean")
88+
)
89+
running_var = (
90+
get_trt_tensor(
91+
ctx,
92+
torch.ones((feature_num,)),
93+
f"{name}_running_var",
94+
dtype=input.dtype,
95+
)
96+
if running_var is None
97+
else get_trt_tensor(ctx, running_var, f"{name}_running_var")
98+
)
8299

83100
# eps_tensor for numerical stability
84101
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps", dtype=input.dtype)
@@ -110,8 +127,7 @@ def batch_norm(
110127

111128
# Reshape scale and bias_adjusted to match input shape for broadcasting
112129
expanded_shape = [1] * len(output_shape)
113-
expanded_shape[1] = output_shape[1] # Set channel dimension
114-
130+
expanded_shape[1] = feature_num # Set channel dimension
115131
scale_reshape = impl.shuffle.reshape(
116132
ctx,
117133
target,
@@ -143,21 +159,24 @@ def batch_norm(
143159
)
144160

145161
else:
146-
if weight is None:
147-
weight = 1.0
162+
with unset_fake_temporarily():
163+
if weight is None:
164+
weight = torch.ones((feature_num,))
148165

149-
if bias is None:
150-
bias = 0.0
166+
if bias is None:
167+
bias = torch.zeros((feature_num,))
151168

152-
if running_mean is None:
153-
running_mean = 0.0
169+
if running_mean is None:
170+
running_mean = torch.zeros((feature_num,))
154171

155-
if running_var is None:
156-
running_var = 1.0
157-
adjusted_scale, adjusted_bias = batch_norm_constant_folding(
158-
weight, bias, running_mean, running_var, eps
159-
)
160-
power = torch.ones_like(adjusted_scale)
172+
if running_var is None:
173+
running_var = torch.ones((feature_num,))
174+
175+
power = torch.ones_like(weight)
176+
177+
adjusted_scale, adjusted_bias = batch_norm_constant_folding(
178+
weight, bias, running_mean, running_var, eps
179+
)
161180

162181
adjusted_scale = to_trt_weights(
163182
ctx,
@@ -188,9 +207,7 @@ def batch_norm(
188207
source_ir=source_ir,
189208
)
190209

191-
output_shape = input.shape
192210
if len(input.shape) < 4:
193-
194211
new_shape = (
195212
(input.shape[0], input.shape[1], 1, 1)
196213
if len(input.shape) == 2

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@
9191
aten.narrow,
9292
# TODO: Disable the below operators once freezing is done
9393
aten.native_batch_norm_backward,
94-
aten._native_batch_norm_legit,
9594
aten._native_batch_norm_legit_functional,
9695
aten.native_dropout_backward,
9796
aten.native_group_norm_backward,

0 commit comments

Comments
 (0)