Skip to content

Commit f7cf273

Browse files
committed
Extract ViewOp functionality into a base TypeCastOp
1 parent ddcd988 commit f7cf273

File tree

4 files changed

+30
-23
lines changed

4 files changed

+30
-23
lines changed

pytensor/compile/ops.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,8 @@ def register_view_op_c_code(type, code, version=()):
3333
ViewOp.c_code_and_version[type] = (code, version)
3434

3535

36-
class ViewOp(COp):
37-
"""
38-
Returns an inplace view of the input. Used internally by PyTensor.
39-
40-
"""
36+
class TypeCastingOp(COp):
37+
"""Op that performs a graph-level type cast operation, but has no effect computation-wise (identity function)."""
4138

4239
view_map = {0: [0]}
4340
# Mapping from Type to C code (and version) to use.
@@ -47,13 +44,8 @@ class ViewOp(COp):
4744
__props__: tuple = ()
4845
_f16_ok: bool = True
4946

50-
def make_node(self, x):
51-
return Apply(self, [x], [x.type()])
52-
53-
def perform(self, node, inp, out):
54-
(x,) = inp
55-
(z,) = out
56-
z[0] = x
47+
def perform(self, node, inputs, outputs_storage):
48+
outputs_storage[0][0] = inputs[0]
5749

5850
def __str__(self):
5951
return f"{self.__class__.__name__}"
@@ -90,6 +82,13 @@ def c_code_cache_version(self):
9082

9183
return tuple(version)
9284

85+
86+
class ViewOp(TypeCastingOp):
87+
"""Returns an inplace view of the input. Used internally by PyTensor."""
88+
89+
def make_node(self, x):
90+
return Apply(self, [x], [x.type()])
91+
9392
def infer_shape(self, fgraph, node, input_shapes):
9493
return input_shapes
9594

pytensor/link/jax/dispatch/basic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from pytensor.compile import JAX
1010
from pytensor.compile.builders import OpFromGraph
11-
from pytensor.compile.ops import DeepCopyOp, ViewOp
11+
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
1212
from pytensor.configdefaults import config
1313
from pytensor.graph.fg import FunctionGraph
1414
from pytensor.ifelse import IfElse
@@ -111,12 +111,12 @@ def deepcopyop(x):
111111
return deepcopyop
112112

113113

114-
@jax_funcify.register(ViewOp)
115-
def jax_funcify_ViewOp(op, **kwargs):
116-
def viewop(x):
114+
@jax_funcify.register(TypeCastingOp)
115+
def jax_funcify_TypeCastingOp(op, **kwargs):
116+
def type_cast(x):
117117
return x
118118

119-
return viewop
119+
return type_cast
120120

121121

122122
@jax_funcify.register(OpFromGraph)

pytensor/link/numba/dispatch/scalar.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44

5-
from pytensor.compile.ops import ViewOp
5+
from pytensor.compile.ops import TypeCastingOp
66
from pytensor.graph.basic import Variable
77
from pytensor.link.numba.dispatch import basic as numba_basic
88
from pytensor.link.numba.dispatch.basic import (
@@ -198,14 +198,14 @@ def cast(x):
198198

199199

200200
@numba_basic.numba_njit
201-
def viewop(x):
201+
def identity(x):
202202
return x
203203

204204

205205
@numba_funcify.register(Identity)
206-
@numba_funcify.register(ViewOp)
207-
def numba_funcify_ViewOp(op, **kwargs):
208-
return numba_basic.global_numba_func(viewop)
206+
@numba_funcify.register(TypeCastingOp)
207+
def numba_funcify_type_casting(op, **kwargs):
208+
return numba_basic.global_numba_func(identity)
209209

210210

211211
@numba_basic.numba_njit

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytensor.compile import PYTORCH
1010
from pytensor.compile.builders import OpFromGraph
1111
from pytensor.compile.function.types import add_supervisor_to_fgraph
12-
from pytensor.compile.ops import DeepCopyOp
12+
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
1313
from pytensor.graph.basic import Constant
1414
from pytensor.graph.fg import FunctionGraph
1515
from pytensor.ifelse import IfElse
@@ -71,6 +71,14 @@ def pytorch_funcify_FunctionGraph(
7171
)
7272

7373

74+
@pytorch_funcify.register(TypeCastingOp)
75+
def pytorch_funcify_CastingOp(op, node, **kwargs):
76+
def type_cast(x):
77+
return x
78+
79+
return type_cast
80+
81+
7482
@pytorch_funcify.register(CheckAndRaise)
7583
def pytorch_funcify_CheckAndRaise(op, **kwargs):
7684
error = op.exc_type

0 commit comments

Comments
 (0)