Skip to content

Commit 187449f

Browse files
committed
Add test_init.py
1 parent 4eec115 commit 187449f

File tree

1 file changed

+171
-0
lines changed

1 file changed

+171
-0
lines changed
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import operator
2+
import os
3+
import shutil
4+
import subprocess
5+
import sys
6+
import tempfile
7+
import unittest
8+
from unittest import mock
9+
10+
import numpy
11+
import pytest
12+
13+
import dpnp as cupy
14+
from tests.third_party.cupy import testing
15+
16+
17+
def _run_script(code):
18+
# subprocess is required not to interfere with cupy module imported in top
19+
# of this file
20+
temp_dir = tempfile.mkdtemp()
21+
try:
22+
script_path = os.path.join(temp_dir, "script.py")
23+
with open(script_path, "w") as f:
24+
f.write(code)
25+
proc = subprocess.Popen(
26+
[sys.executable, script_path],
27+
stdout=subprocess.PIPE,
28+
stderr=subprocess.PIPE,
29+
)
30+
stdoutdata, stderrdata = proc.communicate()
31+
finally:
32+
shutil.rmtree(temp_dir, ignore_errors=True)
33+
return proc.returncode, stdoutdata, stderrdata
34+
35+
36+
def _test_cupy_available(self):
37+
returncode, stdoutdata, stderrdata = _run_script(
38+
"""
39+
import dpnp as cupy
40+
print(cupy.is_available())"""
41+
)
42+
assert returncode == 0, "stderr: {!r}".format(stderrdata)
43+
assert stdoutdata in (b"True\n", b"True\r\n", b"False\n", b"False\r\n")
44+
return stdoutdata == b"True\n" or stdoutdata == b"True\r\n"
45+
46+
47+
class TestImportError(unittest.TestCase):
48+
49+
def test_import_error(self):
50+
returncode, stdoutdata, stderrdata = _run_script(
51+
"""
52+
try:
53+
import dpnp as cupy
54+
except Exception as e:
55+
print(type(e).__name__)
56+
"""
57+
)
58+
assert returncode == 0, "stderr: {!r}".format(stderrdata)
59+
assert stdoutdata in (b"", b"RuntimeError\n")
60+
61+
62+
# if not cupy.cuda.runtime.is_hip:
63+
# visible = "CUDA_VISIBLE_DEVICES"
64+
# else:
65+
# visible = "HIP_VISIBLE_DEVICES"
66+
67+
68+
@pytest.mark.skip("dpnp.is_available() is not implemented")
69+
class TestAvailable(unittest.TestCase):
70+
71+
def test_available(self):
72+
available = _test_cupy_available(self)
73+
assert available
74+
75+
76+
@pytest.mark.skip("dpnp.is_available() is not implemented")
77+
class TestNotAvailable(unittest.TestCase):
78+
79+
def setUp(self):
80+
self.old = os.environ.get(visible)
81+
82+
def tearDown(self):
83+
if self.old is None:
84+
os.environ.pop(visible)
85+
else:
86+
os.environ[visible] = self.old
87+
88+
# @unittest.skipIf(
89+
# cupy.cuda.runtime.is_hip,
90+
# "HIP handles empty HIP_VISIBLE_DEVICES differently",
91+
# )
92+
def test_no_device_1(self):
93+
os.environ["CUDA_VISIBLE_DEVICES"] = " "
94+
available = _test_cupy_available(self)
95+
assert not available
96+
97+
def test_no_device_2(self):
98+
os.environ[visible] = "-1"
99+
available = _test_cupy_available(self)
100+
assert not available
101+
102+
103+
@pytest.mark.skip("No memory pool API is supported")
104+
class TestMemoryPool(unittest.TestCase):
105+
106+
def test_get_default_memory_pool(self):
107+
p = cupy.get_default_memory_pool()
108+
assert isinstance(p, cupy.cuda.memory.MemoryPool)
109+
110+
def test_get_default_pinned_memory_pool(self):
111+
p = cupy.get_default_pinned_memory_pool()
112+
assert isinstance(p, cupy.cuda.pinned_memory.PinnedMemoryPool)
113+
114+
115+
@pytest.mark.skip("dpnp.show_config() is not implemented")
116+
class TestShowConfig(unittest.TestCase):
117+
118+
def test_show_config(self):
119+
with mock.patch("sys.stdout.write") as write_func:
120+
cupy.show_config()
121+
write_func.assert_called_once_with(
122+
str(cupyx.get_runtime_info(full=False))
123+
)
124+
125+
def test_show_config_with_handles(self):
126+
with mock.patch("sys.stdout.write") as write_func:
127+
cupy.show_config(_full=True)
128+
write_func.assert_called_once_with(
129+
str(cupyx.get_runtime_info(full=True))
130+
)
131+
132+
133+
class TestAliases(unittest.TestCase):
134+
135+
def test_abs_is_absolute(self):
136+
for xp in (numpy, cupy):
137+
assert xp.abs is xp.absolute
138+
139+
def test_conj_is_conjugate(self):
140+
for xp in (numpy, cupy):
141+
assert xp.conj is xp.conjugate
142+
143+
def test_bitwise_not_is_invert(self):
144+
for xp in (numpy, cupy):
145+
assert xp.bitwise_not is xp.invert
146+
147+
148+
@testing.with_requires("numpy>=2.0")
149+
@pytest.mark.parametrize(
150+
"name",
151+
[
152+
"exceptions.AxisError",
153+
"exceptions.ComplexWarning",
154+
"exceptions.ModuleDeprecationWarning",
155+
"exceptions.RankWarning",
156+
"exceptions.TooHardError",
157+
"exceptions.VisibleDeprecationWarning",
158+
"linalg.LinAlgError",
159+
],
160+
)
161+
def test_error_classes(name):
162+
get = operator.attrgetter(name)
163+
assert issubclass(get(cupy), get(numpy))
164+
165+
166+
# This is copied from chainer/testing/__init__.py, so should be replaced in
167+
# some way.
168+
if __name__ == "__main__":
169+
import pytest
170+
171+
pytest.main([__file__, "-vvs", "-x", "--pdb"])

0 commit comments

Comments
 (0)