Skip to content

Commit bfe5bc5

Browse files
committed
handle 1d case batch_norm
1 parent c77c90e commit bfe5bc5

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

torch2trt/converters/native_converters.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,38 @@ def convert_batch_norm(ctx):
148148
bias = get_arg(ctx, 'bias', pos=4, default=None)
149149
eps = get_arg(ctx, 'eps', pos=7, default=10e-6)
150150

151+
ndim = input.ndim - 2
151152
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
152153
output = ctx.method_return
153154

154155
scale = weight.detach().cpu().numpy() / np.sqrt(running_var.detach().cpu().numpy() + eps)
155156
bias = bias.detach().cpu().numpy() - running_mean.detach().cpu().numpy() * scale
156157
power = np.ones_like(scale)
157158

158-
layer = ctx.network.add_scale_nd(input_trt, trt.ScaleMode.CHANNEL, bias, scale, power, 1)
159+
if ndim == 1:
160+
# reshape to 2D
161+
layer = ctx.network.add_shuffle(input_trt)
162+
163+
if len(input.shape) == 2:
164+
layer.reshape_dims = (0, 0, 1, 1)
165+
else:
166+
layer.reshape_dims = (0, 0, 0, 1)
167+
168+
scale_input = layer.get_output(0)
169+
else:
170+
scale_input = input_trt
171+
172+
layer = ctx.network.add_scale_nd(scale_input, trt.ScaleMode.CHANNEL, bias, scale, power, 1)
173+
174+
if ndim == 1:
175+
# reshape back to 1D
176+
layer = ctx.network.add_shuffle(layer.get_output(0))
177+
if len(input.shape) == 2:
178+
layer.reshape_dims = (0, 0)
179+
else:
180+
layer.reshape_dims = (0, 0, 0)
181+
182+
159183
output._trt = layer.get_output(0)
160184

161185

0 commit comments

Comments
 (0)