Skip to content

Commit 00ab6b2

Browse files
committed
Fix triu docstring; remove dtype prop in JAX version
1 parent f5fc9df commit 00ab6b2

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,6 @@ def tri(*args):
203203
x if const_x is None else const_x
204204
for x, const_x in zip(args, const_args, strict=True)
205205
]
206-
return jnp.tri(*args, dtype=op.dtype)
206+
return jnp.tri(*args)
207207

208208
return tri

pytensor/tensor/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1268,7 +1268,7 @@ def triu(m, k=0):
12681268
[ 0, 8, 9],
12691269
[ 0, 0, 12]])
12701270
1271-
>>> pt.triu(np.arange(3 * 4 * 5).reshape((3, 4, 5))).eval()
1271+
>>> pt.triu(pt.arange(3 * 4 * 5).reshape((3, 4, 5))).eval()
12721272
array([[[ 0, 1, 2, 3, 4],
12731273
[ 0, 6, 7, 8, 9],
12741274
[ 0, 0, 12, 13, 14],

0 commit comments

Comments
 (0)