Skip to content

Commit b0abe17

Browse files
Save arguments passed to alloc_diag as properties in AllocDiag2
1 parent 5604d9a commit b0abe17

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

pytensor/tensor/basic.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3837,6 +3837,17 @@ class AllocDiag2(OpFromGraph):
38373837
Wrapper Op for alloc_diag graphs
38383838
"""
38393839

3840+
__props__ = ("offset", "axis1", "axis2", "inline")
3841+
3842+
def __init__(self, *args, offset, axis1, axis2, **kwargs):
3843+
inline = kwargs.pop("inline", False)
3844+
self.offset = offset
3845+
self.axis1 = axis1
3846+
self.axis2 = axis2
3847+
self.inline = inline
3848+
3849+
super().__init__(*args, **kwargs, strict=True, inline=inline)
3850+
38403851

38413852
def alloc_diag(diag, offset=0, axis1=0, axis2=1):
38423853
"""Insert a vector on the diagonal of a zero-ed matrix.
@@ -3872,7 +3883,9 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
38723883
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
38733884
result = result.transpose(axes)
38743885

3875-
return AllocDiag2(inputs=[diag], outputs=[result])(diag)
3886+
return AllocDiag2(
3887+
inputs=[diag], outputs=[result], offset=offset, axis1=axis1, axis2=axis2
3888+
)(diag)
38763889

38773890

38783891
def diag(v, k=0):

0 commit comments

Comments
 (0)