Skip to content

Commit 4175f7c

Browse files
committed
Add support for split() and subgroup() to BaseException.
1 parent 834ba5a commit 4175f7c

File tree

2 files changed

+162
-47
lines changed

2 files changed

+162
-47
lines changed

Lib/test/test_exception_group.py

Lines changed: 96 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -308,32 +308,39 @@ class ExceptionGroupSubgroupTests(ExceptionGroupTestBase):
308308
def setUp(self):
309309
self.eg = create_simple_eg()
310310
self.eg_template = [ValueError(1), TypeError(int), ValueError(2)]
311+
self.exc = ValueError('a')
311312

312313
def test_basics_subgroup_split__bad_arg_type(self):
313-
class C:
314-
pass
315-
316-
bad_args = ["bad arg",
317-
C,
318-
OSError('instance not type'),
319-
[OSError, TypeError],
320-
(OSError, 42),
321-
]
322-
for arg in bad_args:
323-
with self.assertRaises(TypeError):
324-
self.eg.subgroup(arg)
325-
with self.assertRaises(TypeError):
326-
self.eg.split(arg)
314+
for obj in (self.eg, self.exc):
315+
for arg in (
316+
"bad arg",
317+
type('NewClass', (), {}),
318+
OSError('instance not type'),
319+
[OSError, TypeError],
320+
(OSError, 42),
321+
):
322+
with self.subTest(obj_type=type(obj), arg=arg):
323+
with self.assertRaises(TypeError):
324+
obj.subgroup(arg)
325+
326+
with self.assertRaises(TypeError):
327+
obj.split(arg)
327328

328329
def test_basics_subgroup_by_type__passthrough(self):
329-
eg = self.eg
330-
self.assertIs(eg, eg.subgroup(BaseException))
331-
self.assertIs(eg, eg.subgroup(Exception))
332-
self.assertIs(eg, eg.subgroup(BaseExceptionGroup))
333-
self.assertIs(eg, eg.subgroup(ExceptionGroup))
330+
for exc_type in (
331+
BaseException, Exception,
332+
BaseExceptionGroup, ExceptionGroup,
333+
):
334+
with self.subTest(exc_type):
335+
self.assertIs(self.eg, self.eg.subgroup(exc_type))
336+
337+
wrapped = self.exc.subgroup(ValueError)
338+
self.assertEqual(wrapped.message, str(self.exc))
339+
self.assertMatchesTemplate(wrapped, ExceptionGroup, [self.exc])
334340

335341
def test_basics_subgroup_by_type__no_match(self):
336342
self.assertIsNone(self.eg.subgroup(OSError))
343+
self.assertIsNone(self.exc.subgroup(OSError))
337344

338345
def test_basics_subgroup_by_type__match(self):
339346
eg = self.eg
@@ -349,15 +356,24 @@ def test_basics_subgroup_by_type__match(self):
349356
self.assertEqual(subeg.message, eg.message)
350357
self.assertMatchesTemplate(subeg, ExceptionGroup, template)
351358

359+
wrapped = self.exc.subgroup(ValueError)
360+
self.assertEqual(wrapped.message, str(self.exc))
361+
self.assertMatchesTemplate(wrapped, ExceptionGroup, [self.exc])
362+
352363
def test_basics_subgroup_by_predicate__passthrough(self):
353364
f = lambda e: True
354365
for callable in [f, Predicate(f), Predicate(f).method]:
355366
self.assertIs(self.eg, self.eg.subgroup(callable))
356367

368+
wrapped = self.exc.subgroup(callable)
369+
self.assertEqual(wrapped.message, str(self.exc))
370+
self.assertMatchesTemplate(wrapped, ExceptionGroup, [self.exc])
371+
357372
def test_basics_subgroup_by_predicate__no_match(self):
358373
f = lambda e: False
359374
for callable in [f, Predicate(f), Predicate(f).method]:
360375
self.assertIsNone(self.eg.subgroup(callable))
376+
self.assertIsNone(self.exc.subgroup(callable))
361377

362378
def test_basics_subgroup_by_predicate__match(self):
363379
eg = self.eg
@@ -371,40 +387,61 @@ def test_basics_subgroup_by_predicate__match(self):
371387
f = lambda e: isinstance(e, match_type)
372388
for callable in [f, Predicate(f), Predicate(f).method]:
373389
with self.subTest(callable=callable):
374-
subeg = eg.subgroup(f)
390+
subeg = eg.subgroup(callable)
375391
self.assertEqual(subeg.message, eg.message)
376392
self.assertMatchesTemplate(subeg, ExceptionGroup, template)
377393

