@@ -54,31 +54,40 @@ defmodule EXLA.MLIR.Value do
54
54
}
55
55
56
56
for { op , direction } <- @ bin_comparison_ops do
57
- def unquote ( op ) ( % Value { function: func } = lhs , % Value { function: func } = rhs , typespec ) do
58
- compare_and_return_bool ( func , lhs , rhs , typespec , unquote ( direction ) )
57
+ def unquote ( op ) (
58
+ % Value { function: func } = lhs ,
59
+ % Value { function: func } = rhs ,
60
+ typespec ,
61
+ opts \\ [ ]
62
+ ) do
63
+ compare_and_return_bool ( func , lhs , rhs , typespec , unquote ( direction ) , opts [ :total_order ] )
59
64
end
60
65
end
61
66
62
- defp compare_and_return_bool ( func , lhs , rhs , typespec , direction ) do
67
+ defp compare_and_return_bool ( func , lhs , rhs , typespec , direction , total_order? \\ false ) do
63
68
% { type: lhs_type } = get_typespec ( lhs )
64
69
% { type: rhs_type } = get_typespec ( rhs )
65
70
66
71
comparison_type =
67
72
cond do
68
73
Nx.Type . complex? ( lhs_type ) or Nx.Type . complex? ( rhs_type ) ->
69
- attr_comparison_type ( :float )
74
+ [ compare_type: attr_comparison_type ( :float ) ]
70
75
71
76
Nx.Type . float? ( lhs_type ) or Nx.Type . float? ( rhs_type ) ->
72
- attr_comparison_type ( :float )
77
+ attr =
78
+ if total_order? do
79
+ attr_comparison_type ( :totalorder )
80
+ else
81
+ attr_comparison_type ( :float )
82
+ end
83
+
84
+ [ compare_type: attr ]
73
85
74
86
true ->
75
- attr_comparison_type ( :notype )
87
+ [ ]
76
88
end
77
89
78
- attributes = [
79
- comparison_direction: attr_comparison_direction ( direction ) ,
80
- compare_type: comparison_type
81
- ]
90
+ attributes = [ comparison_direction: attr_comparison_direction ( direction ) ] ++ comparison_type
82
91
83
92
result_types = typespecs_to_mlir_types ( [ Typespec . to_type ( typespec , { :pred , 8 } ) ] )
84
93
@@ -1072,7 +1081,7 @@ defmodule EXLA.MLIR.Value do
1072
1081
defp attr_comparison_direction ( value ) when value in [ :eq , :lt , :le , :gt , :ge , :ne ] ,
1073
1082
do: attr_enum ( "stablehlo" , "comparison_direction" , value )
1074
1083
1075
- defp attr_comparison_type ( value ) when value in [ :float , :totalorder , :notype ] ,
1084
+ defp attr_comparison_type ( value ) when value in [ :float , :totalorder ] ,
1076
1085
do: attr_enum ( "stablehlo" , "comparison_type" , value )
1077
1086
1078
1087
defp attr_precision ( value ) when value in [ :default , :high , :highest ] ,
0 commit comments