Skip to content

Commit a571309

Browse files
authored
[torchlib] Fix aten__native_batch_norm_legit_functional (#2753)
Fix aten__native_batch_norm_legit_functional where the running mean/var were returned without creating a new value, making the graph invalid. Signed-off-by: Justin Chu <[email protected]>
1 parent bfe2cce commit a571309

File tree

1 file changed

+1
-1
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+1
-1
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6893,7 +6893,7 @@ def _aten_native_batch_norm_inference_onnx(
68936893
# https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1475
68946894
running_mean_fp32 = op.Cast(running_mean, to=FLOAT.dtype)
68956895
invstd = op.Cast(invstd, to=FLOAT.dtype)
6896-
return norm, running_mean_fp32, invstd, running_mean, running_var
6896+
return norm, running_mean_fp32, invstd, op.Identity(running_mean), op.Identity(running_var)
68976897

68986898

68996899
# TODO: This op is using duplicated code from aten_native_batch_norm,

0 commit comments

Comments
 (0)