@@ -384,6 +384,184 @@ def test_reverse_ops(self):
384
384
assert 1 @ tester == 123 , "__rmatmul__ failed"
385
385
386
386
387
+ def test_str_subclass (self ):
388
+ TestStrSubclass = CPyExtType ("TestStrSubclass" ,
389
+ r"""
390
+ static PyTypeObject* testStrSubclassPtr = NULL;
391
+
392
+ #define MAX_UNICODE 0x10ffff
393
+
394
+ #define _PyUnicode_UTF8(op) \
395
+ (((PyCompactUnicodeObject*)(op))->utf8)
396
+ #define PyUnicode_UTF8(op) \
397
+ (assert(_PyUnicode_CHECK(op)), \
398
+ assert(PyUnicode_IS_READY(op)), \
399
+ PyUnicode_IS_COMPACT_ASCII(op) ? \
400
+ ((char*)((PyASCIIObject*)(op) + 1)) : \
401
+ _PyUnicode_UTF8(op))
402
+ #define _PyUnicode_UTF8_LENGTH(op) \
403
+ (((PyCompactUnicodeObject*)(op))->utf8_length)
404
+ #define PyUnicode_UTF8_LENGTH(op) \
405
+ (assert(_PyUnicode_CHECK(op)), \
406
+ assert(PyUnicode_IS_READY(op)), \
407
+ PyUnicode_IS_COMPACT_ASCII(op) ? \
408
+ ((PyASCIIObject*)(op))->length : \
409
+ _PyUnicode_UTF8_LENGTH(op))
410
+ #define _PyUnicode_WSTR(op) \
411
+ (((PyASCIIObject*)(op))->wstr)
412
+ #define _PyUnicode_WSTR_LENGTH(op) \
413
+ (((PyCompactUnicodeObject*)(op))->wstr_length)
414
+ #define _PyUnicode_LENGTH(op) \
415
+ (((PyASCIIObject *)(op))->length)
416
+ #define _PyUnicode_STATE(op) \
417
+ (((PyASCIIObject *)(op))->state)
418
+ #define _PyUnicode_HASH(op) \
419
+ (((PyASCIIObject *)(op))->hash)
420
+ #define _PyUnicode_KIND(op) \
421
+ (assert(_PyUnicode_CHECK(op)), \
422
+ ((PyASCIIObject *)(op))->state.kind)
423
+ #define _PyUnicode_GET_LENGTH(op) \
424
+ (assert(_PyUnicode_CHECK(op)), \
425
+ ((PyASCIIObject *)(op))->length)
426
+ #define _PyUnicode_DATA_ANY(op) \
427
+ (((PyUnicodeObject*)(op))->data.any)
428
+
429
+ // that's taken from CPython's 'PyUnicode_New'
430
+ static PyUnicodeObject * new_empty_unicode(Py_ssize_t size, Py_UCS4 maxchar) {
431
+ PyUnicodeObject *obj;
432
+ PyCompactUnicodeObject *unicode;
433
+ void *data;
434
+ enum PyUnicode_Kind kind;
435
+ int is_sharing, is_ascii;
436
+ Py_ssize_t char_size;
437
+ Py_ssize_t struct_size;
438
+
439
+ is_ascii = 0;
440
+ is_sharing = 0;
441
+ struct_size = sizeof(PyCompactUnicodeObject);
442
+ if (maxchar < 128) {
443
+ kind = PyUnicode_1BYTE_KIND;
444
+ char_size = 1;
445
+ is_ascii = 1;
446
+ struct_size = sizeof(PyASCIIObject);
447
+ }
448
+ else if (maxchar < 256) {
449
+ kind = PyUnicode_1BYTE_KIND;
450
+ char_size = 1;
451
+ }
452
+ else if (maxchar < 65536) {
453
+ kind = PyUnicode_2BYTE_KIND;
454
+ char_size = 2;
455
+ if (sizeof(wchar_t) == 2)
456
+ is_sharing = 1;
457
+ }
458
+ else {
459
+ if (maxchar > MAX_UNICODE) {
460
+ PyErr_SetString(PyExc_SystemError,
461
+ "invalid maximum character passed to PyUnicode_New");
462
+ return NULL;
463
+ }
464
+ kind = PyUnicode_4BYTE_KIND;
465
+ char_size = 4;
466
+ if (sizeof(wchar_t) == 4)
467
+ is_sharing = 1;
468
+ }
469
+
470
+ /* Ensure we won't overflow the size. */
471
+ if (size < 0) {
472
+ PyErr_SetString(PyExc_SystemError,
473
+ "Negative size passed to PyUnicode_New");
474
+ return NULL;
475
+ }
476
+ if (size > ((PY_SSIZE_T_MAX - struct_size) / char_size - 1))
477
+ return NULL;
478
+
479
+ /* Duplicated allocation code from _PyObject_New() instead of a call to
480
+ * PyObject_New() so we are able to allocate space for the object and
481
+ * it's data buffer.
482
+ */
483
+ obj = (PyUnicodeObject *) malloc(struct_size + (size + 1) * char_size);
484
+ if (obj == NULL)
485
+ return NULL;
486
+ obj = PyObject_INIT(obj, &PyUnicode_Type);
487
+ if (obj == NULL)
488
+ return NULL;
489
+
490
+ unicode = (PyCompactUnicodeObject *)obj;
491
+ if (is_ascii)
492
+ data = ((PyASCIIObject*)obj) + 1;
493
+ else
494
+ data = unicode + 1;
495
+ _PyUnicode_LENGTH(unicode) = size;
496
+ _PyUnicode_HASH(unicode) = -1;
497
+ _PyUnicode_STATE(unicode).interned = 0;
498
+ _PyUnicode_STATE(unicode).kind = kind;
499
+ _PyUnicode_STATE(unicode).compact = 1;
500
+ _PyUnicode_STATE(unicode).ready = 1;
501
+ _PyUnicode_STATE(unicode).ascii = is_ascii;
502
+ if (is_ascii) {
503
+ ((char*)data)[size] = 0;
504
+ _PyUnicode_WSTR(unicode) = NULL;
505
+ }
506
+ else if (kind == PyUnicode_1BYTE_KIND) {
507
+ ((char*)data)[size] = 0;
508
+ _PyUnicode_WSTR(unicode) = NULL;
509
+ _PyUnicode_WSTR_LENGTH(unicode) = 0;
510
+ unicode->utf8 = NULL;
511
+ unicode->utf8_length = 0;
512
+ }
513
+ else {
514
+ unicode->utf8 = NULL;
515
+ unicode->utf8_length = 0;
516
+ if (kind == PyUnicode_2BYTE_KIND)
517
+ ((Py_UCS2*)data)[size] = 0;
518
+ else /* kind == PyUnicode_4BYTE_KIND */
519
+ ((Py_UCS4*)data)[size] = 0;
520
+ if (is_sharing) {
521
+ _PyUnicode_WSTR_LENGTH(unicode) = size;
522
+ _PyUnicode_WSTR(unicode) = (wchar_t *)data;
523
+ }
524
+ else {
525
+ _PyUnicode_WSTR_LENGTH(unicode) = 0;
526
+ _PyUnicode_WSTR(unicode) = NULL;
527
+ }
528
+ }
529
+ return obj;
530
+ }
531
+
532
+ static PyObject* nstr_tpnew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
533
+ char *ascii_data = NULL;
534
+ Py_XINCREF(args);
535
+ if (!PyArg_ParseTuple(args, "s", &ascii_data)) {{
536
+ return NULL;
537
+ }}
538
+ Py_ssize_t len = strlen(ascii_data);
539
+ PyUnicodeObject* strObj = new_empty_unicode(len, (Py_UCS4) 127);
540
+ memcpy(PyUnicode_1BYTE_DATA(strObj), (Py_UCS1*)ascii_data, len);
541
+ return (PyObject*) strObj;
542
+ }
543
+ """ ,
544
+ cmembers = """PyUnicodeObject base;
545
+ int marker;""" ,
546
+ tp_base = "&PyUnicode_Type" ,
547
+ tp_new = "nstr_tpnew" ,
548
+ post_ready_code = "testStrSubclassPtr = &TestStrSubclassType; Py_INCREF(testStrSubclassPtr);"
549
+ )
550
+ tester = TestStrSubclass ("hello\n world" )
551
+ assert tester == "hello\n world"
552
+ assert str (tester ) == "hello\n world"
553
+ assert tester .splitlines () == ['hello' , 'world' ]
554
+ assert tester >= "hello"
555
+ assert not (tester >= "helloasdfasdfasdf" )
556
+ assert tester <= "helloasdfasdfasdf"
557
+ assert not (tester <= "hello" )
558
+ assert tester .startswith ("hello" )
559
+ assert tester .endswith ("rld" )
560
+ assert tester .join (["a" , "b" ]) == "ahello\n worldb"
561
+ assert tester .upper () == "HELLO\n WORLD"
562
+ assert tester .replace ("o" , "uff" ) == "helluff\n wuffrld"
563
+ assert tester .replace ("o" , "uff" , 1 ) == "helluff\n world"
564
+
387
565
class TestObjectFunctions (CPyExtTestCase ):
388
566
def compile_module (self , name ):
389
567
type (self ).mro ()[1 ].__dict__ ["test_%s" % name ].create_module (name )
0 commit comments