Skip to content

Commit d55faa1

Browse files
committed
Make tests more resilient to signs when rounding
1 parent 44313f6 commit d55faa1

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

nx/test/nx/lin_alg_test.exs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -586,17 +586,17 @@ defmodule Nx.LinAlgTest do
586586
Nx.tensor([
587587
[
588588
Complex.new(-0.408, 0.0),
589-
Complex.new(0.0, 0.707),
589+
Complex.new(-0.0, 0.707),
590590
Complex.new(0.577, 0.0)
591591
],
592592
[
593-
Complex.new(0.0, -0.816),
593+
Complex.new(-0.0, -0.816),
594594
Complex.new(0.0, 0.0),
595595
Complex.new(0.0, -0.577)
596596
],
597597
[
598598
Complex.new(0.408, 0.0),
599-
Complex.new(0.0, 0.707),
599+
Complex.new(-0.0, 0.707),
600600
Complex.new(-0.577, 0.0)
601601
]
602602
])
@@ -731,7 +731,8 @@ defmodule Nx.LinAlgTest do
731731

732732
assert {u, s, vt} = Nx.LinAlg.svd(t)
733733

734-
assert round(Nx.as_type(t, :f32), 2) == u |> Nx.multiply(s) |> Nx.dot(vt) |> round(2)
734+
assert round(Nx.as_type(t, :f32), 2) ==
735+
u |> Nx.multiply(s) |> Nx.dot(vt) |> Nx.abs() |> round(2)
735736
end
736737

737738
test "finds the singular values of wide matrices" do
@@ -755,7 +756,7 @@ defmodule Nx.LinAlgTest do
755756
|> Nx.broadcast({3, 3})
756757
|> Nx.put_diagonal(s)
757758

758-
assert round(t, 1) == u |> Nx.dot(s_matrix) |> Nx.dot(v) |> round(1)
759+
assert round(t, 1) == u |> Nx.dot(s_matrix) |> Nx.dot(v) |> Nx.abs() |> round(1)
759760

760761
assert round(u, 3) ==
761762
Nx.tensor([

0 commit comments

Comments
 (0)