Skip to content

Commit 4f474dd

Browse files
committed
update tests to reflect user errors
1 parent 303b13c commit 4f474dd

File tree

2 files changed

+94
-39
lines changed

2 files changed

+94
-39
lines changed

Lib/test/test_capi/test_codecs.py

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import codecs
22
import contextlib
33
import io
4+
import re
45
import sys
56
import unittest
67
import unittest.mock as mock
@@ -10,6 +11,7 @@
1011
_testlimitedcapi = import_helper.import_module('_testlimitedcapi')
1112

1213
NULL = None
14+
BAD_ARGUMENT = re.escape('bad argument type for built-in operation')
1315

1416

1517
class CAPIUnicodeTest(unittest.TestCase):
@@ -635,7 +637,8 @@ def test_codec_encode(self):
635637
self.assertEqual(encode('[é]', 'ascii', 'ignore'), b'[]')
636638

637639
self.assertRaises(TypeError, encode, NULL, 'ascii', 'strict')
638-
# CRASHES encode('a', NULL, 'strict')
640+
with self.assertRaisesRegex(TypeError, BAD_ARGUMENT):
641+
encode('a', NULL, 'strict')
639642

640643
def test_codec_decode(self):
641644
decode = _testcapi.codec_decode
@@ -650,46 +653,90 @@ def test_codec_decode(self):
650653
self.assertRaises(UnicodeDecodeError, decode, b, 'ascii', NULL)
651654
self.assertEqual(decode(b, 'ascii', 'replace'), 'a' + '\ufffd'*9)
652655

653-
# _codecs.decode only reports unknown errors policy when they are
654-
# used (it has a fast path for empty bytes); this is different from
655-
# PyUnicode_Decode which checks that both the encoding and the errors
656-
# policy are recognized.
656+
# _codecs.decode() only reports unknown errors policy when they are
657+
# used; this is different from PyUnicode_Decode() which checks that
658+
# both the encoding and the errors policy are recognized before even
659+
# attempting to call the decoder.
657660
self.assertEqual(decode(b'', 'utf-8', 'unknown-errors-policy'), '')
661+
self.assertEqual(decode(b'a', 'utf-8', 'unknown-errors-policy'), 'a')
658662

659663
self.assertRaises(TypeError, decode, NULL, 'ascii', 'strict')
660-
# CRASHES decode(b, NULL, 'strict')
664+
with self.assertRaisesRegex(TypeError, BAD_ARGUMENT):
665+
decode(b, NULL, 'strict')
661666

662667
def test_codec_encoder(self):
668+
codec_encoder = _testcapi.codec_encoder
669+
663670
with self.use_custom_encoder():
664-
encoder = _testcapi.codec_encoder(self.encoding_name)
671+
encoder = codec_encoder(self.encoding_name)
665672
self.assertIs(encoder, self.codec_info.encode)
666673

674+
with self.assertRaisesRegex(TypeError, BAD_ARGUMENT):
675+
codec_encoder(NULL)
676+
667677
def test_codec_decoder(self):
678+
codec_decoder = _testcapi.codec_decoder
679+
668680
with self.use_custom_encoder():
669-
decoder = _testcapi.codec_decoder(self.encoding_name)
681+
decoder = codec_decoder(self.encoding_name)
670682
self.assertIs(decoder, self.codec_info.decode)
671683

684+
with self.assertRaisesRegex(TypeError, BAD_ARGUMENT):
685+
codec_decoder(NULL)
686+
672687
def test_codec_incremental_encoder(self):
688+
codec_incremental_encoder = _testcapi.codec_incremental_encoder
689+
673690
with self.use_custom_encoder():
674-
encoder = _testcapi.codec_incremental_encoder(self.encoding_name, 'strict')
675-
self.assertIsInstance(encoder, self.codec_info.incrementalencoder)
691+
encoding = self.encoding_name
692+
693+
for policy in ['strict', NULL]:
694+
with self.subTest(policy=policy):
695+
encoder = codec_incremental_encoder(encoding, policy)
696+
self.assertIsInstance(encoder, self.codec_info.incrementalencoder)
697+
698+
with self.assertRaisesRegex(TypeError, BAD_ARGUMENT):
699+
codec_incremental_encoder(NULL, 'strict')
676700

