Skip to content

Commit 39e8192

Browse files
serhiy-storchakamiss-islington
authored andcommitted
pythongh-130230: Fix crash in pow() with only Decimal third argument (pythonGH-130237)
(cherry picked from commit b93b7e5) Co-authored-by: Serhiy Storchaka <[email protected]>
1 parent fc1c9f8 commit 39e8192

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

Lib/test/test_decimal.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4458,6 +4458,15 @@ def test_implicit_context(self):
44584458
self.assertIs(Decimal("NaN").fma(7, 1).is_nan(), True)
44594459
# three arg power
44604460
self.assertEqual(pow(Decimal(10), 2, 7), 2)
4461+
if self.decimal == C:
4462+
self.assertEqual(pow(10, Decimal(2), 7), 2)
4463+
self.assertEqual(pow(10, 2, Decimal(7)), 2)
4464+
else:
4465+
# XXX: Three-arg power doesn't use __rpow__.
4466+
self.assertRaises(TypeError, pow, 10, Decimal(2), 7)
4467+
# XXX: There is no special method to dispatch on the
4468+
# third arg of three-arg power.
4469+
self.assertRaises(TypeError, pow, 10, 2, Decimal(7))
44614470
# exp
44624471
self.assertEqual(Decimal("1.01").exp(), 3)
44634472
# is_normal
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix crash in :func:`pow` with only :class:`~decimal.Decimal` third argument.

Modules/_decimal/_decimal.c

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,24 @@ find_state_left_or_right(PyObject *left, PyObject *right)
140140
return get_module_state(mod);
141141
}
142142

143+
static inline decimal_state *
144+
find_state_ternary(PyObject *left, PyObject *right, PyObject *modulus)
145+
{
146+
PyTypeObject *base;
147+
if (PyType_GetBaseByToken(Py_TYPE(left), &dec_spec, &base) != 1) {
148+
assert(!PyErr_Occurred());
149+
if (PyType_GetBaseByToken(Py_TYPE(right), &dec_spec, &base) != 1) {
150+
assert(!PyErr_Occurred());
151+
PyType_GetBaseByToken(Py_TYPE(modulus), &dec_spec, &base);
152+
}
153+
}
154+
assert(base != NULL);
155+
void *state = _PyType_GetModuleState(base);
156+
assert(state != NULL);
157+
Py_DECREF(base);
158+
return (decimal_state *)state;
159+
}
160+
143161

144162
#if !defined(MPD_VERSION_HEX) || MPD_VERSION_HEX < 0x02050000
145163
#error "libmpdec version >= 2.5.0 required"
@@ -4305,7 +4323,7 @@ nm_mpd_qpow(PyObject *base, PyObject *exp, PyObject *mod)
43054323
PyObject *context;
43064324
uint32_t status = 0;
43074325

4308-
decimal_state *state = find_state_left_or_right(base, exp);
4326+
decimal_state *state = find_state_ternary(base, exp, mod);
43094327
CURRENT_CONTEXT(state, context);
43104328
CONVERT_BINOP(&a, &b, base, exp, context);
43114329

0 commit comments

Comments
 (0)