|
| 1 | +(libdoc_xtensor)= |
| 2 | +# XTensor Module |
| 3 | + |
| 4 | +:::{toctree} |
| 5 | +:maxdepth: 1 |
| 6 | +xtensor/linalg |
| 7 | +xtensor/math |
| 8 | +xtensor/random |
| 9 | +xtensor/type |
| 10 | +::: |
| 11 | + |
| 12 | +This module implements as abstraction layer on regular tensor operations, that behaves like Xarray. |
| 13 | + |
| 14 | +A new type :class:`pytensor.xtensor.type.XTensorType`, generalizes the :class:`pytensor.tensor.type.TensorType` |
| 15 | +with the addition of a `dims` attribute, that labels the dimensions of the tensor. |
| 16 | + |
| 17 | +Variables of XTensorType (i.e., :class:`pytensor.xtensor.type.XTensorVariable`s) are the symbolic counterpart |
| 18 | +to xarray DataArray objects. |
| 19 | + |
| 20 | +The module implements several PyTensor operations :class:`pytensor.xtensor.basic.XOp`s, whose signature mimics that of |
| 21 | +xarray (and xarray_einstants) DataArray operations. These operations, unlike most regular PyTensor operations, cannot |
| 22 | +be directly evaluated, but require a rewrite (lowering) into a regular tensor graph that can itself be evaluated as usual. |
| 23 | + |
| 24 | +Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray. |
| 25 | +If the existing XOps can be composed to produce the desired result, then we can use them directly. |
| 26 | + |
| 27 | +## Coordinates |
| 28 | +For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`. |
| 29 | +The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor. |
| 30 | +Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor. |
| 31 | + |
| 32 | +## Example |
| 33 | + |
| 34 | +```python |
| 35 | +import pytensor.tensor as pt |
| 36 | +import pytensor.xtensor as ptx |
| 37 | + |
| 38 | +a = pt.tensor("a", shape=(3,)) |
| 39 | +b = pt.tensor("b", shape=(4,)) |
| 40 | + |
| 41 | +ax = ptx.as_xtensor(a, dims=["x"]) |
| 42 | +bx = ptx.as_xtensor(b, dims=["y"]) |
| 43 | + |
| 44 | +zx = ax + bx |
| 45 | +assert zx.type == ptx.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4)) |
| 46 | + |
| 47 | +z = zx.values |
| 48 | +z.dprint() |
| 49 | +# TensorFromXTensor [id A] |
| 50 | +# └─ XElemwise{scalar_op=Add()} [id B] |
| 51 | +# ├─ XTensorFromTensor{dims=('x',)} [id C] |
| 52 | +# │ └─ a [id D] |
| 53 | +# └─ XTensorFromTensor{dims=('y',)} [id E] |
| 54 | +# └─ b [id F] |
| 55 | +``` |
| 56 | + |
| 57 | +Once we compile the graph, no `XOp`s are left. |
| 58 | + |
| 59 | +```python |
| 60 | +import pytensor |
| 61 | + |
| 62 | +with pytensor.config.change_flags(optimizer_verbose=True): |
| 63 | + fn = pytensor.function([a, b], z) |
| 64 | + |
| 65 | +# rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0) |
| 66 | +# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None |
| 67 | +# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None |
| 68 | +# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0) |
| 69 | + |
| 70 | +fn.dprint() |
| 71 | +# Add [id A] 2 |
| 72 | +# ├─ ExpandDims{axis=1} [id B] 1 |
| 73 | +# │ └─ a [id C] |
| 74 | +# └─ ExpandDims{axis=0} [id D] 0 |
| 75 | +# └─ b [id E] |
| 76 | +``` |
0 commit comments