677701
def test_codec_incremental_decoder(self):
702+
codec_incremental_decoder = _testcapi.codec_incremental_decoder
703+
678704
with self.use_custom_encoder():
679-
decoder = _testcapi.codec_incremental_decoder(self.encoding_name, 'strict')
680-
self.assertIsInstance(decoder, self.codec_info.incrementaldecoder)
705+
encoding = self.encoding_name
706+
707+
for policy in ['strict', NULL]:
708+
with self.subTest(policy=policy):
709+
decoder = codec_incremental_decoder(encoding, policy)
710+
self.assertIsInstance(decoder, self.codec_info.incrementaldecoder)
711+
712+
with self.assertRaisesRegex(TypeError, BAD_ARGUMENT):
713+
codec_incremental_decoder(NULL, 'strict')
681714

682715
def test_codec_stream_reader(self):
716+
codec_stream_reader = _testcapi.codec_stream_reader
717+
683718
with self.use_custom_encoder():
684719
encoding, stream = self.encoding_name, io.StringIO()
685-
reader = _testcapi.codec_stream_reader(encoding, stream, 'strict')
686-
self.assertIsInstance(reader, self.codec_info.streamreader)
720+
for policy in ['strict', NULL]:
721+
with self.subTest(policy=policy):
722+
writer = codec_stream_reader(encoding, stream, policy)
723+
self.assertIsInstance(writer, self.codec_info.streamreader)
724+
725+
with self.assertRaisesRegex(TypeError, BAD_ARGUMENT):
726+
codec_stream_reader(NULL, stream, 'strict')
687727

688728
def test_codec_stream_writer(self):
729+
codec_stream_writer = _testcapi.codec_stream_writer
730+
689731
with self.use_custom_encoder():
690732
encoding, stream = self.encoding_name, io.StringIO()
691-
writer = _testcapi.codec_stream_writer(encoding, stream, 'strict')
692-
self.assertIsInstance(writer, self.codec_info.streamwriter)
733+
for policy in ['strict', NULL]:
734+
with self.subTest(policy=policy):
735+
writer = codec_stream_writer(encoding, stream, policy)
736+
self.assertIsInstance(writer, self.codec_info.streamwriter)
737+
738+
with self.assertRaisesRegex(TypeError, BAD_ARGUMENT):
739+
codec_stream_writer(NULL, stream, 'strict')
693740

694741

695742
class CAPICodecErrors(unittest.TestCase):

Modules/_testcapi/codec.c

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
#include "parts.h"
22

3+
/*
4+
* The Codecs C API assume that 'encoding' is not NULL, lest
5+
* it uses PyErr_BadArgument() to set a TypeError exception.
6+
*
7+
* In this file, we allow to call the functions using None
8+
* as NULL to explicitly check this behaviour.
9+
*/
10+
311
// === Codecs registration and un-registration ================================
412

