Skip to content

Commit e1bf1d6

Browse files
ENH: Add partition/rpartition ufunc for string dtypes (numpy#26082)
* ENH: Add partition/rpartition ufunc for string dtypes Closes numpy#25993. * Fix doctests * Fix docstrings in ufunc_docstrings.py as well * Return array with the separators // optimize using find ufunc results * Address feedback * Fix chararray __array_finalize__ * ENH: add stringdtype partition/rpartition * BUG: remove unnecessary size_t cast * BUG: fix error handling and resource cleanup * MNT: refactor so stringdtype can combine find and partition * MNT: update signatures to reflect const API changes * MNT: simplfy fastsearch call * MNT: move variable binding out of inner loop * Fix error message about out; fix promoter * Remove unused import in defchararray; add assertion * BUG: don't use a user-provided descriptor to initialize a new stringdtype view * MNT: back out attempted fix for stringdtype view problem * MNT: address code review comments --------- Co-authored-by: Nathan Goldbaum <[email protected]>
1 parent 6d4a0f7 commit e1bf1d6

File tree

9 files changed

+850
-49
lines changed

9 files changed

+850
-49
lines changed

numpy/_core/code_generators/generate_umath.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,26 @@ def english_upper(s):
13001300
docstrings.get('numpy._core.umath._zfill'),
13011301
None,
13021302
),
1303+
'_partition_index':
1304+
Ufunc(3, 3, None,
1305+
docstrings.get('numpy._core.umath._partition_index'),
1306+
None,
1307+
),
1308+
'_rpartition_index':
1309+
Ufunc(3, 3, None,
1310+
docstrings.get('numpy._core.umath._rpartition_index'),
1311+
None,
1312+
),
1313+
'_partition':
1314+
Ufunc(2, 3, None,
1315+
docstrings.get('numpy._core.umath._partition'),
1316+
None,
1317+
),
1318+
'_rpartition':
1319+
Ufunc(2, 3, None,
1320+
docstrings.get('numpy._core.umath._rpartition'),
1321+
None,
1322+
),
13031323
}
13041324

13051325
def indent(st, spaces):

