@@ -586,17 +586,17 @@ defmodule Nx.LinAlgTest do
586
586
Nx . tensor ( [
587
587
[
588
588
Complex . new ( - 0.408 , 0.0 ) ,
589
- Complex . new ( 0.0 , 0.707 ) ,
589
+ Complex . new ( - 0.0 , 0.707 ) ,
590
590
Complex . new ( 0.577 , 0.0 )
591
591
] ,
592
592
[
593
- Complex . new ( 0.0 , - 0.816 ) ,
593
+ Complex . new ( - 0.0 , - 0.816 ) ,
594
594
Complex . new ( 0.0 , 0.0 ) ,
595
595
Complex . new ( 0.0 , - 0.577 )
596
596
] ,
597
597
[
598
598
Complex . new ( 0.408 , 0.0 ) ,
599
- Complex . new ( 0.0 , 0.707 ) ,
599
+ Complex . new ( - 0.0 , 0.707 ) ,
600
600
Complex . new ( - 0.577 , 0.0 )
601
601
]
602
602
] )
@@ -731,7 +731,8 @@ defmodule Nx.LinAlgTest do
731
731
732
732
assert { u , s , vt } = Nx.LinAlg . svd ( t )
733
733
734
- assert round ( Nx . as_type ( t , :f32 ) , 2 ) == u |> Nx . multiply ( s ) |> Nx . dot ( vt ) |> round ( 2 )
734
+ assert round ( Nx . as_type ( t , :f32 ) , 2 ) ==
735
+ u |> Nx . multiply ( s ) |> Nx . dot ( vt ) |> Nx . abs ( ) |> round ( 2 )
735
736
end
736
737
737
738
test "finds the singular values of wide matrices" do
@@ -755,7 +756,7 @@ defmodule Nx.LinAlgTest do
755
756
|> Nx . broadcast ( { 3 , 3 } )
756
757
|> Nx . put_diagonal ( s )
757
758
758
- assert round ( t , 1 ) == u |> Nx . dot ( s_matrix ) |> Nx . dot ( v ) |> round ( 1 )
759
+ assert round ( t , 1 ) == u |> Nx . dot ( s_matrix ) |> Nx . dot ( v ) |> Nx . abs ( ) |> round ( 1 )
759
760
760
761
assert round ( u , 3 ) ==
761
762
Nx . tensor ( [
0 commit comments