|
43 | 43 | get_vector_length, |
44 | 44 | ) |
45 | 45 | from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback |
| 46 | +from pytensor.tensor.einsum import _iota |
46 | 47 | from pytensor.tensor.elemwise import ( |
47 | 48 | DimShuffle, |
48 | 49 | Elemwise, |
@@ -1084,35 +1085,18 @@ def nonzero_values(a): |
1084 | 1085 | return _a.flatten()[flatnonzero(_a)] |
1085 | 1086 |
|
1086 | 1087 |
|
1087 | | -class Tri(Op): |
| 1088 | +class Tri(OpFromGraph): |
| 1089 | + """ |
| 1090 | + Wrapper Op for np.tri graphs |
| 1091 | + """ |
| 1092 | + |
1088 | 1093 | __props__ = ("dtype",) |
1089 | 1094 |
|
1090 | | - def __init__(self, dtype=None): |
1091 | | - if dtype is None: |
1092 | | - dtype = config.floatX |
| 1095 | + def __init__(self, *args, M, k, dtype, **kwargs): |
| 1096 | + self.M = M |
| 1097 | + self.k = k |
1093 | 1098 | self.dtype = dtype |
1094 | | - |
1095 | | - def make_node(self, N, M, k): |
1096 | | - N = as_tensor_variable(N) |
1097 | | - M = as_tensor_variable(M) |
1098 | | - k = as_tensor_variable(k) |
1099 | | - return Apply( |
1100 | | - self, |
1101 | | - [N, M, k], |
1102 | | - [TensorType(dtype=self.dtype, shape=(None, None))()], |
1103 | | - ) |
1104 | | - |
1105 | | - def perform(self, node, inp, out_): |
1106 | | - N, M, k = inp |
1107 | | - (out,) = out_ |
1108 | | - out[0] = np.tri(N, M, k, dtype=self.dtype) |
1109 | | - |
1110 | | - def infer_shape(self, fgraph, node, in_shapes): |
1111 | | - out_shape = [node.inputs[0], node.inputs[1]] |
1112 | | - return [out_shape] |
1113 | | - |
1114 | | - def grad(self, inp, grads): |
1115 | | - return [grad_undefined(self, i, inp[i]) for i in range(3)] |
| 1099 | + super().__init__(*args, **kwargs, strict=True) |
1116 | 1100 |
|
1117 | 1101 |
|
1118 | 1102 | def tri(N, M=None, k=0, dtype=None): |
@@ -1144,8 +1128,8 @@ def tri(N, M=None, k=0, dtype=None): |
1144 | 1128 | dtype = config.floatX |
1145 | 1129 | if M is None: |
1146 | 1130 | M = N |
1147 | | - op = Tri(dtype) |
1148 | | - return op(N, M, k) |
| 1131 | + output = ((_iota(M) + k) > _iota(N)).astype(int) |
| 1132 | + return Tri(inputs=[N], outputs=[output], M=M, k=k, dtype=dtype)(N) |
1149 | 1133 |
|
1150 | 1134 |
|
1151 | 1135 | def tril(m, k=0): |
|
0 commit comments