Skip to content

Commit f05d9a0

Browse files
committed
Add numpy-backed ArrayValue tests
1 parent 52db81d commit f05d9a0

File tree

3 files changed

+38
-7
lines changed

3 files changed

+38
-7
lines changed

CondaPkg.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ version = ">=3.10,<4"
1414

1515
[dev.deps]
1616
matplotlib = ""
17+
numpy = ""
1718
pyside6 = ""
1819
python = "<3.14"

src/JlWrap/array.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -355,22 +355,20 @@ class ArrayValue(AnyValue):
355355
@property
356356
def __array_interface__(self):
357357
return self._jl_callmethod($(pyjl_methodnum(pyjlarray_array_interface)))
358-
def __array__(self, dtype=None):
358+
def __array__(self, dtype=None, copy=None):
359+
import numpy
359360
# convert to an array-like object
360361
arr = self
361362
if not (hasattr(arr, "__array_interface__") or hasattr(arr, "__array_struct__")):
363+
if copy is False:
364+
raise ValueError("copy=False is not supported when collecting ArrayValue data")
362365
# the first attempt collects into an Array
363366
arr = self._jl_callmethod($(pyjl_methodnum(pyjlarray_array__array)))
364367
if not (hasattr(arr, "__array_interface__") or hasattr(arr, "__array_struct__")):
365368
# the second attempt collects into a PyObjectArray
366369
arr = self._jl_callmethod($(pyjl_methodnum(pyjlarray_array__pyobjectarray)))
367370
# convert to a numpy array if numpy is available
368-
try:
369-
import numpy
370-
arr = numpy.array(arr, dtype=dtype)
371-
except ImportError:
372-
pass
373-
return arr
371+
return numpy.array(arr, dtype=dtype, copy=copy)
374372
def to_numpy(self, dtype=None, copy=True, order="K"):
375373
import numpy
376374
return numpy.array(self, dtype=dtype, copy=copy, order=order)

test/JlWrap.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,38 @@ end
313313
@test pyjlvalue(x) == [0 2; 3 4]
314314
@test pyjlvalue(y) == [1 2; 3 4]
315315
end
316+
@testset "__array__" begin
317+
np = pyimport("numpy")
318+
319+
numeric = pyjl(Float64[1, 2, 3])
320+
numeric_array = numeric.__array__()
321+
@test pyisinstance(numeric_array, np.ndarray)
322+
@test pyconvert(Vector{Float64}, numeric_array) == [1.0, 2.0, 3.0]
323+
324+
numeric_no_copy = numeric.__array__(copy=false)
325+
numeric_data = pyjlvalue(numeric)
326+
numeric_data[1] = 42.0
327+
@test pyconvert(Vector{Float64}, numeric_no_copy) == [42.0, 2.0, 3.0]
328+
329+
string_array = pyjl(["a", "b"])
330+
string_result = string_array.__array__()
331+
@test pyisinstance(string_result, np.ndarray)
332+
@test pyconvert(Vector{String}, pybuiltins.list(string_result)) == ["a", "b"]
333+
334+
err = try
335+
string_array.__array__(copy=false)
336+
nothing
337+
catch err
338+
err
339+
end
340+
@test err !== nothing
341+
@test err isa PythonCall.PyException
342+
@test pyis(err._t, pybuiltins.ValueError)
343+
@test occursin(
344+
"copy=False is not supported when collecting ArrayValue data",
345+
sprint(showerror, err),
346+
)
347+
end
316348
@testset "array_interface" begin
317349
x = pyjl(Float32[1 2 3; 4 5 6]).__array_interface__
318350
@test pyisinstance(x, pybuiltins.dict)

0 commit comments

Comments
 (0)