Skip to content

Commit 6366fe7

Browse files
committed
Skip unneeded copy
1 parent 72f91cc commit 6366fe7

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

array_api_compat/common/_helpers.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,7 @@ def is_writeable_array(x):
812812
return False
813813
return True
814814

815-
def _parse_copy_param(x, copy: bool | None) -> bool:
815+
def _parse_copy_param(x, copy: bool | None | Literal["_force_false"]) -> bool:
816816
"""Preprocess and validate a copy parameter, in line with the same
817817
parameter in np.asarray(), np.astype(), etc.
818818
"""
@@ -821,6 +821,8 @@ def _parse_copy_param(x, copy: bool | None) -> bool:
821821
if copy is False:
822822
if not is_writeable_array(x):
823823
raise ValueError("Cannot avoid modifying parameter in place")
824+
elif copy == "_force_false":
825+
return False
824826
elif copy is not True:
825827
raise ValueError(f"Invalid value for copy: {copy!r}")
826828
return copy
@@ -909,7 +911,7 @@ def _common(
909911
self,
910912
at_op: str,
911913
y=_undef,
912-
copy: bool | None = True,
914+
copy: bool | None | Literal["_force_false"] = True,
913915
**kwargs,
914916
):
915917
"""Validate kwargs and perform common prepocessing.
@@ -956,13 +958,13 @@ def get(self, copy: bool | None = True, **kwargs):
956958
or (is_numpy_array(self.idx) and self.idx.dtype.kind in "biu")
957959
)
958960
):
959-
if copy is True:
960-
copy = None
961-
elif copy is False:
961+
if copy is False:
962962
raise ValueError(
963963
"Indexing a numpy array with a fancy index always "
964964
"results in a copy"
965965
)
966+
# Skip copy inside _common, even if array is not writeable
967+
copy = "_force_false" # type: ignore
966968

967969
res, x = self._common("get", copy=copy, **kwargs)
968970
if res is not None:

0 commit comments

Comments
 (0)