Skip to content

Commit d28d774

Browse files
committed
Modify Tri class, revert tests
1 parent 5172a52 commit d28d774

File tree

2 files changed

+25
-27
lines changed

2 files changed

+25
-27
lines changed

pytensor/tensor/basic.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
3535
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
3636
from pytensor.raise_op import CheckAndRaise
37-
from pytensor.scalar import int32, upcast
37+
from pytensor.scalar import int32
3838
from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable
3939
from pytensor.tensor import (
4040
_as_tensor_variable,
@@ -1148,14 +1148,6 @@ class Tri(OpFromGraph):
11481148
Wrapper Op for np.tri graphs
11491149
"""
11501150

1151-
__props__ = ("dtype",)
1152-
1153-
def __init__(self, *args, M, k, dtype, **kwargs):
1154-
self.M = M
1155-
self.k = k
1156-
self.dtype = dtype
1157-
super().__init__(*args, **kwargs, strict=True)
1158-
11591151

11601152
def tri(N, M=None, k=0, dtype=None):
11611153
"""
@@ -1184,14 +1176,19 @@ def tri(N, M=None, k=0, dtype=None):
11841176
"""
11851177
if dtype is None:
11861178
dtype = config.floatX
1179+
dtype = np.dtype(dtype)
1180+
11871181
if M is None:
11881182
M = N
11891183

1184+
N = as_tensor_variable(N)
1185+
M = as_tensor_variable(M)
1186+
k = as_tensor_variable(k)
1187+
11901188
output = ((iota(as_tensor((N, 1)), 0) + k + 1) > iota(as_tensor((1, M)), 1)).astype(
11911189
dtype
11921190
)
1193-
N = as_tensor_variable(N)
1194-
return Tri(inputs=[N], outputs=[output], M=M, k=k, dtype=dtype)(N)
1191+
return Tri(inputs=[N, M, k], outputs=[output])(N, M, k)
11951192

11961193

11971194
def tril(m, k=0):
@@ -1245,9 +1242,7 @@ def tril(m, k=0):
12451242
[55, 56, 57, 58, 0]]])
12461243
12471244
"""
1248-
N, M = m.shape[-2:]
1249-
dtype = upcast(m.dtype)
1250-
return m * tri(N, M=M, k=k, dtype=dtype) # M is symbolic, while it shouldnt be
1245+
return m * tri(*m.shape[-2:], k=k, dtype=m.dtype)
12511246

12521247

12531248
def triu(m, k=0):

tests/tensor/test_basic.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -988,12 +988,15 @@ def check(dtype, N, M_=None, k=0):
988988
M = M_
989989
# Currently DebugMode does not support None as inputs even if this is
990990
# allowed.
991-
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
991+
if M is None: # and config.mode in ["DebugMode", "DEBUG_MODE"]:
992992
M = N
993993
N_symb = iscalar()
994-
f = function([N_symb], tri(N_symb, M=M, k=k, dtype=dtype))
995-
# kwargs = {}
996-
result = f(N)
994+
M_symb = iscalar()
995+
k_symb = iscalar()
996+
f = function(
997+
[N_symb, M_symb, k_symb], tri(N_symb, M_symb, k_symb, dtype=dtype)
998+
)
999+
result = f(N, M, k)
9971000
assert np.allclose(result, np.tri(N, M_, k, dtype=dtype))
9981001
assert result.dtype == np.dtype(dtype)
9991002

@@ -1025,15 +1028,15 @@ def test_tril_triu(self):
10251028

10261029
def check_l(m, k=0):
10271030
m_symb = matrix(dtype=m.dtype)
1028-
# k_symb = iscalar()
1029-
f = function([m_symb], tril(m_symb, k=k))
1031+
k_symb = iscalar()
1032+
f = function([m_symb, k_symb], tril(m_symb, k_symb))
10301033
f_indx = function(
1031-
[m_symb], tril_indices(m_symb.shape[0], k=k, m=m_symb.shape[1])
1034+
[m_symb, k_symb], tril_indices(m_symb.shape[0], k_symb, m_symb.shape[1])
10321035
)
1033-
f_indx_from = function([m_symb], tril_indices_from(m_symb))
1034-
result = f(m)
1035-
result_indx = f_indx(m, k=k)
1036-
result_from = f_indx_from(m, k=k)
1036+
f_indx_from = function([m_symb, k_symb], tril_indices_from(m_symb, k_symb))
1037+
result = f(m, k)
1038+
result_indx = f_indx(m, k)
1039+
result_from = f_indx_from(m, k)
10371040
assert np.allclose(result, np.tril(m, k))
10381041
assert np.allclose(result_indx, np.tril_indices(m.shape[0], k, m.shape[1]))
10391042
assert np.allclose(result_from, np.tril_indices_from(m, k))
@@ -1043,7 +1046,7 @@ def check_l(m, k=0):
10431046
def check_u(m, k=0):
10441047
m_symb = matrix(dtype=m.dtype)
10451048
k_symb = iscalar()
1046-
f = function([m_symb, k_symb], triu(m_symb, k=k))
1049+
f = function([m_symb, k_symb], triu(m_symb, k_symb))
10471050
f_indx = function(
10481051
[m_symb, k_symb], triu_indices(m_symb.shape[0], k_symb, m_symb.shape[1])
10491052
)
@@ -1075,7 +1078,7 @@ def check_u_batch(m):
10751078
assert np.allclose(result, np.triu(m, k))
10761079
assert result.dtype == np.dtype(dtype)
10771080

1078-
for dtype in ["int32", "int64", "float32", "float64", "uint16", "complex64"]:
1081+
for dtype in ["int32", "int64", "float32", "float64", "uint16"]:
10791082
m = random_of_dtype((10, 10), dtype)
10801083
check_l(m, 0)
10811084
check_l(m, 1)

0 commit comments

Comments
 (0)