Skip to content

Commit 41f5019

Browse files
authored
[torchlib] Two fixes (#2762)
- torch.prod: Does not preserve dimensions - aten::normal.float_float: May be called with more arguments --------- Signed-off-by: Justin Chu <[email protected]>
1 parent c7bb8af commit 41f5019

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7350,7 +7350,13 @@ def aten_normal(
73507350

73517351
@torch_op("aten::normal.float_float", trace_only=True)
73527352
def aten_normal_float_float(
7353-
mean: float, std: float, size: INT64, dtype: int = FLOAT.dtype
7353+
mean: float,
7354+
std: float,
7355+
size: INT64,
7356+
dtype: int = FLOAT.dtype,
7357+
layout: str = "",
7358+
device: str = "",
7359+
pin_memory: bool = False,
73547360
) -> TensorType:
73557361
"""normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
73567362

@@ -7674,7 +7680,9 @@ def aten_prod(self: TReal, dtype: int = -1) -> TReal:
76747680

76757681
if dtype != -1 and dtype is not None:
76767682
self = op.Cast(self, to=dtype)
7677-
return op.ReduceProd(self)
7683+
elif self.dtype.is_integer():
7684+
self = op.Cast(self, to=INT64.dtype)
7685+
return op.ReduceProd(self, keepdims=False)
76787686

76797687

76807688
@torch_op("aten::prod.dim_int", trace_only=True)

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,7 @@ def _where_input_wrangler(
10671067
TorchLibOpInfo("prod", core_ops.aten_prod).skip(
10681068
matcher=lambda sample: sample.kwargs.get("dim") is not None
10691069
or sample.kwargs.get("keepdim") is not None
1070-
or sample.kwargs.get("dtype") != -1,
1070+
or len(sample.args) > 0,
10711071
reason="this Aten overload only accept 1 inputs: self",
10721072
),
10731073
TorchLibOpInfo("prod_dim_int", core_ops.aten_prod_dim_int).skip(

0 commit comments

Comments
 (0)