-
Notifications
You must be signed in to change notification settings - Fork 139
Implement ScalarLoop in torch backend #958
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 26 commits
7804b90
12569f8
8eff3fe
e0bbde8
ae1c9da
2844bc4
39ff3de
714759c
623dfbe
e06994f
977d98d
07e4520
561301b
e28c3e2
fb90500
cd678ef
1865de9
1ffd7c6
fd2f192
7027c4c
2a9ffd3
4ebdd15
46e3e72
d00f9e2
5bd100e
b6ce485
521ad67
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,19 @@ | |
super().__init__(*args, **kwargs) | ||
self.gen_functors = [] | ||
|
||
def input_filter(self, inp): | ||
Ch0ronomato marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from pytensor.link.pytorch.dispatch import pytorch_typify | ||
|
||
return pytorch_typify(inp) | ||
|
||
def output_filter(self, var, out): | ||
Ch0ronomato marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from torch import is_tensor | ||
|
||
if is_tensor(out): | ||
return out.cpu() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will probably create conflict when one of my other PRs gets merged as an FYI. |
||
else: | ||
return out | ||
|
||
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): | ||
from pytensor.link.pytorch.dispatch import pytorch_funcify | ||
|
||
|
@@ -77,11 +90,11 @@ | |
self.gen_functors = [] | ||
|
||
# Torch does not accept numpy inputs and may return GPU objects | ||
def fn(*inputs, inner_fn=inner_fn): | ||
def create_outputs(*inputs, inner_fn=inner_fn): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why the new name? Seems less clear There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure but can we use a different name. This doesn't "create_outputs" it converts inputs to torch tensors and outputs back to pytensor-compatible types There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure thing - I can also just keep the shadowing lol. It's not the end of the world. From your description I would probably have called it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can also put it inside the wrapper There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, that makes sense too. |
||
outs = inner_fn(*(pytorch_typify(inp) for inp in inputs)) | ||
return tuple(out.cpu().numpy() for out in outs) | ||
|
||
return fn | ||
return create_outputs | ||
|
||
def create_thunk_inputs(self, storage_map): | ||
thunk_inputs = [] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there no
torch.empty
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mb; i had an old version of torch on my machine (2.2) which didn't have it, but 2.3+ does. Reverted to
torch.empty