Skip to content

Commit a6444c7

Browse files
committed
Modify np.tri Op, wrap around OpFromGraph
1 parent 89d5366 commit a6444c7

File tree

1 file changed

+12
-28
lines changed

1 file changed

+12
-28
lines changed

pytensor/tensor/basic.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
get_vector_length,
4444
)
4545
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
46+
from pytensor.tensor.einsum import _iota
4647
from pytensor.tensor.elemwise import (
4748
DimShuffle,
4849
Elemwise,
@@ -1084,35 +1085,18 @@ def nonzero_values(a):
10841085
return _a.flatten()[flatnonzero(_a)]
10851086

10861087

1087-
class Tri(Op):
1088+
class Tri(OpFromGraph):
1089+
"""
1090+
Wrapper Op for np.tri graphs
1091+
"""
1092+
10881093
__props__ = ("dtype",)
10891094

1090-
def __init__(self, dtype=None):
1091-
if dtype is None:
1092-
dtype = config.floatX
1095+
def __init__(self, *args, M, k, dtype, **kwargs):
1096+
self.M = M
1097+
self.k = k
10931098
self.dtype = dtype
1094-
1095-
def make_node(self, N, M, k):
1096-
N = as_tensor_variable(N)
1097-
M = as_tensor_variable(M)
1098-
k = as_tensor_variable(k)
1099-
return Apply(
1100-
self,
1101-
[N, M, k],
1102-
[TensorType(dtype=self.dtype, shape=(None, None))()],
1103-
)
1104-
1105-
def perform(self, node, inp, out_):
1106-
N, M, k = inp
1107-
(out,) = out_
1108-
out[0] = np.tri(N, M, k, dtype=self.dtype)
1109-
1110-
def infer_shape(self, fgraph, node, in_shapes):
1111-
out_shape = [node.inputs[0], node.inputs[1]]
1112-
return [out_shape]
1113-
1114-
def grad(self, inp, grads):
1115-
return [grad_undefined(self, i, inp[i]) for i in range(3)]
1099+
super().__init__(*args, **kwargs, strict=True)
11161100

11171101

11181102
def tri(N, M=None, k=0, dtype=None):
@@ -1144,8 +1128,8 @@ def tri(N, M=None, k=0, dtype=None):
11441128
dtype = config.floatX
11451129
if M is None:
11461130
M = N
1147-
op = Tri(dtype)
1148-
return op(N, M, k)
1131+
output = ((_iota(M) + k) > _iota(N)).astype(int)
1132+
return Tri(inputs=[N], outputs=[output], M=M, k=k, dtype=dtype)(N)
11491133

11501134

11511135
def tril(m, k=0):

0 commit comments

Comments
 (0)