@@ -167,24 +167,81 @@ dummy_func(void) {
167167 }
168168
169169 op (_BINARY_OP , (left , right -- res )) {
170- PyTypeObject * ltype = sym_get_type (left );
171- PyTypeObject * rtype = sym_get_type (right );
172- if (ltype != NULL && (ltype == & PyLong_Type || ltype == & PyFloat_Type ) &&
173- rtype != NULL && (rtype == & PyLong_Type || rtype == & PyFloat_Type ))
174- {
175- if (oparg != NB_TRUE_DIVIDE && oparg != NB_INPLACE_TRUE_DIVIDE &&
176- ltype == & PyLong_Type && rtype == & PyLong_Type ) {
177- /* If both inputs are ints and the op is not division the result is an int */
178- res = sym_new_type (ctx , & PyLong_Type );
170+ bool lhs_int = sym_matches_type (left , & PyLong_Type );
171+ bool rhs_int = sym_matches_type (right , & PyLong_Type );
172+ bool lhs_float = sym_matches_type (left , & PyFloat_Type );
173+ bool rhs_float = sym_matches_type (right , & PyFloat_Type );
174+ if ((!lhs_int && !lhs_float ) || (!rhs_int && !rhs_float )) {
175+ res = sym_new_unknown (ctx );
176+ goto binary_op_done ;
177+ }
178+ if (oparg == NB_POWER || oparg == NB_INPLACE_POWER ) {
179+ // This one's fun: the *type* of the result depends on the *values*
180+ // being exponentiated. But exponents with one constant part are
181+ // reasonably common, so it's probably worth trying to be precise:
182+ PyObject * lhs_const = sym_get_const (left );
183+ PyObject * rhs_const = sym_get_const (right );
184+ if (lhs_int && rhs_int ) {
185+ if (rhs_const == NULL ) {
186+ // Unknown RHS means either int or float:
187+ res = sym_new_unknown (ctx );
188+ goto binary_op_done ;
189+ }
190+ if (!_PyLong_IsNegative ((PyLongObject * )rhs_const )) {
191+ // Non-negative RHS means int:
192+ res = sym_new_type (ctx , & PyLong_Type );
193+ goto binary_op_done ;
194+ }
195+ // Negative RHS uses float_pow...
179196 }
180- else {
181- /* For any other op combining ints/floats the result is a float */
197+ // Negative LHS *and* non-integral RHS means complex. So we need to
198+ // disprove at least one to prove a float result:
199+ if (rhs_int ) {
200+ // Integral RHS means float:
182201 res = sym_new_type (ctx , & PyFloat_Type );
202+ goto binary_op_done ;
203+ }
204+ if (rhs_const ) {
205+ double rhs_double = PyFloat_AS_DOUBLE (rhs_const );
206+ if (rhs_double == floor (rhs_double )) {
207+ // Integral RHS means float:
208+ res = sym_new_type (ctx , & PyFloat_Type );
209+ goto binary_op_done ;
210+ }
211+ }
212+ if (lhs_const ) {
213+ if (lhs_int ) {
214+ if (!_PyLong_IsNegative ((PyLongObject * )lhs_const )) {
215+ // Non-negative LHS means float:
216+ res = sym_new_type (ctx , & PyFloat_Type );
217+ goto binary_op_done ;
218+ }
219+ }
220+ else if (0.0 <= PyFloat_AS_DOUBLE (lhs_const )) {
221+ // Non-negative LHS means float:
222+ res = sym_new_type (ctx , & PyFloat_Type );
223+ goto binary_op_done ;
224+ }
225+ if (rhs_const ) {
226+ // If we have two constants and failed to disprove that it's
227+ // complex, then it's complex:
228+ res = sym_new_type (ctx , & PyComplex_Type );
229+ goto binary_op_done ;
230+ }
183231 }
232+ // Couldn't prove anything. It's either float or complex:
233+ res = sym_new_unknown (ctx );
234+ }
235+ else if (oparg == NB_TRUE_DIVIDE || oparg == NB_INPLACE_TRUE_DIVIDE ) {
236+ res = sym_new_type (ctx , & PyFloat_Type );
237+ }
238+ else if (lhs_int && rhs_int ) {
239+ res = sym_new_type (ctx , & PyLong_Type );
184240 }
185241 else {
186- res = sym_new_unknown (ctx );
242+ res = sym_new_type (ctx , & PyFloat_Type );
187243 }
244+ binary_op_done :
188245 }
189246
190247 op (_BINARY_OP_ADD_INT , (left , right -- res )) {
0 commit comments