Skip to content

Commit a790e5e

Browse files
committed
Remove dtype as a required argument for StringScalar
Since StringDType doesn't take any parameters there's no need to save the dtype along with the scalar instance. We can always create a new StringDType instance on-the-fly.
1 parent 98873db commit a790e5e

File tree

3 files changed

+7
-13
lines changed

3 files changed

+7
-13
lines changed

stringdtype/stringdtype/scalar.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,6 @@
22

33

44
class StringScalar(str):
5-
def __new__(cls, value, dtype):
6-
instance = super().__new__(cls, value)
7-
instance.dtype = dtype
8-
return instance
9-
105
def partition(self, sep):
116
ret = super().partition(sep)
127
return (str(ret[0]), str(ret[1]), str(ret[2]))

stringdtype/stringdtype/src/dtype.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ string_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls),
6868
return NULL;
6969
}
7070

71-
PyArray_Descr *ret = (PyArray_Descr *)PyObject_GetAttrString(obj, "dtype");
71+
PyArray_Descr *ret = (PyArray_Descr *)new_stringdtype_instance();
7272
if (ret == NULL) {
7373
return NULL;
7474
}
@@ -143,7 +143,7 @@ stringdtype_getitem(StringDTypeObject *descr, char **dataptr)
143143
}
144144

145145
PyObject *res = PyObject_CallFunctionObjArgs((PyObject *)StringScalar_Type,
146-
val_obj, descr, NULL);
146+
val_obj, NULL);
147147

148148
if (res == NULL) {
149149
return NULL;

stringdtype/tests/test_stringdtype.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def string_list():
1414

1515

1616
def test_scalar_creation():
17-
assert str(StringScalar("abc", StringDType())) == "abc"
17+
assert str(StringScalar("abc")) == "abc"
1818

1919

2020
def test_dtype_creation():
@@ -42,12 +42,11 @@ def test_array_creation_utf8(data):
4242

4343

4444
def test_array_creation_scalars(string_list):
45-
dtype = StringDType()
4645
arr = np.array(
4746
[
48-
StringScalar("abc", dtype=dtype),
49-
StringScalar("def", dtype=dtype),
50-
StringScalar("ghi", dtype=dtype),
47+
StringScalar("abc"),
48+
StringScalar("def"),
49+
StringScalar("ghi"),
5150
]
5251
)
5352
assert repr(arr) == repr(np.array(string_list, dtype=StringDType()))
@@ -98,7 +97,7 @@ def test_unicode_casts(string_list):
9897
def test_insert_scalar(string_list):
9998
dtype = StringDType()
10099
arr = np.array(string_list, dtype=dtype)
101-
arr[1] = StringScalar("what", dtype=dtype)
100+
arr[1] = StringScalar("what")
102101
assert repr(arr) == repr(np.array(["abc", "what", "ghi"], dtype=dtype))
103102

104103

0 commit comments

Comments
 (0)