@@ -125,57 +125,71 @@ defmodule EXLA.MLIR.Value do
125
125
end
126
126
end
127
127
128
- def is_infinity ( % Value { function: func } = operand , typespec ) do
128
+ def is_infinity ( % Value { function: func } = operand , out_typespec ) do
129
129
% { type: type } = get_typespec ( operand )
130
130
131
- typespec = Typespec . to_type ( typespec , { :pred , 8 } )
131
+ typespec = Typespec . to_type ( out_typespec , { :pred , 8 } )
132
132
133
- cond do
134
- Nx.Type . complex? ( type ) ->
135
- float_typespec = Typespec . to_type ( typespec , complex_part_type ( type ) )
136
- real = real ( operand , float_typespec )
137
- imag = imag ( operand , float_typespec )
138
- is_inf_real = is_infinity ( real , typespec )
139
- is_inf_imag = is_infinity ( imag , typespec )
140
- bitwise_or ( is_inf_real , is_inf_imag , typespec )
141
-
142
- Nx.Type . integer? ( type ) ->
143
- # Integers are never infinity. We use inequality to make sure
144
- # the operand is still a part of the computation
145
- not_equal ( operand , operand , typespec )
133
+ result =
134
+ cond do
135
+ Nx.Type . complex? ( type ) ->
136
+ float_typespec = Typespec . to_type ( typespec , complex_part_type ( type ) )
137
+ real = real ( operand , float_typespec )
138
+ imag = imag ( operand , float_typespec )
139
+ is_inf_real = is_infinity ( real , typespec )
140
+ is_inf_imag = is_infinity ( imag , typespec )
141
+ bitwise_or ( is_inf_real , is_inf_imag , typespec )
142
+
143
+ Nx.Type . integer? ( type ) ->
144
+ # Integers are never infinity. We use inequality to make sure
145
+ # the operand is still a part of the computation
146
+ not_equal ( operand , operand , typespec )
147
+
148
+ true ->
149
+ result_types = typespecs_to_mlir_types ( [ typespec ] )
150
+ op ( func , "chlo.is_inf" , [ operand ] , result_types ) |> one! ( )
151
+ end
146
152
147
- true ->
148
- result_types = typespecs_to_mlir_types ( [ typespec ] )
149
- op ( func , "chlo.is_inf" , [ operand ] , result_types ) |> one! ( )
153
+ if out_typespec . type == typespec . type do
154
+ result
155
+ else
156
+ convert ( result , out_typespec )
150
157
end
151
158
end
152
159
153
- def is_nan ( % Value { function: func } = operand , typespec ) do
160
+ def is_nan ( % Value { function: func } = operand , out_typespec ) do
154
161
% { type: type } = get_typespec ( operand )
155
162
156
- typespec = Typespec . to_type ( typespec , { :pred , 8 } )
163
+ typespec = Typespec . to_type ( out_typespec , { :pred , 8 } )
157
164
158
- cond do
159
- Nx.Type . complex? ( type ) ->
160
- float_typespec = Typespec . to_type ( typespec , complex_part_type ( type ) )
161
- real = real ( operand , float_typespec )
162
- imag = imag ( operand , float_typespec )
163
- is_nan_real = is_nan ( real , typespec )
164
- is_nan_imag = is_nan ( imag , typespec )
165
- bitwise_or ( is_nan_real , is_nan_imag , typespec )
166
-
167
- Nx.Type . integer? ( type ) ->
168
- # Integers are never nan. We use inequality to make sure
169
- # the operand is still a part of the computation
170
- not_equal ( operand , operand , typespec )
165
+ result =
166
+ cond do
167
+ Nx.Type . complex? ( type ) ->
168
+ float_typespec = Typespec . to_type ( typespec , complex_part_type ( type ) )
169
+ real = real ( operand , float_typespec )
170
+ imag = imag ( operand , float_typespec )
171
+ is_nan_real = is_nan ( real , typespec )
172
+ is_nan_imag = is_nan ( imag , typespec )
173
+ bitwise_or ( is_nan_real , is_nan_imag , typespec )
174
+
175
+ Nx.Type . integer? ( type ) ->
176
+ # Integers are never nan. We use inequality to make sure
177
+ # the operand is still a part of the computation
178
+ not_equal ( operand , operand , typespec )
179
+
180
+ true ->
181
+ result_types = typespecs_to_mlir_types ( [ typespec ] )
182
+ is_inf = op ( func , "chlo.is_inf" , [ operand ] , result_types ) |> one! ( )
183
+ is_finite = op ( func , "stablehlo.is_finite" , [ operand ] , result_types ) |> one! ( )
184
+ is_not_inf = bitwise_not ( is_inf , typespec )
185
+ is_not_finite = bitwise_not ( is_finite , typespec )
186
+ bitwise_and ( is_not_inf , is_not_finite , typespec )
187
+ end
171
188
172
- true ->
173
- result_types = typespecs_to_mlir_types ( [ typespec ] )
174
- is_inf = op ( func , "chlo.is_inf" , [ operand ] , result_types ) |> one! ( )
175
- is_finite = op ( func , "stablehlo.is_finite" , [ operand ] , result_types ) |> one! ( )
176
- is_not_inf = bitwise_not ( is_inf , typespec )
177
- is_not_finite = bitwise_not ( is_finite , typespec )
178
- bitwise_and ( is_not_inf , is_not_finite , typespec )
189
+ if out_typespec . type == typespec . type do
190
+ result
191
+ else
192
+ convert ( result , out_typespec )
179
193
end
180
194
end
181
195
@@ -706,6 +720,10 @@ defmodule EXLA.MLIR.Value do
706
720
op ( func , "stablehlo.while" , initial , result_types , regions: regions )
707
721
end
708
722
723
+ def func_return ( func , values ) when is_list ( values ) do
724
+ op ( func , "func.return" , values , [ ] )
725
+ end
726
+
709
727
def return ( func , values ) when is_list ( values ) do
710
728
op ( func , "stablehlo.return" , values , [ ] )
711
729
end
0 commit comments