Skip to content

Commit 6857bea

Browse files
Add OpFromGraph wrapper around alloc_diag
1 parent 05d376f commit 6857bea

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

pytensor/tensor/basic.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pytensor.scalar.sharedvar
2222
from pytensor import compile, config, printing
2323
from pytensor import scalar as ps
24+
from pytensor.compile.builders import OpFromGraph
2425
from pytensor.gradient import DisconnectedType, grad_undefined
2526
from pytensor.graph import RewriteDatabaseQuery
2627
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
@@ -3831,6 +3832,12 @@ def __setstate__(self, state):
38313832
self.axis2 = 1
38323833

38333834

3835+
class AllocDiag2(OpFromGraph):
3836+
"""
3837+
Wrapper Op for alloc_diag graphs
3838+
"""
3839+
3840+
38343841
def alloc_diag(diag, offset=0, axis1=0, axis2=1):
38353842
"""Insert a vector on the diagonal of a zero-ed matrix.
38363843
@@ -3865,7 +3872,7 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
38653872
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
38663873
result = result.transpose(axes)
38673874

3868-
return result
3875+
return AllocDiag2(inputs=[diag], outputs=[result])(diag)
38693876

38703877

38713878
def diag(v, k=0):

0 commit comments

Comments
 (0)