Skip to content

Commit 738393f

Browse files
authored
Merge branch 'main' into ifelse_torch
2 parents 881bca1 + b66d859 commit 738393f

File tree

17 files changed

+1095
-319
lines changed

17 files changed

+1095
-319
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
)$
2323
- id: check-merge-conflict
2424
- repo: https://github.com/astral-sh/ruff-pre-commit
25-
rev: v0.5.6
25+
rev: v0.6.3
2626
hooks:
2727
- id: ruff
2828
args: ["--fix", "--output-format=full"]

pytensor/compile/mode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
471471
"BlasOpt",
472472
"fusion",
473473
"inplace",
474+
"local_uint_constant_indices",
474475
],
475476
),
476477
)

pytensor/link/pytorch/dispatch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import pytensor.link.pytorch.dispatch.elemwise
88
import pytensor.link.pytorch.dispatch.math
99
import pytensor.link.pytorch.dispatch.extra_ops
10+
import pytensor.link.pytorch.dispatch.nlinalg
1011
import pytensor.link.pytorch.dispatch.shape
1112
import pytensor.link.pytorch.dispatch.sort
12-
import pytensor.link.pytorch.dispatch.nlinalg
13+
import pytensor.link.pytorch.dispatch.subtensor
1314
# isort: on

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,41 @@
11
from functools import singledispatch
22
from types import NoneType
33

4+
import numpy as np
45
import torch
56

67
from pytensor.compile.ops import DeepCopyOp
78
from pytensor.graph.fg import FunctionGraph
89
from pytensor.ifelse import IfElse
910
from pytensor.link.utils import fgraph_to_python
1011
from pytensor.raise_op import CheckAndRaise
11-
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector
12+
from pytensor.tensor.basic import (
13+
Alloc,
14+
AllocEmpty,
15+
ARange,
16+
Eye,
17+
Join,
18+
MakeVector,
19+
TensorFromScalar,
20+
)
1221

1322

1423
@singledispatch
15-
def pytorch_typify(data, dtype=None, **kwargs):
16-
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
24+
def pytorch_typify(data, **kwargs):
25+
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")
26+
27+
28+
@pytorch_typify.register(np.ndarray)
29+
@pytorch_typify.register(torch.Tensor)
30+
def pytorch_typify_tensor(data, dtype=None, **kwargs):
1731
return torch.as_tensor(data, dtype=dtype)
1832

1933

34+
@pytorch_typify.register(slice)
2035
@pytorch_typify.register(NoneType)
21-
def pytorch_typify_None(data, **kwargs):
22-
return None
36+
@pytorch_typify.register(np.number)
37+
def pytorch_typify_no_conversion_needed(data, **kwargs):
38+
return data
2339

2440

2541
@singledispatch
@@ -146,3 +162,10 @@ def ifelse(cond, *true_and_false, n_outs=n_outs):
146162
return torch.stack(true_and_false[n_outs:])
147163

