Skip to content

Commit 62a15be

Browse files
committed
Add _PyBytesWriter_CreateByteArray()
Convert _PyBytes_FromHex().
1 parent bf60f7f commit 62a15be

File tree

4 files changed

+93
-41
lines changed

4 files changed

+93
-41
lines changed

Include/cpython/bytesobject.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ typedef struct PyBytesWriter PyBytesWriter;
4848

4949
PyAPI_FUNC(PyBytesWriter *) PyBytesWriter_Create(
5050
Py_ssize_t size);
51+
PyAPI_FUNC(PyBytesWriter*) _PyBytesWriter_CreateByteArray(
52+
Py_ssize_t size);
5153
PyAPI_FUNC(void) PyBytesWriter_Discard(
5254
PyBytesWriter *writer);
5355
PyAPI_FUNC(PyObject*) PyBytesWriter_Finish(

Lib/test/test_capi/test_bytes.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -291,82 +291,89 @@ def test_join(self):
291291
bytes_join(b'', NULL)
292292

293293

294-
class PyBytesWriterTest(unittest.TestCase):
294+
class BytesWriterTest(unittest.TestCase):
295295
SMALL_BUFFER = 256 # bytes
296+
result_type = bytes
296297

297298
def create_writer(self, alloc=0, string=b''):
298-
return _testcapi.PyBytesWriter(alloc, string)
299+
return _testcapi.PyBytesWriter(alloc, string, 0)
299300

300301
def test_create(self):
301302
# Test PyBytesWriter_Create()
302303
writer = self.create_writer()
303304
self.assertEqual(writer.get_size(), 0)
304305
self.assertEqual(writer.get_allocated(), self.SMALL_BUFFER)
305-
self.assertEqual(writer.finish(), b'')
306+
self.assertEqual(writer.finish(), self.result_type(b''))
306307

307308
writer = self.create_writer(3, b'abc')
308309
self.assertEqual(writer.get_size(), 3)
309310
self.assertEqual(writer.get_allocated(), self.SMALL_BUFFER)
310-
self.assertEqual(writer.finish(), b'abc')
311+
self.assertEqual(writer.finish(), self.result_type(b'abc'))
311312

312313
writer = self.create_writer(10, b'abc')
313314
self.assertEqual(writer.get_size(), 10)
314315
self.assertEqual(writer.get_allocated(), self.SMALL_BUFFER)
315-
self.assertEqual(writer.finish_with_size(3), b'abc')
316+
self.assertEqual(writer.finish_with_size(3), self.result_type(b'abc'))
316317

317318
def test_write_bytes(self):
318319
# Test PyBytesWriter_WriteBytes()
319320
writer = self.create_writer()
320321
writer.write_bytes(b'Hello World!', -1)
321-
self.assertEqual(writer.finish(), b'Hello World!')
322+
self.assertEqual(writer.finish(), self.result_type(b'Hello World!'))
322323

323324
writer = self.create_writer()
324325
writer.write_bytes(b'Hello ', -1)
325326
writer.write_bytes(b'World! <truncated>', 6)
326-
self.assertEqual(writer.finish(), b'Hello World!')
327+
self.assertEqual(writer.finish(), self.result_type(b'Hello World!'))
327328

328329
def test_resize(self):
329330
# Test PyBytesWriter_Resize()
330331
writer = self.create_writer()
331332
writer.resize(len(b'number=123456'), b'number=123456')
332333
writer.resize(len(b'number=123456'), b'')
333334
self.assertEqual(writer.get_size(), len(b'number=123456'))
334-
self.assertEqual(writer.finish(), b'number=123456')
335+
self.assertEqual(writer.finish(), self.result_type(b'number=123456'))
335336

336337
writer = self.create_writer()
337338
writer.resize(0, b'')
338339
writer.resize(len(b'number=123456'), b'number=123456')
339-
self.assertEqual(writer.finish(), b'number=123456')
340+
self.assertEqual(writer.finish(), self.result_type(b'number=123456'))
340341

341342
writer = self.create_writer()
342343
writer.resize(len(b'number='), b'number=')
343344
writer.resize(len(b'number=123456'), b'123456')
344-
self.assertEqual(writer.finish(), b'number=123456')
345+
self.assertEqual(writer.finish(), self.result_type(b'number=123456'))
345346

346347
writer = self.create_writer()
347348
writer.resize(len(b'number='), b'number=')
348349
writer.resize(len(b'number='), b'')
349350
writer.resize(len(b'number=123456'), b'123456')
350-
self.assertEqual(writer.finish(), b'number=123456')
351+
self.assertEqual(writer.finish(), self.result_type(b'number=123456'))
351352

352353
writer = self.create_writer()
353354
writer.resize(len(b'number'), b'number')
354355
writer.resize(len(b'number='), b'=')
355356
writer.resize(len(b'number=123'), b'123')
356357
writer.resize(len(b'number=123456'), b'456')
357-
self.assertEqual(writer.finish(), b'number=123456')
358+
self.assertEqual(writer.finish(), self.result_type(b'number=123456'))
358359

359360
def test_format_i(self):
360361
# Test PyBytesWriter_Format()
361362
writer = self.create_writer()
362363
writer.format_i(b'x=%i', 123456)
363-
self.assertEqual(writer.finish(), b'x=123456')
364+
self.assertEqual(writer.finish(), self.result_type(b'x=123456'))
364365

365366
writer = self.create_writer()
366367
writer.format_i(b'x=%i, ', 123)
367368
writer.format_i(b'y=%i', 456)
368-
self.assertEqual(writer.finish(), b'x=123, y=456')
369+
self.assertEqual(writer.finish(), self.result_type(b'x=123, y=456'))
369370

370371

372+
class ByteArrayWriterTest(BytesWriterTest):
373+
result_type = bytearray
374+
375+
def create_writer(self, alloc=0, string=b''):
376+
return _testcapi.PyBytesWriter(alloc, string, 1)
377+
371378
if __name__ == "__main__":
372379
unittest.main()

Modules/_testcapi/bytes.c

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,18 @@ writer_init(PyObject *self_raw, PyObject *args, PyObject *kwargs)
8888
Py_ssize_t alloc;
8989
char *str;
9090
Py_ssize_t str_size;
91-
if (!PyArg_ParseTuple(args, "ny#", &alloc, &str, &str_size)) {
91+
int use_bytearray;
92+
if (!PyArg_ParseTuple(args, "ny#i",
93+
&alloc, &str, &str_size, &use_bytearray)) {
9294
return -1;
9395
}
9496

95-
self->writer = PyBytesWriter_Create(alloc);
97+
if (use_bytearray) {
98+
self->writer = _PyBytesWriter_CreateByteArray(alloc);
99+
}
100+
else {
101+
self->writer = PyBytesWriter_Create(alloc);
102+
}
96103
if (self->writer == NULL) {
97104
return -1;
98105
}

Objects/bytesobject.c

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2515,17 +2515,13 @@ bytes_fromhex_impl(PyTypeObject *type, PyObject *string)
25152515
PyObject*
25162516
_PyBytes_FromHex(PyObject *string, int use_bytearray)
25172517
{
2518-
char *buf;
25192518
Py_ssize_t hexlen, invalid_char;
25202519
unsigned int top, bot;
25212520
const Py_UCS1 *str, *start, *end;
2522-
_PyBytesWriter writer;
2521+
PyBytesWriter *writer = NULL;
25232522
Py_buffer view;
25242523
view.obj = NULL;
25252524

2526-
_PyBytesWriter_Init(&writer);
2527-
writer.use_bytearray = use_bytearray;
2528-
25292525
if (PyUnicode_Check(string)) {
25302526
hexlen = PyUnicode_GET_LENGTH(string);
25312527

@@ -2561,10 +2557,11 @@ _PyBytes_FromHex(PyObject *string, int use_bytearray)
25612557
}
25622558

25632559
/* This overestimates if there are spaces */
2564-
buf = _PyBytesWriter_Alloc(&writer, hexlen / 2);
2565-
if (buf == NULL) {
2560+
writer = _PyBytesWriter_CreateByteArray(hexlen / 2);
2561+
if (writer == NULL) {
25662562
goto release_buffer;
25672563
}
2564+
char *buf = PyBytesWriter_GetData(writer);
25682565

25692566
start = str;
25702567
end = str + hexlen;
@@ -2603,7 +2600,7 @@ _PyBytes_FromHex(PyObject *string, int use_bytearray)
26032600
if (view.obj != NULL) {
26042601
PyBuffer_Release(&view);
26052602
}
2606-
return _PyBytesWriter_Finish(&writer, buf);
2603+
return PyBytesWriter_FinishWithEndPointer(writer, buf);
26072604

26082605
error:
26092606
if (invalid_char == -1) {
@@ -2614,7 +2611,7 @@ _PyBytes_FromHex(PyObject *string, int use_bytearray)
26142611
"non-hexadecimal number found in "
26152612
"fromhex() arg at position %zd", invalid_char);
26162613
}
2617-
_PyBytesWriter_Dealloc(&writer);
2614+
PyBytesWriter_Discard(writer);
26182615

26192616
release_buffer:
26202617
if (view.obj != NULL) {
@@ -3737,6 +3734,7 @@ struct PyBytesWriter {
37373734
char small_buffer[256];
37383735
PyObject *obj;
37393736
Py_ssize_t size;
3737+
int use_bytearray;
37403738
};
37413739

37423740

@@ -3758,6 +3756,9 @@ byteswriter_allocated(PyBytesWriter *writer)
37583756
if (writer->obj == NULL) {
37593757
return sizeof(writer->small_buffer);
37603758
}
3759+
else if (writer->use_bytearray) {
3760+
return PyByteArray_GET_SIZE(writer->obj);
3761+
}
37613762
else {
37623763
return PyBytes_GET_SIZE(writer->obj);
37633764
}
@@ -3778,15 +3779,8 @@ byteswriter_resize(PyBytesWriter *writer, Py_ssize_t size, int overallocate)
37783779
{
37793780
assert(size >= 0);
37803781

3781-
if (writer->obj == NULL) {
3782-
if ((size_t)size <= sizeof(writer->small_buffer)) {
3783-
return 0;
3784-
}
3785-
}
3786-
else {
3787-
if (size <= PyBytes_GET_SIZE(writer->obj)) {
3788-
return 0;
3789-
}
3782+
if (size <= byteswriter_allocated(writer)) {
3783+
return 0;
37903784
}
37913785

37923786
if (overallocate) {
@@ -3796,11 +3790,28 @@ byteswriter_resize(PyBytesWriter *writer, Py_ssize_t size, int overallocate)
37963790
}
37973791

37983792
if (writer->obj != NULL) {
3799-
if (_PyBytes_Resize(&writer->obj, size)) {
3800-
return -1;
3793+
if (writer->use_bytearray) {
3794+
if (PyByteArray_Resize(writer->obj, size)) {
3795+
return -1;
3796+
}
3797+
}
3798+
else {
3799+
if (_PyBytes_Resize(&writer->obj, size)) {
3800+
return -1;
3801+
}
38013802
}
38023803
assert(writer->obj != NULL);
38033804
}
3805+
else if (writer->use_bytearray) {
3806+
writer->obj = PyByteArray_FromStringAndSize(NULL, size);
3807+
if (writer->obj == NULL) {
3808+
return -1;
3809+
}
3810+
assert((size_t)size > sizeof(writer->small_buffer));
3811+
memcpy(PyByteArray_AS_STRING(writer->obj),
3812+
writer->small_buffer,
3813+
sizeof(writer->small_buffer));
3814+
}
38043815
else {
38053816
writer->obj = PyBytes_FromStringAndSize(NULL, size);
38063817
if (writer->obj == NULL) {
@@ -3815,8 +3826,8 @@ byteswriter_resize(PyBytesWriter *writer, Py_ssize_t size, int overallocate)
38153826
}
38163827

38173828

3818-
PyBytesWriter*
3819-
PyBytesWriter_Create(Py_ssize_t size)
3829+
static PyBytesWriter*
3830+
byteswriter_create(Py_ssize_t size, int use_bytearray)
38203831
{
38213832
if (size < 0) {
38223833
PyErr_SetString(PyExc_ValueError, "size must be >= 0");
@@ -3833,6 +3844,7 @@ PyBytesWriter_Create(Py_ssize_t size)
38333844
}
38343845
writer->obj = NULL;
38353846
writer->size = 0;
3847+
writer->use_bytearray = use_bytearray;
38363848

38373849
if (size >= 1) {
38383850
if (byteswriter_resize(writer, size, 0) < 0) {
@@ -3844,6 +3856,18 @@ PyBytesWriter_Create(Py_ssize_t size)
38443856
return writer;
38453857
}
38463858

3859+
PyBytesWriter*
3860+
PyBytesWriter_Create(Py_ssize_t size)
3861+
{
3862+
return byteswriter_create(size, 0);
3863+
}
3864+
3865+
PyBytesWriter*
3866+
_PyBytesWriter_CreateByteArray(Py_ssize_t size)
3867+
{
3868+
return byteswriter_create(size, 1);
3869+
}
3870+
38473871

38483872
void
38493873
PyBytesWriter_Discard(PyBytesWriter *writer)
@@ -3865,14 +3889,26 @@ PyBytesWriter_FinishWithSize(PyBytesWriter *writer, Py_ssize_t size)
38653889
result = bytes_get_empty();
38663890
}
38673891
else if (writer->obj != NULL) {
3868-
if (size != PyBytes_GET_SIZE(writer->obj)) {
3869-
if (_PyBytes_Resize(&writer->obj, size)) {
3870-
goto error;
3892+
if (writer->use_bytearray) {
3893+
if (size != PyByteArray_GET_SIZE(writer->obj)) {
3894+
if (PyByteArray_Resize(writer->obj, size)) {
3895+
goto error;
3896+
}
3897+
}
3898+
}
3899+
else {
3900+
if (size != PyBytes_GET_SIZE(writer->obj)) {
3901+
if (_PyBytes_Resize(&writer->obj, size)) {
3902+
goto error;
3903+
}
38713904
}
38723905
}
38733906
result = writer->obj;
38743907
writer->obj = NULL;
38753908
}
3909+
else if (writer->use_bytearray) {
3910+
result = PyByteArray_FromStringAndSize(writer->small_buffer, size);
3911+
}
38763912
else {
38773913
result = PyBytes_FromStringAndSize(writer->small_buffer, size);
38783914
}

0 commit comments

Comments
 (0)