PyTorch wraps the C++ ATen tensor library that offers a wide range of operations implemented on GPU and CPU. Pytorch/XLA is a PyTorch extension; one of its purposes is to convert PyTorch operations to XLA operations. Lowering defines the process of converting a higher-level representation of an operation to a lower-level representation. PyTorch/XLA forwards operations for which an XLA lowering hasn't been defined to the CPU which calls the ATen implementation of the operation. Operations that are forwarded to the CPU will cause a significant slowdown. To achieve the best performance, all operations used in the model must have a lowering defined. For more information, see Ezyang's Blog post.
Here's an example of what you might see from the PyTorch/XLA debugging tool for an operation that has not been lowered:
pt-xla-profiler: Op(s) not lowered: aten::_ctc_loss, aten::_ctc_loss_backward, Please open a GitHub issue with the above op lowering requests.
Furthermore, if possible, we want to lower operations to use full_codegen see our codegen migration guide for more instructions.
You should follow the instructions in here to install required dependencies and build pytorch and pytorch/XLA from the source. You do not need access to TPU to implement the lowering. It is recommended to experiment on a workstation and configure it to use XLA:CPU. You can configure Pytorch/XLA to use XLA:CPU by running
export PJRT_DEVICE=CPU
You can find the definition of the C++ ATen operations in native_functions.yaml. After you build Pytorch/XLA from source, you will also find our default implementation (a boxed kernel which forwards calls to either PyTorch native kernels) in xla/torch_xla/csrc/aten_fallback.h/cpp. Pytorch operations can usually be mapped to PyTorch tensor api easily. If that is not the case searching the PyTorch native implementation under PyTorch repo is recommended. The goal is to lower the PyTorch operations into a sequence of XLA operations defined in here.
All file mentioned below lives under the xla/torch_xla/csrc folder, with the exception of codegen/xla_native_functions.yaml
xla_native_functions.yamlcontains the list of all operators (from the Core Aten list) that are explicitly lowered. Composed operators are not listed here. Each operator name here must directly match a pytorch operator listed in native_functions.yaml. This file serves as the interface to adding new xla operators, and is an input to PyTorch's codegen machinery. It generates the below 3 files:XLANativeFunctions.h,RegisterXLA.cpp, andRegisterAutogradXLA.cppXLANativeFunctions.handaten_xla_type.cppare entry points of PyTorch to the pytorch_xla world, and contain the manually written lowerings to XLA for each operator.XLANativeFunctions.his auto-generated through a combination ofxla_native_functions.yamland the PyTorch corenative_functions.yamlfile, and contains declarations for kernels that need to be defined inaten_xla_type.cpp. The kernels written here need to construct 'XLATensor' using the inputat::Tensorand other parameters. The resultingXLATensorneeds to be converted back to theat::Tensorbefore returning to the PyTorch world.RegisterXLA.cppandRegisterAutogradXLA.cppare auto-generated files that register all lowerings to the PyTorch Dispatcher. They also include auto-generated wrapper implementations ofout=andinplaceoperators.aten_fallback.h/.cppcontain our boxed fallback implementation. The boxed fallback kernel will be used if a lowering is not explicitly defined inxla_native_functions.yaml+aten_xla_type.cpp, and the operator is not composite.tensor_methods.hcontains theXLATensordeclarations. These declarations are usually a one to one mapping of theat::Tensornodes we declared inXLANativeFunctions.htensor_methods.cppcontains the implementation ofXLATensor nodedefined intensor_methods.h. We constructed the correspondingir::opfrom the parameter’sir::Valueand wrapped it inside aXLATensor. Ir stands for intermediate representation.ops/directory contains allir::opsdeclaration and definition. Smaller nodes can be put inops/ops.h/.cpp. More complicated nodes can be put into a separate file. All ops inherit fromir::ops::Nodeand provide a way to lower inputir::Valueto a sequence ofXlaOp.
Our CI runs PyTorch native python tests for every change and every day. Those tests will use XLA implementation if we provide a lowering. We usually don’t need to add additional python tests for PyTorch/XLA unless we want to verify some xla behaviors(like dynamic shape) or we skipped the pytorch native test for some reason. The python test should be added to xla/test/test_operations.py if it is required. We also need to add CPP tests in xla/test/cpp/test_aten_xla_tensor.cpp. This test should call PyTorch c++ API and verify our implementation yields the same result as PyTorch native implementation. We also need to verify if the xla implementation is called when the tensor is a XLA tensor by checking the aten::op and xla::op counters.
The process of lowering is breaking down the PyTorch operations into a sequence of XlaOp. To provide a good lowering of the PyTorch operation, one needs to have a good grasp of what XLA is capable of. Reading the XlaOp document and looking into how similar ops is lowered is the best way to achieve that. You can find a minimal Op lowering example in this pr. You can also find a slightly more complicated example with backward lowering in this pr.
We have auto-generated wrapper implementations of out= and inplace operators for some operators in RegisterXLA.cpp. We only need to lower the vanilla op in this case. An example would be lerp operator which has 6 variants in native_functions.yaml, they are
- lerp_.Scalar
- lerp_.Tensor
- lerp.Scalar_out
- lerp.Tensor_out
- lerp.Scalar
- lerp.Tensor
and will generate function prototypes
at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight);
at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight);
at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight);
at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight, at::Tensor & out);
at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Tensor & weight);
at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight, at::Tensor & out);
in XLANativeFunctions.h if we add all of them to the xla_native_functions.yaml. However if we only lower lerp.Scalar and lerp.Tensor and check RegisterXLA.cpp, we will see
namespace {
at::Tensor wrapper_Scalar_lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
// No device check
// DeviceGuard omitted
return torch_xla::lerp(self, end, weight);
}
} // anonymous namespace
at::Tensor & wrapper_Scalar_lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) {
auto wrapper_Scalar_lerp__tmp = wrapper_Scalar_lerp(self, end, weight);
at::_copy_from(wrapper_Scalar_lerp__tmp, self);
return self;
}
...
m.impl("lerp_.Scalar",
TORCH_FN(wrapper_Scalar_lerp_));
The codegen will automatically generate lowerings for lerp_.Scalar and lerp.Scalar_out that use our lerp.Scalar implementation, without us having to provide an explicit lowering.
In general, if there is an operator in pytorch core that has both an out-of-place and an out= variant, it's better to write a lowering for the out-of-place variant, since you'll get a code-generated out= lowering for free.
For each node we need to pass an ir::OpKind. Here is an (example). You can find the OpKind definition in interned_strings.h. If the aten symbol is missing, you can submit a PR like this.
In certain cases, it might be that we need to manually override the XLA key implementation of an operation. Ideally codegeneration would handle this, but it is useful to know how to handle an unfortunate edge case.
If you need to override the XLA dispatch key you can do this through macros in the xla_manual_registration.cpp file.
You can use the pytorch#8801 PR for reference on what files to change.