|
| 1 | +(libdoc_xtensor)= |
| 2 | +# `xtensor` -- XTensor operations |
| 3 | + |
| 4 | +This module implements as abstraction layer on regular tensor operations, that behaves like Xarray. |
| 5 | + |
| 6 | +A new type {class}`pytensor.xtensor.type.XTensorType`, generalizes the {class}`pytensor.tensor.TensorType` |
| 7 | +with the addition of a `dims` attribute, that labels the dimensions of the tensor. |
| 8 | + |
| 9 | +Variables of XTensorType (i.e., {class}`pytensor.xtensor.type.XTensorVariable`s) are the symbolic counterpart |
| 10 | +to xarray DataArray objects. |
| 11 | + |
| 12 | +The module implements several PyTensor operations {class}`pytensor.xtensor.basic.XOp`s, whose signature mimics that of |
| 13 | +xarray (and xarray_einstats) DataArray operations. These operations, unlike most regular PyTensor operations, cannot |
| 14 | +be directly evaluated, but require a rewrite (lowering) into a regular tensor graph that can itself be evaluated as usual. |
| 15 | + |
| 16 | +Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray. |
| 17 | +If the existing XOps can be composed to produce the desired result, then we can use them directly. |
| 18 | + |
| 19 | +## Coordinates |
| 20 | +For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`. |
| 21 | +The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor. |
| 22 | +Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor. |
| 23 | + |
| 24 | +## Example |
| 25 | + |
| 26 | + |
| 27 | +:::{testcode} |
| 28 | + |
| 29 | +import pytensor.tensor as pt |
| 30 | +import pytensor.xtensor as ptx |
| 31 | + |
| 32 | +a = pt.tensor("a", shape=(3,)) |
| 33 | +b = pt.tensor("b", shape=(4,)) |
| 34 | + |
| 35 | +ax = ptx.as_xtensor(a, dims=["x"]) |
| 36 | +bx = ptx.as_xtensor(b, dims=["y"]) |
| 37 | + |
| 38 | +zx = ax + bx |
| 39 | +assert zx.type == ptx.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4)) |
| 40 | + |
| 41 | +z = zx.values |
| 42 | +z.dprint() |
| 43 | +::: |
| 44 | + |
| 45 | + |
| 46 | +:::{testoutput} |
| 47 | + |
| 48 | +TensorFromXTensor [id A] |
| 49 | + └─ XElemwise{scalar_op=Add()} [id B] |
| 50 | + ├─ XTensorFromTensor{dims=('x',)} [id C] |
| 51 | + │ └─ a [id D] |
| 52 | + └─ XTensorFromTensor{dims=('y',)} [id E] |
| 53 | + └─ b [id F] |
| 54 | +::: |
| 55 | + |
| 56 | +Once we compile the graph, no XOps are left. |
| 57 | + |
| 58 | +:::{testcode} |
| 59 | + |
| 60 | +import pytensor |
| 61 | + |
| 62 | +with pytensor.config.change_flags(optimizer_verbose=True): |
| 63 | + fn = pytensor.function([a, b], z) |
| 64 | + |
| 65 | +::: |
| 66 | + |
| 67 | +:::{testoutput} |
| 68 | + |
| 69 | +rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0) |
| 70 | +rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None |
| 71 | +rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None |
| 72 | +rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0) |
| 73 | + |
| 74 | +::: |
| 75 | + |
| 76 | +:::{testcode} |
| 77 | + |
| 78 | +fn.dprint() |
| 79 | +::: |
| 80 | + |
| 81 | +:::{testoutput} |
| 82 | + |
| 83 | +Add [id A] 2 |
| 84 | + ├─ ExpandDims{axis=1} [id B] 1 |
| 85 | + │ └─ a [id C] |
| 86 | + └─ ExpandDims{axis=0} [id D] 0 |
| 87 | + └─ b [id E] |
| 88 | +::: |
| 89 | + |
| 90 | + |
| 91 | +## Index |
| 92 | + |
| 93 | +:::{toctree} |
| 94 | +:maxdepth: 1 |
| 95 | + |
| 96 | +module_functions |
| 97 | +math |
| 98 | +linalg |
| 99 | +random |
| 100 | +type |
| 101 | +::: |
0 commit comments