Skip to content

Commit b00ff13

Browse files
authored
Simplify the Group Norm converter (#3719)
1 parent d8461e6 commit b00ff13

File tree

1 file changed

+5
-9
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/normalization

1 file changed

+5
-9
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -303,17 +303,13 @@ def native_group_norm(
303303
C == input.shape[1]
304304
), f"num_channels ({C}) must be equal to number of channels in input ({input.shape[1]})"
305305

306-
weight_one = get_trt_tensor(ctx, 1.0, f"{name}_weight_one", input.dtype)
307-
bias_zero = get_trt_tensor(ctx, 0.0, f"{name}_bias_zero", input.dtype)
308-
309306
shape = [1, group] + [1] * (rank - 2)
310307

311-
weight_one = impl.slice.expand(
312-
ctx, target, source_ir, f"{name}_expand_weight_one", weight_one, shape
313-
)
314-
bias_zero = impl.slice.expand(
315-
ctx, target, source_ir, f"{name}_expand_bias_zero", bias_zero, shape
316-
)
308+
weight_torch = torch.ones(shape)
309+
bias_torch = torch.zeros(shape)
310+
311+
weight_one = get_trt_tensor(ctx, weight_torch, f"{name}_weight_one", input.dtype)
312+
bias_zero = get_trt_tensor(ctx, bias_torch, f"{name}_bias_zero", input.dtype)
317313

318314
axes = get_axes_for_reduce_op(list(range(1 if group == 1 else 2, rank)))
319315

0 commit comments

Comments
 (0)