@@ -34,7 +34,7 @@ defmodule EXLA.Lib do
34
34
def argmax ( builder , op , type , opts \\ [ ] )
35
35
36
36
def argmax ( % Function { } = builder , % Value { } = op , type , opts ) do
37
- argmin_or_max ( builder , op , false , type , opts )
37
+ argmin_or_max ( builder , op , :max , type , opts )
38
38
end
39
39
40
40
@ doc """
@@ -49,37 +49,43 @@ defmodule EXLA.Lib do
49
49
def argmin ( builder , op , type , opts \\ [ ] )
50
50
51
51
def argmin ( % Function { } = builder , % Value { } = op , type , opts ) do
52
- argmin_or_max ( builder , op , true , type , opts )
52
+ argmin_or_max ( builder , op , :min , type , opts )
53
53
end
54
54
55
- defp argmin_or_max ( builder , % Value { } = op , is_min? , type , opts ) do
55
+ defp argmin_or_max ( builder , % Value { } = op , variant , type , opts ) do
56
56
tie_break = opts [ :tie_break ] || :low
57
57
keep_axis = opts [ :keep_axis ] || false
58
+ axis = opts [ :axis ]
58
59
59
60
op_typespec = Value . get_typespec ( op )
60
61
62
+ { op , op_typespec } =
63
+ if axis == nil and Nx . rank ( op_typespec . shape ) != 1 do
64
+ # When no axis is given, we flatten the tensor and reduce over
65
+ # the first axis
66
+ typespec = Typespec . to_shape ( op_typespec , { Nx . size ( op_typespec . shape ) } )
67
+ { Value . reshape ( op , typespec ) , typespec }
68
+ else
69
+ { op , op_typespec }
70
+ end
71
+
72
+ axis = axis || 0
73
+
61
74
init_value =
62
- if is_min? ,
63
- do: max_number ( builder , op_typespec . type ) ,
64
- else: min_number ( builder , op_typespec . type )
75
+ case variant do
76
+ :min -> max_number ( builder , op_typespec . type )
77
+ :max -> min_number ( builder , op_typespec . type )
78
+ end
65
79
66
- axis = opts [ :axis ]
67
80
index_init_value = Value . constant ( builder , [ 0 ] , Typespec . tensor ( type , { } ) )
68
81
iota = iota ( builder , axis , Typespec . to_type ( op_typespec , type ) )
69
- reduction = create_min_max_computation ( builder , op_typespec . type , type , is_min? , tie_break )
82
+ reduction = create_min_max_computation ( builder , op_typespec . type , type , variant , tie_break )
70
83
71
- dims =
72
- if axis do
73
- [ axis ]
74
- else
75
- Nx . axes ( op_typespec . shape )
76
- end
77
-
78
- shape = remove_axes ( op_typespec . shape , dims )
84
+ shape = Tuple . delete_at ( op_typespec . shape , axis )
79
85
typespecs = [ Typespec . tensor ( op_typespec . type , shape ) , Typespec . tensor ( type , shape ) ]
80
86
81
87
[ _ , result ] =
82
- Value . reduce ( reduction , [ init_value , index_init_value ] , [ op , iota ] , dims , typespecs )
88
+ Value . reduce ( reduction , [ init_value , index_init_value ] , [ op , iota ] , [ axis ] , typespecs )
83
89
84
90
if keep_axis do
85
91
Value . reshape ( result , Typespec . tensor ( type , put_elem ( op_typespec . shape , axis , 1 ) ) )
@@ -88,13 +94,7 @@ defmodule EXLA.Lib do
88
94
end
89
95
end
90
96
91
- defp remove_axes ( shape , axes ) do
92
- axes
93
- |> Enum . reverse ( )
94
- |> Enum . reduce ( shape , & Tuple . delete_at ( & 2 , & 1 ) )
95
- end
96
-
97
- defp create_min_max_computation ( % Function { } = function , type , index_type , is_min? , tie_break ) do
97
+ defp create_min_max_computation ( % Function { } = function , type , index_type , variant , tie_break ) do
98
98
arg_typespecs = [
99
99
Typespec . tensor ( type , { } ) ,
100
100
Typespec . tensor ( index_type , { } ) ,
@@ -109,27 +109,42 @@ defmodule EXLA.Lib do
109
109
value_typespec = Typespec . tensor ( type , { } )
110
110
idx_typespec = Typespec . tensor ( index_type , { } )
111
111
112
- cmp =
113
- if is_min? ,
114
- do: Value . less_equal ( lhs_value , rhs_value , pred_typespec ) ,
115
- else: Value . greater_equal ( lhs_value , rhs_value , pred_typespec )
112
+ comparator =
113
+ case variant do
114
+ :min -> & Value . less / 3
115
+ :max -> & Value . greater / 3
116
+ end
117
+
118
+ # Pick lhs if strictly before or if it is NaN
119
+ pick_lhs_value =
120
+ Value . bitwise_or (
121
+ comparator . ( lhs_value , rhs_value , pred_typespec ) ,
122
+ Value . is_nan ( lhs_value , pred_typespec ) ,
123
+ pred_typespec
124
+ )
116
125
117
- max = Value . select ( cmp , lhs_value , rhs_value , value_typespec )
118
- arg_max = Value . select ( cmp , lhs_index , rhs_index , idx_typespec )
126
+ max = Value . select ( pick_lhs_value , lhs_value , rhs_value , value_typespec )
119
127
120
- arg_max =
128
+ idx_comparator =
121
129
case tie_break do
122
- :low ->
123
- eq? = Value . equal ( lhs_value , rhs_value , pred_typespec )
124
- id = Value . min ( lhs_index , rhs_index , idx_typespec )
125
- Value . select ( eq? , id , arg_max , idx_typespec )
126
-
127
- :high ->
128
- eq? = Value . equal ( lhs_value , rhs_value , pred_typespec )
129
- id = Value . max ( lhs_index , rhs_index , idx_typespec )
130
- Value . select ( eq? , id , arg_max , idx_typespec )
130
+ :low -> & Value . less / 3
131
+ :high -> & Value . greater / 3
131
132
end
132
133
134
+ # If lhs and rhs are equal (and not NaN), then pick index based on tie_break
135
+ pick_lhs_idx =
136
+ Value . bitwise_or (
137
+ pick_lhs_value ,
138
+ Value . bitwise_and (
139
+ Value . equal ( lhs_value , rhs_value , pred_typespec ) ,
140
+ idx_comparator . ( lhs_index , rhs_index , pred_typespec ) ,
141
+ pred_typespec
142
+ ) ,
143
+ pred_typespec
144
+ )
145
+
146
+ arg_max = Value . select ( pick_lhs_idx , lhs_index , rhs_index , idx_typespec )
147
+
133
148
Value . return ( function , [ max , arg_max ] )
134
149
Function . pop_region ( function )
135
150
region
0 commit comments