File tree Expand file tree Collapse file tree 2 files changed +15
-0
lines changed
pytensor/link/pytorch/dispatch Expand file tree Collapse file tree 2 files changed +15
-0
lines changed Original file line number Diff line number Diff line change 1919 Eye ,
2020 Join ,
2121 MakeVector ,
22+ Split ,
2223 TensorFromScalar ,
2324)
2425
@@ -185,3 +186,11 @@ def tensorfromscalar(x):
185186 return torch .as_tensor (x )
186187
187188 return tensorfromscalar
189+
190+
191+ @pytorch_funcify .register (Split )
192+ def pytorch_funcify_Split (op , node , ** kwargs ):
193+ def inner_fn (x , dim , split_amounts ):
194+ return x .split (split_amounts .tolist (), dim = dim .item ())
195+
196+ return inner_fn
Original file line number Diff line number Diff line change 55from pytensor .link .pytorch .dispatch .basic import pytorch_funcify
66from pytensor .scalar .basic import (
77 Cast ,
8+ Invert ,
89 ScalarOp ,
910)
1011from pytensor .scalar .loop import ScalarLoop
1112from pytensor .scalar .math import Softplus
1213
1314
15+ @pytorch_funcify .register (Invert )
16+ def pytorch_funcify_invert (op , node , ** kwargs ):
17+ return torch .bitwise_not
18+
19+
1420@pytorch_funcify .register (ScalarOp )
1521def pytorch_funcify_ScalarOp (op , node , ** kwargs ):
1622 """Return pytorch function that implements the same computation as the Scalar Op.
You can’t perform that action at this time.
0 commit comments