@@ -148,14 +148,38 @@ def convert_batch_norm(ctx):
148148 bias = get_arg (ctx , 'bias' , pos = 4 , default = None )
149149 eps = get_arg (ctx , 'eps' , pos = 7 , default = 10e-6 )
150150
151+ ndim = input .ndim - 2
151152 input_trt = add_missing_trt_tensors (ctx .network , [input ])[0 ]
152153 output = ctx .method_return
153154
154155 scale = weight .detach ().cpu ().numpy () / np .sqrt (running_var .detach ().cpu ().numpy () + eps )
155156 bias = bias .detach ().cpu ().numpy () - running_mean .detach ().cpu ().numpy () * scale
156157 power = np .ones_like (scale )
157158
158- layer = ctx .network .add_scale_nd (input_trt , trt .ScaleMode .CHANNEL , bias , scale , power , 1 )
159+ if ndim == 1 :
160+ # reshape to 2D
161+ layer = ctx .network .add_shuffle (input_trt )
162+
163+ if len (input .shape ) == 2 :
164+ layer .reshape_dims = (0 , 0 , 1 , 1 )
165+ else :
166+ layer .reshape_dims = (0 , 0 , 0 , 1 )
167+
168+ scale_input = layer .get_output (0 )
169+ else :
170+ scale_input = input_trt
171+
172+ layer = ctx .network .add_scale_nd (scale_input , trt .ScaleMode .CHANNEL , bias , scale , power , 1 )
173+
174+ if ndim == 1 :
175+ # reshape back to 1D
176+ layer = ctx .network .add_shuffle (layer .get_output (0 ))
177+ if len (input .shape ) == 2 :
178+ layer .reshape_dims = (0 , 0 )
179+ else :
180+ layer .reshape_dims = (0 , 0 , 0 )
181+
182+
159183 output ._trt = layer .get_output (0 )
160184
161185
0 commit comments