513
static PyObject *
@@ -23,8 +31,8 @@ codec_unregister(PyObject *Py_UNUSED(module), PyObject *search_function)
2331
static PyObject *
2432
codec_known_encoding(PyObject *Py_UNUSED(module), PyObject *args)
2533
{
26-
const char *encoding; // should not be NULL
27-
if (!PyArg_ParseTuple(args, "s", &encoding)) {
34+
const char *encoding; // should not be NULL (see top-file comment)
35+
if (!PyArg_ParseTuple(args, "z", &encoding)) {
2836
return NULL;
2937
}
3038
return PyCodec_KnownEncoding(encoding) ? Py_True : Py_False;
@@ -36,9 +44,9 @@ static PyObject *
3644
codec_encode(PyObject *Py_UNUSED(module), PyObject *args)
3745
{
3846
PyObject *input;
39-
const char *encoding; // should not be NULL
47+
const char *encoding; // should not be NULL (see top-file comment)
4048
const char *errors; // can be NULL
41-
if (!PyArg_ParseTuple(args, "O|sz", &input, &encoding, &errors)) {
49+
if (!PyArg_ParseTuple(args, "O|zz", &input, &encoding, &errors)) {
4250
return NULL;
4351
}
4452
return PyCodec_Encode(input, encoding, errors);
@@ -48,9 +56,9 @@ static PyObject *
4856
codec_decode(PyObject *Py_UNUSED(module), PyObject *args)
4957
{
5058
PyObject *input;
51-
const char *encoding; // should not be NULL
59+
const char *encoding; // should not be NULL (see top-file comment)
5260
const char *errors; // can be NULL
53-
if (!PyArg_ParseTuple(args, "O|sz", &input, &encoding, &errors)) {
61+
if (!PyArg_ParseTuple(args, "O|zz", &input, &encoding, &errors)) {
5462
return NULL;
5563
}
5664
return PyCodec_Decode(input, encoding, errors);
@@ -59,8 +67,8 @@ codec_decode(PyObject *Py_UNUSED(module), PyObject *args)
5967
static PyObject *
6068
codec_encoder(PyObject *Py_UNUSED(module), PyObject *args)
6169
{
62-
const char *encoding; // should not be NULL
63-
if (!PyArg_ParseTuple(args, "s", &encoding)) {
70+
const char *encoding; // should not be NULL (see top-file comment)
71+
if (!PyArg_ParseTuple(args, "z", &encoding)) {
6472
return NULL;
6573
}
6674
return PyCodec_Encoder(encoding);
@@ -69,8 +77,8 @@ codec_encoder(PyObject *Py_UNUSED(module), PyObject *args)
6977
static PyObject *
7078
codec_decoder(PyObject *Py_UNUSED(module), PyObject *args)
7179
{
72-
const char *encoding; // should not be NULL
73-
if (!PyArg_ParseTuple(args, "s", &encoding)) {
80+
const char *encoding; // should not be NULL (see top-file comment)
81+
if (!PyArg_ParseTuple(args, "z", &encoding)) {
7482
return NULL;
7583
}
7684
return PyCodec_Decoder(encoding);
@@ -79,9 +87,9 @@ codec_decoder(PyObject *Py_UNUSED(module), PyObject *args)
7987
static PyObject *
8088
codec_incremental_encoder(PyObject *Py_UNUSED(module), PyObject *args)
8189
{
82-
const char *encoding; // should not be NULL
83-
const char *errors; // should not be NULL
84-
if (!PyArg_ParseTuple(args, "ss", &encoding, &errors)) {
90+
const char *encoding; // should not be NULL (see top-file comment)
91+
const char *errors; // can be NULL
92+
if (!PyArg_ParseTuple(args, "zz", &encoding, &errors)) {
8593
return NULL;
8694
}
8795
return PyCodec_IncrementalEncoder(encoding, errors);
@@ -90,9 +98,9 @@ codec_incremental_encoder(PyObject *Py_UNUSED(module), PyObject *args)
9098
static PyObject *
9199
codec_incremental_decoder(PyObject *Py_UNUSED(module), PyObject *args)
92100
{
93-
const char *encoding; // should not be NULL
94-
const char *errors; // should not be NULL
95-
if (!PyArg_ParseTuple(args, "ss", &encoding, &errors)) {
101+
const char *encoding; // should not be NULL (see top-file comment)
102+
const char *errors; // can be NULL
103+
if (!PyArg_ParseTuple(args, "zz", &encoding, &errors)) {
96104
return NULL;
97105
}
98106
return PyCodec_IncrementalDecoder(encoding, errors);
@@ -101,10 +109,10 @@ codec_incremental_decoder(PyObject *Py_UNUSED(module), PyObject *args)
101109
static PyObject *
102110
codec_stream_reader(PyObject *Py_UNUSED(module), PyObject *args)
103111
{
104-
const char *encoding; // should not be NULL
112+
const char *encoding; // should not be NULL (see top-file comment)
105113
PyObject *stream;
106-
const char *errors; // should not be NULL
107-
if (!PyArg_ParseTuple(args, "sOs", &encoding, &stream, &errors)) {
114+
const char *errors; // can be NULL
115+
if (!PyArg_ParseTuple(args, "zOz", &encoding, &stream, &errors)) {
108116
return NULL;
109117
}
110118
return PyCodec_StreamReader(encoding, stream, errors);
@@ -113,10 +121,10 @@ codec_stream_reader(PyObject *Py_UNUSED(module), PyObject *args)
113121
static PyObject *
114122
codec_stream_writer(PyObject *Py_UNUSED(module), PyObject *args)
115123
{
116-
const char *encoding; // should not be NULL
124+
const char *encoding; // should not be NULL (see top-file comment)
117125
PyObject *stream;
118-
const char *errors; // should not be NULL
119-
if (!PyArg_ParseTuple(args, "sOs", &encoding, &stream, &errors)) {
126+
const char *errors; // can be NULL
127+
if (!PyArg_ParseTuple(args, "zOz", &encoding, &stream, &errors)) {
120128
return NULL;
121129
}
122130
return PyCodec_StreamWriter(encoding, stream, errors);
@@ -127,7 +135,7 @@ codec_stream_writer(PyObject *Py_UNUSED(module), PyObject *args)
127135
static PyObject *
128136
codec_register_error(PyObject *Py_UNUSED(module), PyObject *args)
129137
{
130-
const char *encoding; // should not be NULL
138+
const char *encoding; // must not be NULL
131139
PyObject *error;
132140
if (!PyArg_ParseTuple(args, "sO", &encoding, &error)) {
133141
return NULL;

0 commit comments

Comments
 (0)