Skip to content

Commit 5172a52

Browse files
committed
Modify tests for tri Op
1 parent 21a158d commit 5172a52

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

pytensor/tensor/basic.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
3535
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
3636
from pytensor.raise_op import CheckAndRaise
37-
from pytensor.scalar import int32
37+
from pytensor.scalar import int32, upcast
3838
from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable
3939
from pytensor.tensor import (
4040
_as_tensor_variable,
@@ -1186,8 +1186,9 @@ def tri(N, M=None, k=0, dtype=None):
11861186
dtype = config.floatX
11871187
if M is None:
11881188
M = N
1189+
11891190
output = ((iota(as_tensor((N, 1)), 0) + k + 1) > iota(as_tensor((1, M)), 1)).astype(
1190-
int
1191+
dtype
11911192
)
11921193
N = as_tensor_variable(N)
11931194
return Tri(inputs=[N], outputs=[output], M=M, k=k, dtype=dtype)(N)
@@ -1244,7 +1245,9 @@ def tril(m, k=0):
12441245
[55, 56, 57, 58, 0]]])
12451246
12461247
"""
1247-
return m * tri(*m.shape[-2:], k=k, dtype=m.dtype)
1248+
N, M = m.shape[-2:]
1249+
dtype = upcast(m.dtype)
1250+
return m * tri(N, M=M, k=k, dtype=dtype) # M is symbolic, while it shouldnt be
12481251

12491252

12501253
def triu(m, k=0):

tests/tensor/test_basic.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -991,16 +991,19 @@ def check(dtype, N, M_=None, k=0):
991991
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
992992
M = N
993993
N_symb = iscalar()
994-
M_symb = iscalar()
995-
k_symb = iscalar()
996-
f = function(
997-
[N_symb, M_symb, k_symb], tri(N_symb, M_symb, k_symb, dtype=dtype)
998-
)
999-
result = f(N, M, k)
994+
f = function([N_symb], tri(N_symb, M=M, k=k, dtype=dtype))
995+
# kwargs = {}
996+
result = f(N)
1000997
assert np.allclose(result, np.tri(N, M_, k, dtype=dtype))
1001998
assert result.dtype == np.dtype(dtype)
1002999

1003-
for dtype in ["int32", "int64", "float32", "float64", "uint16", "complex64"]:
1000+
for dtype in [
1001+
"int32",
1002+
"int64",
1003+
"float32",
1004+
"float64",
1005+
"uint16",
1006+
]: # Handle "complex64" ?
10041007
check(dtype, 3)
10051008
# M != N, k = 0
10061009
check(dtype, 3, 5)
@@ -1022,15 +1025,15 @@ def test_tril_triu(self):
10221025

10231026
def check_l(m, k=0):
10241027
m_symb = matrix(dtype=m.dtype)
1025-
k_symb = iscalar()
1026-
f = function([m_symb, k_symb], tril(m_symb, k_symb))
1028+
# k_symb = iscalar()
1029+
f = function([m_symb], tril(m_symb, k=k))
10271030
f_indx = function(
1028-
[m_symb, k_symb], tril_indices(m_symb.shape[0], k_symb, m_symb.shape[1])
1031+
[m_symb], tril_indices(m_symb.shape[0], k=k, m=m_symb.shape[1])
10291032
)
1030-
f_indx_from = function([m_symb, k_symb], tril_indices_from(m_symb, k_symb))
1031-
result = f(m, k)
1032-
result_indx = f_indx(m, k)
1033-
result_from = f_indx_from(m, k)
1033+
f_indx_from = function([m_symb], tril_indices_from(m_symb))
1034+
result = f(m)
1035+
result_indx = f_indx(m, k=k)
1036+
result_from = f_indx_from(m, k=k)
10341037
assert np.allclose(result, np.tril(m, k))
10351038
assert np.allclose(result_indx, np.tril_indices(m.shape[0], k, m.shape[1]))
10361039
assert np.allclose(result_from, np.tril_indices_from(m, k))
@@ -1040,7 +1043,7 @@ def check_l(m, k=0):
10401043
def check_u(m, k=0):
10411044
m_symb = matrix(dtype=m.dtype)
10421045
k_symb = iscalar()
1043-
f = function([m_symb, k_symb], triu(m_symb, k_symb))
1046+
f = function([m_symb, k_symb], triu(m_symb, k=k))
10441047
f_indx = function(
10451048
[m_symb, k_symb], triu_indices(m_symb.shape[0], k_symb, m_symb.shape[1])
10461049
)

0 commit comments

Comments
 (0)