Skip to content

Commit 5d5f754

Browse files
committed
remove torch related code / comments
1 parent e116fa1 commit 5d5f754

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

pytensor/link/mlx/linker.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ def fgraph_convert(
4545
# just the subgraph
4646
generator = unique_name_generator(["mlx_linker"])
4747

48-
# Ensure that torch is aware of the generated
49-
# code so we can compile without graph breaks
5048
def conversion_func_register(*args, **kwargs):
5149
functor = mlx_funcify(*args, **kwargs)
5250
name = kwargs["unique_name"](functor)
@@ -85,14 +83,12 @@ def __call__(self, *inputs, **kwargs):
8583
# MLX doesn't support np.ndarray as input
8684
outs = self.fn(*(mlx_typify(inp) for inp in inputs), **kwargs)
8785

88-
return outs
89-
9086
# unset attrs
9187
for n, _ in self.gen_functors:
9288
if getattr(pytensor.link.utils, n[1:], False):
9389
delattr(pytensor.link.utils, n[1:])
9490

95-
return tuple(out.cpu().numpy() for out in outs)
91+
return outs
9692

9793
def __del__(self):
9894
del self.gen_functors

0 commit comments

Comments
 (0)