@@ -40,7 +40,9 @@ defmodule Nx.Defn.TreeTest do
40
40
41
41
test "ignores constants" do
42
42
a = Expr . parameter ( :root , { :u , 64 } , { } , 0 )
43
- assert [ { _ , :parameter } , { _ , :add } ] = plus_constant ( a ) |> Tree . scope_ids ( ) |> Enum . sort ( )
43
+
44
+ assert [ { _ , :add } , { _ , :parameter } ] =
45
+ plus_constant ( a ) |> Tree . scope_ids ( ) |> Enum . sort_by ( & elem ( & 1 , 1 ) )
44
46
end
45
47
46
48
defn inside_cond ( bool , a , b ) do
@@ -54,8 +56,8 @@ defmodule Nx.Defn.TreeTest do
54
56
test "ignores expressions inside cond" do
55
57
{ bool , cond } = Nx.Defn . jit ( & { & 1 , inside_cond ( & 1 , & 2 , & 3 ) } ) . ( 0 , 1 , 2 )
56
58
57
- assert cond |> Tree . scope_ids ( ) |> Enum . sort ( ) ==
58
- [ { bool . data . id , :parameter } , { cond . data . id , :cond } ]
59
+ assert cond |> Tree . scope_ids ( ) |> Enum . sort_by ( & elem ( & 1 , 1 ) ) ==
60
+ [ { cond . data . id , :cond } , { bool . data . id , :parameter } ]
59
61
end
60
62
61
63
defn inside_both_cond ( bool , a , b ) do
@@ -84,14 +86,14 @@ defmodule Nx.Defn.TreeTest do
84
86
b = Expr . parameter ( :root , { :u , 64 } , { } , 2 )
85
87
86
88
assert [
87
- { _ , :parameter } ,
88
- { _ , :parameter } ,
89
- { _ , :parameter } ,
90
89
{ _ , :add } ,
91
90
{ _ , :cond } ,
92
91
{ _ , :cond } ,
93
- { _ , :multiply }
94
- ] = inside_both_cond ( bool , a , b ) |> Tree . scope_ids ( ) |> Enum . sort ( )
92
+ { _ , :multiply } ,
93
+ { _ , :parameter } ,
94
+ { _ , :parameter } ,
95
+ { _ , :parameter }
96
+ ] = inside_both_cond ( bool , a , b ) |> Tree . scope_ids ( ) |> Enum . sort_by ( & elem ( & 1 , 1 ) )
95
97
end
96
98
end
97
99
0 commit comments