-
Notifications
You must be signed in to change notification settings - Fork 146
Implement OpFromGraph in PyTorch backend #956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
c4b20ec
9c64320
fdd5d5c
d98e68e
10a841f
b29be45
cefec02
0f18d8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
OPT_O3, | ||
OPT_STABILIZE, | ||
OPT_UNSAFE, | ||
PYTORCH, | ||
AddDestroyHandler, | ||
AddFeatureOptimizer, | ||
Mode, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
|
||
import torch | ||
|
||
from pytensor.compile import PYTORCH | ||
from pytensor.compile.builders import OpFromGraph | ||
from pytensor.compile.ops import DeepCopyOp | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.link.utils import fgraph_to_python | ||
|
@@ -132,3 +134,17 @@ def makevector(*x): | |
return torch.tensor(x, dtype=torch_dtype) | ||
|
||
return makevector | ||
|
||
|
||
@pytorch_funcify.register(OpFromGraph) | ||
def pytorch_funcify_OpFromGraph(op, node=None, **kwargs): | ||
_ = kwargs.pop("storage_map", None) | ||
|
||
PYTORCH.optimizer(op.fgraph) | ||
fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs)) | ||
|
||
|
||
def opfromgraph(*inputs, dim=op.fgraph.outputs): | ||
Ch0ronomato marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
res = fgraph_fn(*inputs) | ||
|
||
return res[0] | ||
|
||
return opfromgraph |
Uh oh!
There was an error while loading. Please reload this page.