148164
return ifelse
165+
166+
@pytorch_funcify.register(TensorFromScalar)
167+
def pytorch_funcify_TensorFromScalar(op, **kwargs):
168+
def tensorfromscalar(x):
169+
return torch.as_tensor(x)
170+
171+
return tensorfromscalar
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
2+
from pytensor.tensor.subtensor import (
3+
AdvancedIncSubtensor,
4+
AdvancedIncSubtensor1,
5+
AdvancedSubtensor,
6+
AdvancedSubtensor1,
7+
IncSubtensor,
8+
Subtensor,
9+
indices_from_subtensor,
10+
)
11+
from pytensor.tensor.type_other import MakeSlice, SliceType
12+
13+
14+
def check_negative_steps(indices):
15+
for index in indices:
16+
if isinstance(index, slice):
17+
if index.step is not None and index.step < 0:
18+
raise NotImplementedError(
19+
"Negative step sizes are not supported in Pytorch"
20+
)
21+
22+
23+
@pytorch_funcify.register(Subtensor)
24+
def pytorch_funcify_Subtensor(op, node, **kwargs):
25+
idx_list = op.idx_list
26+
27+
def subtensor(x, *flattened_indices):
28+
indices = indices_from_subtensor(flattened_indices, idx_list)
29+
check_negative_steps(indices)
30+
return x[indices]
31+
32+
return subtensor
33+
34+
35+
@pytorch_funcify.register(MakeSlice)
36+
def pytorch_funcify_makeslice(op, **kwargs):
37+
def makeslice(*x):
38+
return slice(x)
39+
40+
return makeslice
41+
42+
43+
@pytorch_funcify.register(AdvancedSubtensor1)
44+
@pytorch_funcify.register(AdvancedSubtensor)
45+
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
46+
def advsubtensor(x, *indices):
47+
check_negative_steps(indices)
48+
return x[indices]
49+
50+
return advsubtensor
51+
52+
53+
@pytorch_funcify.register(IncSubtensor)
54+
def pytorch_funcify_IncSubtensor(op, node, **kwargs):
55+
idx_list = op.idx_list
56+
inplace = op.inplace
57+
if op.set_instead_of_inc:
58+
59+
def set_subtensor(x, y, *flattened_indices):
60+
indices = indices_from_subtensor(flattened_indices, idx_list)
61+
check_negative_steps(indices)
62+
if not inplace:
63+
x = x.clone()
64+
x[indices] = y
65+
return x
66+
67+
return set_subtensor
68+
69+
else:
70+
71+
def inc_subtensor(x, y, *flattened_indices):
72+
indices = indices_from_subtensor(flattened_indices, idx_list)
73+
check_negative_steps(indices)
74+
if not inplace:
75+
x = x.clone()
76+
x[indices] += y
77+
return x
78+
79+
return inc_subtensor
80+
81+
82+
@pytorch_funcify.register(AdvancedIncSubtensor)
83+
@pytorch_funcify.register(AdvancedIncSubtensor1)
84+
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
85+
inplace = op.inplace
86+
ignore_duplicates = getattr(op, "ignore_duplicates", False)
87+
88+
if op.set_instead_of_inc:
89+
90+
def adv_set_subtensor(x, y, *indices):
91+
check_negative_steps(indices)
92+
if not inplace:
93+
x = x.clone()
94+
x[indices] = y.type_as(x)
95+
return x
96+
97+
return adv_set_subtensor
98+
99+
elif ignore_duplicates:
100+
101+
def adv_inc_subtensor_no_duplicates(x, y, *indices):
102+
check_negative_steps(indices)
103+
if not inplace:
104+
x = x.clone()
105+
x[indices] += y.type_as(x)
106+
return x
107+
108+
return adv_inc_subtensor_no_duplicates
109+
110+
else:
111+
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
112+
raise NotImplementedError(
113+
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
114+
)
115+
116+
def adv_inc_subtensor(x, y, *indices):
117+
# Not needed because slices aren't supported
118+
# check_negative_steps(indices)
119+
if not inplace:
120+
x = x.clone()
121+
x.index_put_(indices, y.type_as(x), accumulate=True)
122+
return x
123+
124+
return adv_inc_subtensor

pytensor/tensor/basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3780,15 +3780,16 @@ class AllocDiag(OpFromGraph):
37803780
Wrapper Op for alloc_diag graphs
37813781
"""
37823782

3783-
__props__ = ("axis1", "axis2")
3784-
37853783
def __init__(self, *args, axis1, axis2, offset, **kwargs):
37863784
self.axis1 = axis1
37873785
self.axis2 = axis2
37883786
self.offset = offset
37893787

37903788
super().__init__(*args, **kwargs, strict=True)
37913789

3790+
def __str__(self):
3791+
return f"AllocDiag{{{self.axis1=}, {self.axis2=}, {self.offset=}}}"
3792+
37923793
@staticmethod
37933794
def is_offset_zero(node) -> bool:
37943795
"""

pytensor/tensor/einsum.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,15 @@ class Einsum(OpFromGraph):
5252
desired. We haven't decided whether we want to provide this functionality.
5353
"""
5454

55-
__props__ = ("subscripts", "path", "optimized")
56-
5755
def __init__(self, *args, subscripts: str, path: PATH, optimized: bool, **kwargs):
5856
self.subscripts = subscripts
5957
self.path = path
6058
self.optimized = optimized
6159
super().__init__(*args, **kwargs, strict=True)
6260

61+
def __str__(self):
62+
return f"Einsum{{{self.subscripts=}, {self.path=}, {self.optimized=}}}"
63+
6364

6465
def _iota(shape: TensorVariable, axis: int) -> TensorVariable:
6566
"""

0 commit comments

Comments
 (0)