numpy/_core/code_generators/ufunc_docstrings.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5028,3 +5028,184 @@ def add_newdoc(place, name, doc):
50285028
array(['001', '-01', '+01'], dtype='<U3')
50295029
50305030
""")
5031+
5032+
add_newdoc('numpy._core.umath', '_partition_index',
5033+
"""
5034+
Partition each element in ``x1`` around ``x2``, at precomputed
5035+
index ``x3``.
5036+
5037+
For each element in ``x1``, split the element at the first
5038+
occurrence of ``x2`` at location ``x3``, and return a 3-tuple
5039+
containing the part before the separator, the separator itself,
5040+
and the part after the separator. If the separator is not found,
5041+
the first item of the tuple will contain the whole string, and
5042+
the second and third ones will be the empty string.
5043+
5044+
Parameters
5045+
----------
5046+
x1 : array-like, with ``bytes_``, or ``str_`` dtype
5047+
Input array
5048+
x2 : array-like, with ``bytes_``, or ``str_`` dtype
5049+
Separator to split each string element in ``x1``.
5050+
x3 : array-like, with any integer dtype
5051+
The indices of the separator (<0 to indicate the separator is not
5052+
present).
5053+
5054+
Returns
5055+
-------
5056+
out : 3-tuple:
5057+
- array with ``bytes_`` or ``str_`` dtype with the part before the
5058+
separator
5059+
- array with ``bytes_`` or ``str_`` dtype with the separator
5060+
- array with ``bytes_`` or ``str_`` dtype with the part after the
5061+
separator
5062+
5063+
See Also
5064+
--------
5065+
str.partition
5066+
5067+
Examples
5068+
--------
5069+
The ufunc is used most easily via ``np.strings.partition``,
5070+
which calls it after calculating the indices::
5071+
5072+
>>> x = np.array(["Numpy is nice!"])
5073+
>>> np.strings.partition(x, " ")
5074+
(array(['Numpy'], dtype='<U5'),
5075+
array([' '], dtype='<U1'),
5076+
array(['is nice!'], dtype='<U8'))
5077+
5078+
""")
5079+
5080+
add_newdoc('numpy._core.umath', '_rpartition_index',
5081+
"""
5082+
Partition each element in ``x1`` around the right-most separator,
5083+
``x2``, at precomputed index ``x3``.
5084+
5085+
For each element in ``x1``, split the element at the last
5086+
occurrence of ``x2`` at location ``x3``, and return a 3-tuple
5087+
containing the part before the separator, the separator itself,
5088+
and the part after the separator. If the separator is not found,
5089+
the third item of the tuple will contain the whole string, and
5090+
the first and second ones will be the empty string.
5091+
5092+
Parameters
5093+
----------
5094+
x1 : array-like, with ``bytes_``, or ``str_`` dtype
5095+
Input array
5096+
x2 : array-like, with ``bytes_``, or ``str_`` dtype
5097+
Separator to split each string element in ``x1``.
5098+
x3 : array-like, with any integer dtype
5099+
The indices of the separator (<0 to indicate the separator is not
5100+
present).
5101+
5102+
Returns
5103+
-------
5104+
out : 3-tuple:
5105+
- array with ``bytes_`` or ``str_`` dtype with the part before the
5106+
separator
5107+
- array with ``bytes_`` or ``str_`` dtype with the separator
5108+
- array with ``bytes_`` or ``str_`` dtype with the part after the
5109+
separator
5110+
5111+
See Also
5112+
--------
5113+
str.rpartition
5114+
5115+
Examples
5116+
--------
5117+
The ufunc is used most easily via ``np.strings.rpartition``,
5118+
which calls it after calculating the indices::
5119+
5120+
>>> a = np.array(['aAaAaA', ' aA ', 'abBABba'])
5121+
>>> np.strings.rpartition(a, 'A')
5122+
(array(['aAaAa', ' a', 'abB'], dtype='<U5'),
5123+
array(['A', 'A', 'A'], dtype='<U1'),
5124+
array(['', ' ', 'Bba'], dtype='<U3'))
5125+
5126+
""")
5127+
5128+
add_newdoc('numpy._core.umath', '_partition',
5129+
"""
5130+
Partition each element in ``x1`` around ``x2``.
5131+
5132+
For each element in ``x1``, split the element at the first
5133+
occurrence of ``x2`` and return a 3-tuple containing the part before
5134+
the separator, the separator itself, and the part after the
5135+
separator. If the separator is not found, the first item of the
5136+
tuple will contain the whole string, and the second and third ones
5137+
will be the empty string.
5138+
5139+
Parameters
5140+
----------
5141+
x1 : array-like, with ``StringDType`` dtype
5142+
Input array
5143+
x2 : array-like, with ``StringDType`` dtype
5144+
Separator to split each string element in ``x1``.
5145+
5146+
Returns
5147+
-------
5148+
out : 3-tuple:
5149+
- ``StringDType`` array with the part before the separator
5150+
- ``StringDType`` array with the separator
5151+
- ``StringDType`` array with the part after the separator
5152+
5153+
See Also
5154+
--------
5155+
str.partition
5156+
5157+
Examples
5158+
--------
5159+
The ufunc is used most easily via ``np.strings.partition``,
5160+
which calls it under the hood::
5161+
5162+
>>> x = np.array(["Numpy is nice!"], dtype="T")
5163+
>>> np.strings.partition(x, " ")
5164+
(array(['Numpy'], dtype=StringDType()),
5165+
array([' '], dtype=StringDType()),
5166+
array(['is nice!'], dtype=StringDType()))
5167+
5168+
""")
5169+
5170+
add_newdoc('numpy._core.umath', '_rpartition',
5171+
"""
5172+
Partition each element in ``x1`` around the right-most separator,
5173+
``x2``.
5174+
5175+
For each element in ``x1``, split the element at the last
5176+
occurrence of ``x2`` at location ``x3``, and return a 3-tuple
5177+
containing the part before the separator, the separator itself,
5178+
and the part after the separator. If the separator is not found,
5179+
the third item of the tuple will contain the whole string, and
5180+
the first and second ones will be the empty string.
5181+
5182+
Parameters
5183+
----------
5184+
x1 : array-like, with ``StringDType`` dtype
5185+
Input array
5186+
x2 : array-like, with ``StringDType`` dtype
5187+
Separator to split each string element in ``x1``.
5188+
5189+
Returns
5190+
-------
5191+
out : 3-tuple:
5192+
- ``StringDType`` array with the part before the separator
5193+
- ``StringDType`` array with the separator
5194+
- ``StringDType`` array with the part after the separator
5195+
5196+
See Also
5197+
--------
5198+
str.rpartition
5199+
5200+
Examples
5201+
--------
5202+
The ufunc is used most easily via ``np.strings.rpartition``,
5203+
which calls it after calculating the indices::
5204+
5205+
>>> a = np.array(['aAaAaA', ' aA ', 'abBABba'], dtype="T")
5206+
>>> np.strings.rpartition(a, 'A')
5207+
(array(['aAaAa', ' a', 'abB'], dtype=StringDType()),
5208+
array(['A', 'A', 'A'], dtype=StringDType()),
5209+
array(['', ' ', 'Bba'], dtype=StringDType()))
5210+
5211+
""")

numpy/_core/defchararray.py

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,19 @@
1717
"""
1818
import functools
1919