394+
f = lambda e: isinstance(e, ValueError)
395+
for callable in [f, Predicate(f), Predicate(f).method]:
396+
group = self.exc.subgroup(callable)
397+
self.assertEqual(group.message, str(self.exc))
398+
self.assertMatchesTemplate(group, ExceptionGroup, [self.exc])
399+
378400

379401
class ExceptionGroupSplitTests(ExceptionGroupTestBase):
380402
def setUp(self):
381403
self.eg = create_simple_eg()
382404
self.eg_template = [ValueError(1), TypeError(int), ValueError(2)]
383405

406+
self.exc = ValueError('a')
407+
384408
def test_basics_split_by_type__passthrough(self):
385-
for E in [BaseException, Exception,
386-
BaseExceptionGroup, ExceptionGroup]:
387-
match, rest = self.eg.split(E)
388-
self.assertMatchesTemplate(
389-
match, ExceptionGroup, self.eg_template)
390-
self.assertIsNone(rest)
409+
for exc_type in (
410+
BaseException, Exception,
411+
BaseExceptionGroup, ExceptionGroup,
412+
):
413+
with self.subTest(exc_type):
414+
match, rest = self.eg.split(exc_type)
415+
self.assertMatchesTemplate(match, ExceptionGroup,
416+
self.eg_template)
417+
self.assertIsNone(rest)
418+
419+
match, rest = self.exc.split(exc_type)
420+
self.assertEqual(match.message, str(self.exc))
421+
self.assertMatchesTemplate(match, ExceptionGroup, [self.exc])
422+
self.assertIsNone(rest)
391423

392424
def test_basics_split_by_type__no_match(self):
393425
match, rest = self.eg.split(OSError)
394426
self.assertIsNone(match)
395-
self.assertMatchesTemplate(
396-
rest, ExceptionGroup, self.eg_template)
427+
self.assertMatchesTemplate(rest, ExceptionGroup, self.eg_template)
428+
429+
match, rest = self.exc.split(OSError)
430+
self.assertIsNone(match)
431+
self.assertMatchesTemplate(rest, ExceptionGroup, [self.exc])
397432

398433
def test_basics_split_by_type__match(self):
399434
eg = self.eg
400-
VE = ValueError
401-
TE = TypeError
402435
testcases = [
403-
# (matcher, match_template, rest_template)
404-
(VE, [VE(1), VE(2)], [TE(int)]),
405-
(TE, [TE(int)], [VE(1), VE(2)]),
406-
((VE, TE), self.eg_template, None),
407-
((OSError, VE), [VE(1), VE(2)], [TE(int)]),
436+
# (exc_or_eg, matcher, match_template, rest_template)
437+
(ValueError, [ValueError(1), ValueError(2)], [TypeError(int)]),
438+
(TypeError, [TypeError(int)], [ValueError(1), ValueError(2)]),
439+
((ValueError, TypeError), self.eg_template, None),
440+
(
441+
(OSError, ValueError),
442+
[ValueError(1), ValueError(2)],
443+
[TypeError(int)],
444+
),
408445
]
409446

410447
for match_type, match_template, rest_template in testcases:
@@ -419,29 +456,41 @@ def test_basics_split_by_type__match(self):
419456
else:
420457
self.assertIsNone(rest)
421458

459+
match, rest = self.exc.split(ValueError)
460+
self.assertEqual(match.message, str(self.exc))
461+
self.assertMatchesTemplate(match, ExceptionGroup, [self.exc])
462+
self.assertIsNone(rest)
463+
422464
def test_basics_split_by_predicate__passthrough(self):
423465
f = lambda e: True
424466
for callable in [f, Predicate(f), Predicate(f).method]:
425467
match, rest = self.eg.split(callable)
426468
self.assertMatchesTemplate(match, ExceptionGroup, self.eg_template)
427469
self.assertIsNone(rest)
428470

471+
match, rest = self.exc.split(callable)
472+
self.assertEqual(match.message, str(self.exc))
473+
self.assertMatchesTemplate(match, ExceptionGroup, [self.exc])
474+
self.assertIsNone(rest)
475+
429476
def test_basics_split_by_predicate__no_match(self):
430477
f = lambda e: False
431478
for callable in [f, Predicate(f), Predicate(f).method]:
432479
match, rest = self.eg.split(callable)
433480
self.assertIsNone(match)
434481
self.assertMatchesTemplate(rest, ExceptionGroup, self.eg_template)
435482

