Skip to content

Commit 6bef925

Browse files
authored
fix: torch.std and torch.var support multi-dimensional reductions (#1395)
* Augment current tests to reproduce error generating bug with standard dev / variance lowering * fix: Ensure torch.std (and torch.var) support multiple dimensions - Refactor IR code to avoid the use of select, for which only single-dimension support exists currently - Update formula in Bessel's correction (unbiased) case - Include regression tests to catch multi-dimensional indexing errors * Improved comments on IR for UnpackVar function
1 parent 1011ac1 commit 6bef925

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

core/lowering/passes/unpack_var.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,18 @@ void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph) {
2626
%var: Tensor = aten::sub(%sqrdmean, %meansqrd, %1)
2727
%varout : Tensor = prim::If(%unbiased)
2828
block0():
29-
%shape: int[] = aten::size(%input)
30-
%shapet: Tensor = aten::tensor(%shape, %f32_dtype, %none, %false)
31-
%dim: int = prim::ListUnpack(%dims)
32-
%reduceddims: Tensor = aten::select(%shapet, %0, %dim)
33-
%numel: Tensor = aten::prod(%reduceddims, %dim, %keepdim, %none)
29+
# Compute number of elements in original input tensor
30+
%originalshape: int[] = aten::size(%input)
31+
%originalshapet: Tensor = aten::tensor(%originalshape, %f32_dtype, %none, %false)
32+
%originalnumel: Tensor = aten::prod(%originalshapet, %0, %false, %none)
33+
# Compute number of elements in resulting output tensor
34+
%resultingshape: int[] = aten::size(%var)
35+
%resultingshapet: Tensor = aten::tensor(%resultingshape, %f32_dtype, %none, %false)
36+
%resultingnumel: Tensor = aten::prod(%resultingshapet, %0, %false, %none)
37+
# Quotient of original number of elements and resulting number of elements
38+
# is equal to the number of elements used per variance calculation
39+
%numel: Tensor = aten::div(%originalnumel, %resultingnumel)
40+
# Perform Bessel's correction on computed variance
3441
%mul: Tensor = aten::mul(%var, %numel)
3542
%sub: Tensor = aten::sub(%numel, %1, %1)
3643
%v: Tensor = aten::div(%mul, %sub)

tests/core/conversion/converters/test_reduce.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,8 @@ TEST(Converters, UnpackStdUnbiasedKeepDimsLowersCorrectly) {
465465
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
466466
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
467467
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
468-
%6 : int[] = prim::ListConstruct(%3)
468+
%one : int = prim::Constant[value=1]()
469+
%6 : int[] = prim::ListConstruct(%3, %one)
469470
%7 : Tensor = aten::std(%x.1, %6, %4, %5) # test_zeros.py:10:26
470471
return (%7))IR";
471472

tests/core/lowering/test_unpack_reduce_ops.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ TEST(LoweringPasses, UnpackStdKeepDimsLowersCorrectly) {
134134
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
135135
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
136136
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
137-
%6 : int[] = prim::ListConstruct(%3)
137+
%one : int = prim::Constant[value=1]()
138+
%6 : int[] = prim::ListConstruct(%3, %one)
138139
%7 : Tensor = aten::std(%x.1, %6, %5, %5) # test_zeros.py:10:26
139140
return (%7))IR";
140141

@@ -184,7 +185,8 @@ TEST(LoweringPasses, UnpackStdUnbiasedKeepDimsLowersCorrectly) {
184185
%5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65
185186
%4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50
186187
%3 : int = prim::Constant[value=0]() # test_zeros.py:10:39
187-
%6 : int[] = prim::ListConstruct(%3)
188+
%one : int = prim::Constant[value=1]()
189+
%6 : int[] = prim::ListConstruct(%3, %one)
188190
%7 : Tensor = aten::std(%x.1, %6, %4, %5) # test_zeros.py:10:26
189191
return (%7))IR";
190192

0 commit comments

Comments
 (0)