Skip to content

Commit 4187b27

Browse files
committed
Add tests for arithmetic operators
1 parent 5a29ffa commit 4187b27

File tree

2 files changed

+111
-6
lines changed

2 files changed

+111
-6
lines changed

arrayfire/array_api/_array_object.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,35 +159,35 @@ def __add__(self, other: int | float | Array, /) -> Array:
159159
"""
160160
return _process_c_function(self, other, backend.get().af_add)
161161

162-
def __sub__(self, other: int | float | bool | complex | Array, /) -> Array:
162+
def __sub__(self, other: int | float | Array, /) -> Array:
163163
"""
164164
Return self - other.
165165
"""
166166
return _process_c_function(self, other, backend.get().af_sub)
167167

168-
def __mul__(self, other: int | float | bool | complex | Array, /) -> Array:
168+
def __mul__(self, other: int | float | Array, /) -> Array:
169169
"""
170170
Return self * other.
171171
"""
172172
return _process_c_function(self, other, backend.get().af_mul)
173173

174-
def __truediv__(self, other: int | float | bool | complex | Array, /) -> Array:
174+
def __truediv__(self, other: int | float | Array, /) -> Array:
175175
"""
176176
Return self / other.
177177
"""
178178
return _process_c_function(self, other, backend.get().af_div)
179179

180-
def __floordiv__(self, other: int | float | bool | complex | Array, /) -> Array:
180+
def __floordiv__(self, other: int | float | Array, /) -> Array:
181181
# TODO
182182
return NotImplemented
183183

184-
def __mod__(self, other: int | float | bool | complex | Array, /) -> Array:
184+
def __mod__(self, other: int | float | Array, /) -> Array:
185185
"""
186186
Return self % other.
187187
"""
188188
return _process_c_function(self, other, backend.get().af_mod)
189189

190-
def __pow__(self, other: int | float | bool | complex | Array, /) -> Array:
190+
def __pow__(self, other: int | float | Array, /) -> Array:
191191
"""
192192
Return self ** other.
193193
"""

arrayfire/array_api/tests/test_array_object.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,108 @@ def test_array_sum() -> None:
111111
assert res[0].scalar() == 2
112112
assert res[1].scalar() == 3
113113
assert res[2].scalar() == 4
114+
115+
res = array + 1.5
116+
assert res[0].scalar() == 2.5
117+
assert res[1].scalar() == 3.5
118+
assert res[2].scalar() == 4.5
119+
120+
res = array + Array([9, 9, 9])
121+
assert res[0].scalar() == 10
122+
assert res[1].scalar() == 11
123+
assert res[2].scalar() == 12
124+
125+
126+
def test_array_sub() -> None:
127+
array = Array([1, 2, 3])
128+
res = array - 1
129+
assert res[0].scalar() == 0
130+
assert res[1].scalar() == 1
131+
assert res[2].scalar() == 2
132+
133+
res = array - 1.5
134+
assert res[0].scalar() == -0.5
135+
assert res[1].scalar() == 0.5
136+
assert res[2].scalar() == 1.5
137+
138+
res = array - Array([9, 9, 9])
139+
assert res[0].scalar() == -8
140+
assert res[1].scalar() == -7
141+
assert res[2].scalar() == -6
142+
143+
144+
def test_array_mul() -> None:
145+
array = Array([1, 2, 3])
146+
res = array * 2
147+
assert res[0].scalar() == 2
148+
assert res[1].scalar() == 4
149+
assert res[2].scalar() == 6
150+
151+
res = array * 1.5
152+
assert res[0].scalar() == 1.5
153+
assert res[1].scalar() == 3
154+
assert res[2].scalar() == 4.5
155+
156+
res = array * Array([9, 9, 9])
157+
assert res[0].scalar() == 9
158+
assert res[1].scalar() == 18
159+
assert res[2].scalar() == 27
160+
161+
162+
def test_array_truediv() -> None:
163+
array = Array([1, 2, 3])
164+
res = array / 2
165+
assert res[0].scalar() == 0.5
166+
assert res[1].scalar() == 1
167+
assert res[2].scalar() == 1.5
168+
169+
res = array / 1.5
170+
assert round(res[0].scalar(), 5) == 0.66667 # type: ignore[arg-type]
171+
assert round(res[1].scalar(), 5) == 1.33333 # type: ignore[arg-type]
172+
assert res[2].scalar() == 2
173+
174+
res = array / Array([2, 2, 2])
175+
assert res[0].scalar() == 0.5
176+
assert res[1].scalar() == 1
177+
assert res[2].scalar() == 1.5
178+
179+
180+
def test_array_floordiv() -> None:
181+
# TODO add test after implementation of __floordiv__
182+
pass
183+
184+
185+
def test_array_mod() -> None:
186+
array = Array([1, 2, 3])
187+
res = array % 2
188+
assert res[0].scalar() == 1
189+
assert res[1].scalar() == 0
190+
assert res[2].scalar() == 1
191+
192+
res = array % 1.5
193+
assert res[0].scalar() == 1.0
194+
assert res[1].scalar() == 0.5
195+
assert res[2].scalar() == 0.0
196+
197+
res = array % Array([9, 9, 9])
198+
assert res[0].scalar() == 1.0
199+
assert res[1].scalar() == 2.0
200+
assert res[2].scalar() == 3.0
201+
202+
203+
def test_array_pow() -> None:
204+
array = Array([1, 2, 3])
205+
res = array ** 2
206+
assert res[0].scalar() == 1
207+
assert res[1].scalar() == 4
208+
assert res[2].scalar() == 9
209+
210+
res = array ** 1.5
211+
assert res[0].scalar() == 1
212+
assert round(res[1].scalar(), 5) == 2.82843 # type: ignore[arg-type]
213+
assert round(res[2].scalar(), 5) == 5.19615 # type: ignore[arg-type]
214+
215+
res = array ** Array([9, 9, 9])
216+
assert res[0].scalar() == 1
217+
assert res[1].scalar() == 512
218+
assert res[2].scalar() == 19683

0 commit comments

Comments
 (0)