20+
import numpy as np
2021
from .._utils import set_module
2122
from .numerictypes import bytes_, str_, character
2223
from .numeric import ndarray, array as narray, asarray as asnarray
2324
from numpy._core.multiarray import compare_chararrays
2425
from numpy._core import overrides
2526
from numpy.strings import *
26-
from numpy.strings import multiply as strings_multiply
27+
from numpy.strings import (
28+
multiply as strings_multiply,
29+
partition as strings_partition,
30+
rpartition as strings_rpartition,
31+
)
2732
from numpy._core.strings import (
28-
_partition as partition,
29-
_rpartition as rpartition,
3033
_split as split,
3134
_rsplit as rsplit,
3235
_splitlines as splitlines,
@@ -303,6 +306,88 @@ def multiply(a, i):
303306
raise ValueError("Can only multiply by integers")
304307

305308

309+
def partition(a, sep):
310+
"""
311+
Partition each element in `a` around `sep`.
312+
313+
Calls :meth:`str.partition` element-wise.
314+
315+
For each element in `a`, split the element as the first
316+
occurrence of `sep`, and return 3 strings containing the part
317+
before the separator, the separator itself, and the part after
318+
the separator. If the separator is not found, return 3 strings
319+
containing the string itself, followed by two empty strings.
320+
321+
Parameters
322+
----------
323+
a : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
324+
Input array
325+
sep : {str, unicode}
326+
Separator to split each string element in `a`.
327+
328+
Returns
329+
-------
330+
out : ndarray
331+
Output array of ``StringDType``, ``bytes_`` or ``str_`` dtype,
332+
depending on input types. The output array will have an extra
333+
dimension with 3 elements per input element.
334+
335+
Examples
336+
--------
337+
>>> x = np.array(["Numpy is nice!"])
338+
>>> np.char.partition(x, " ")
339+
array([['Numpy', ' ', 'is nice!']], dtype='<U8')
340+
341+
See Also
342+
--------
343+
str.partition
344+
345+
"""
346+
return np.stack(strings_partition(a, sep), axis=-1)
347+
348+
349+
def rpartition(a, sep):
350+
"""
351+
Partition (split) each element around the right-most separator.
352+
353+
Calls :meth:`str.rpartition` element-wise.
354+
355+
For each element in `a`, split the element as the last
356+
occurrence of `sep`, and return 3 strings containing the part
357+
before the separator, the separator itself, and the part after
358+
the separator. If the separator is not found, return 3 strings
359+
containing the string itself, followed by two empty strings.
360+
361+
Parameters
362+
----------
363+
a : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
364+
Input array
365+
sep : str or unicode
366+
Right-most separator to split each element in array.
367+
368+
Returns
369+
-------
370+
out : ndarray
371+
Output array of ``StringDType``, ``bytes_`` or ``str_`` dtype,
372+
depending on input types. The output array will have an extra
373+
dimension with 3 elements per input element.
374+
375+
See Also
376+
--------
377+
str.rpartition
378+
379+
Examples
380+
--------
381+
>>> a = np.array(['aAaAaA', ' aA ', 'abBABba'])
382+
>>> np.char.rpartition(a, 'A')
383+
array([['aAaAa', 'A', ''],
384+
[' a', 'A', ' '],
385+
['abB', 'A', 'Bba']], dtype='<U5')
386+
387+
"""
388+
return np.stack(strings_rpartition(a, sep), axis=-1)
389+
390+
306391
@set_module("numpy.char")
307392
class chararray(ndarray):
308393
"""
@@ -487,7 +572,7 @@ def __array_wrap__(self, arr, context=None, return_scalar=False):
487572

488573
def __array_finalize__(self, obj):
489574
# The b is a special case because it is used for reconstructing.
490-
if self.dtype.char not in 'SUbc':
575+
if self.dtype.char not in 'VSUbc':
491576
raise ValueError("Can only create a chararray from string data.")
492577

493578
def __getitem__(self, obj):

numpy/_core/src/umath/string_buffer.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,4 +1593,46 @@ string_zfill(Buffer<enc> buf, npy_int64 width, Buffer<enc> out)
15931593
}
15941594

15951595

1596+
template <ENCODING enc>
1597+
static inline void
1598+
string_partition(Buffer<enc> buf1, Buffer<enc> buf2, npy_int64 idx,
1599+
Buffer<enc> out1, Buffer<enc> out2, Buffer<enc> out3,
1600+
npy_intp *final_len1, npy_intp *final_len2, npy_intp *final_len3,
1601+
STARTPOSITION pos)
1602+
{
1603+
// StringDType uses a ufunc that implements the find-part as well
1604+
assert(enc != ENCODING::UTF8);
1605+
1606+
size_t len1 = buf1.num_codepoints();
1607+
size_t len2 = buf2.num_codepoints();
1608+
1609+
if (len2 == 0) {
1610+
npy_gil_error(PyExc_ValueError, "empty separator");
1611+
*final_len1 = *final_len2 = *final_len3 = -1;
1612+
return;
1613+
}
1614+
1615+
if (idx < 0) {
1616+
if (pos == STARTPOSITION::FRONT) {
1617+
buf1.buffer_memcpy(out1, len1);
1618+
*final_len1 = len1;
1619+
*final_len2 = *final_len3 = 0;
1620+
}
1621+
else {
1622+
buf1.buffer_memcpy(out3, len1);
1623+
*final_len1 = *final_len2 = 0;
1624+
*final_len3 = len1;
1625+
}
1626+
return;
1627+
}
1628+
1629+
buf1.buffer_memcpy(out1, idx);
1630+
*final_len1 = idx;
1631+
buf2.buffer_memcpy(out2, len2);
1632+
*final_len2 = len2;
1633+
(buf1 + idx + len2).buffer_memcpy(out3, len1 - idx - len2);
1634+
*final_len3 = len1 - idx - len2;
1635+
}
1636+
1637+
15961638
#endif /* _NPY_CORE_SRC_UMATH_STRING_BUFFER_H_ */

0 commit comments

Comments
 (0)