483+
match, rest = self.exc.split(callable)
484+
self.assertIsNone(match)
485+
self.assertMatchesTemplate(rest, ExceptionGroup, [self.exc])
486+
436487
def test_basics_split_by_predicate__match(self):
437488
eg = self.eg
438-
VE = ValueError
439-
TE = TypeError
440489
testcases = [
441490
# (matcher, match_template, rest_template)
442-
(VE, [VE(1), VE(2)], [TE(int)]),
443-
(TE, [TE(int)], [VE(1), VE(2)]),
444-
((VE, TE), self.eg_template, None),
491+
(ValueError, [ValueError(1), ValueError(2)], [TypeError(int)]),
492+
(TypeError, [TypeError(int)], [ValueError(1), ValueError(2)]),
493+
((ValueError, TypeError), self.eg_template, None),
445494
]
446495

447496
for match_type, match_template, rest_template in testcases:
@@ -456,6 +505,13 @@ def test_basics_split_by_predicate__match(self):
456505
self.assertMatchesTemplate(
457506
rest, ExceptionGroup, rest_template)
458507

508+
f = lambda e: isinstance(e, ValueError)
509+
for callable in [f, Predicate(f), Predicate(f).method]:
510+
match, rest = self.exc.split(callable)
511+
self.assertEqual(match.message, str(self.exc))
512+
self.assertMatchesTemplate(match, ExceptionGroup, [self.exc])
513+
self.assertIsNone(rest)
514+
459515

460516
class DeepRecursionInSplitAndSubgroup(unittest.TestCase):
461517
def make_deep_eg(self):

Objects/exceptions.c

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,18 +269,77 @@ BaseException_add_note(PyObject *self, PyObject *note)
269269
Py_RETURN_NONE;
270270
}
271271

272+
/*
273+
* Return an exception group wrapping 'self'.
274+
*
275+
* If 'self' is an exception group, a strong
276+
* reference to 'self' is returned instead.
277+
*/
278+
static PyObject *
279+
base_exception_as_group(PyObject *self)
280+
{
281+
if (!PyExceptionInstance_Check(self)) {
282+
PyErr_BadInternalCall();
283+
return NULL;
284+
}
285+
286+
if (_PyBaseExceptionGroup_Check(self)) {
287+
return Py_NewRef(self);
288+
}
289+
290+
PyObject *message = PyObject_Str(self);
291+
if (message == NULL) {
292+
return NULL;
293+
}
294+
PyObject *wrapped = PyTuple_New(1);
295+
if (wrapped == NULL) {
296+
Py_DECREF(message);
297+
return NULL;
298+
}
299+
PyTuple_SET_ITEM(wrapped, 0, Py_NewRef(self));
300+
PyObject *group = PyObject_CallFunctionObjArgs(PyExc_BaseExceptionGroup,
301+
message, wrapped, NULL);
302+
Py_DECREF(wrapped);
303+
Py_DECREF(message);
304+
return group;
305+
}
306+
307+
static PyObject *
308+
BaseException_subgroup(PyObject *self, PyObject *matcher)
309+
{
310+
PyObject *group = base_exception_as_group(self);
311+
if (group == NULL) {
312+
return NULL;
313+
}
314+
PyObject *res = PyObject_CallMethodOneArg(group, &_Py_ID(subgroup), matcher);
315+
Py_DECREF(group);
316+
return res;
317+
}
318+
319+
static PyObject *
320+
BaseException_split(PyObject *self, PyObject *matcher)
321+
{
322+
PyObject *group = base_exception_as_group(self);
323+
if (group == NULL) {
324+
return NULL;
325+
}
326+
PyObject *res = PyObject_CallMethodOneArg(group, &_Py_ID(split), matcher);
327+
Py_DECREF(group);
328+
return res;
329+
}
330+
272331
PyDoc_STRVAR(add_note_doc,
273332
"Exception.add_note(note) --\n\
274333
add a note to the exception");
275334

276335
static PyMethodDef BaseException_methods[] = {
277-
{"__reduce__", (PyCFunction)BaseException_reduce, METH_NOARGS },
278-
{"__setstate__", (PyCFunction)BaseException_setstate, METH_O },
279-
{"with_traceback", (PyCFunction)BaseException_with_traceback, METH_O,
280-
with_traceback_doc},
281-
{"add_note", (PyCFunction)BaseException_add_note, METH_O,
282-
add_note_doc},
283-
{NULL, NULL, 0, NULL},
336+
{"__reduce__", (PyCFunction)BaseException_reduce, METH_NOARGS, NULL},
337+
{"__setstate__", BaseException_setstate, METH_O, NULL},
338+
{"with_traceback", BaseException_with_traceback, METH_O, with_traceback_doc},
339+
{"add_note", BaseException_add_note, METH_O, add_note_doc},
340+
{"split", BaseException_split, METH_O, NULL},
341+
{"subgroup", BaseException_subgroup, METH_O, NULL},
342+
{NULL, NULL, 0, NULL},
284343
};
285344

286345
static PyObject *

0 commit comments

Comments
 (0)