|
| 1 | +#include <torch/csrc/inductor/aoti_torch/c/shim.h> |
| 2 | +#include <torch/csrc/stable/accelerator.h> |
| 3 | +#include <torch/csrc/stable/device.h> |
| 4 | +#include <torch/csrc/stable/library.h> |
| 5 | +#include <torch/csrc/stable/tensor.h> |
| 6 | +#include <torch/csrc/stable/ops.h> |
| 7 | +#include <torch/headeronly/util/Exception.h> |
| 8 | +#include <torch/headeronly/core/ScalarType.h> |
| 9 | + |
| 10 | +#ifdef LAE_USE_CUDA |
| 11 | +#include <cuda_runtime.h> |
| 12 | +#endif |
| 13 | + |
| 14 | +#include <optional> |
| 15 | + |
| 16 | +using torch::stable::Tensor; |
| 17 | + |
| 18 | +std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) { |
| 19 | + std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)}; |
| 20 | + aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data()); |
| 21 | + return torch::stable::detail::to<std::vector<Tensor>>(stack[0]); |
| 22 | +} |
| 23 | + |
| 24 | +void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) { |
| 25 | + std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)}; |
| 26 | + aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data()); |
| 27 | +} |
| 28 | + |
| 29 | +Tensor my_clone(Tensor t) { |
| 30 | + return clone(t); |
| 31 | +} |
| 32 | + |
| 33 | +std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) { |
| 34 | + // This function tests that my__foreach_mul can take in std::initializer_lists |
| 35 | + // in addition to std::vectors. |
| 36 | + Tensor t1_1 = my_clone(t1); |
| 37 | + Tensor t1_2 = my_clone(t1); |
| 38 | + Tensor t2_1 = my_clone(t2); |
| 39 | + Tensor t2_2 = my_clone(t2); |
| 40 | + return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2}); |
| 41 | +} |
| 42 | + |
| 43 | +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { |
| 44 | + m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]"); |
| 45 | + m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()"); |
| 46 | + m.def("make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) -> Tensor[]"); |
| 47 | +} |
| 48 | + |
| 49 | +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { |
| 50 | + m.impl("my__foreach_mul", TORCH_BOX(&my__foreach_mul)); |
| 51 | + m.impl("my__foreach_mul_", TORCH_BOX(&my__foreach_mul_)); |
| 52 | + m.impl("make_tensor_clones_and_call_foreach", TORCH_BOX(&make_tensor_clones_and_call_foreach)); |
| 53 | +} |
| 54 | + |
| 55 | +// Test functions for torch::stable::Tensor device method |
| 56 | + |
| 57 | +torch::stable::Device test_tensor_device(torch::stable::Tensor tensor) { |
| 58 | + return tensor.device(); |
| 59 | +} |
| 60 | + |
| 61 | +// Test functions for torch::stable::Device |
| 62 | + |
| 63 | +torch::stable::Device test_device_constructor( |
| 64 | + bool is_cuda, |
| 65 | + torch::stable::DeviceIndex index, |
| 66 | + bool use_str) { |
| 67 | + using torch::stable::Device; |
| 68 | + using torch::stable::DeviceType; |
| 69 | + |
| 70 | + if (use_str) { |
| 71 | + std::string device_str; |
| 72 | + if (is_cuda) { |
| 73 | + device_str = "cuda:" + std::to_string(index); |
| 74 | + } else { |
| 75 | + device_str = "cpu"; |
| 76 | + } |
| 77 | + return Device(device_str); |
| 78 | + } else { |
| 79 | + if (is_cuda) { |
| 80 | + return Device(DeviceType::CUDA, index); |
| 81 | + } else { |
| 82 | + return Device(DeviceType::CPU); |
| 83 | + } |
| 84 | + } |
| 85 | +} |
| 86 | + |
| 87 | +bool test_device_equality(torch::stable::Device d1, torch::stable::Device d2) { |
| 88 | + return d1 == d2; |
| 89 | +} |
| 90 | + |
| 91 | +torch::stable::Device test_device_set_index( |
| 92 | + torch::stable::Device device, |
| 93 | + torch::stable::DeviceIndex index) { |
| 94 | + device.set_index(index); |
| 95 | + return device; |
| 96 | +} |
| 97 | + |
| 98 | +torch::stable::DeviceIndex test_device_index(torch::stable::Device device) { |
| 99 | + return device.index(); |
| 100 | +} |
| 101 | + |
| 102 | +bool test_device_is_cuda(torch::stable::Device device) { |
| 103 | + return device.is_cuda(); |
| 104 | +} |
| 105 | + |
| 106 | +bool test_device_is_cpu(torch::stable::Device device) { |
| 107 | + return device.is_cpu(); |
| 108 | +} |
| 109 | + |
| 110 | +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { |
| 111 | + m.def("test_tensor_device(Tensor t) -> Device"); |
| 112 | + m.def( |
| 113 | + "test_device_constructor(bool is_cuda, DeviceIndex index, bool use_str) -> Device"); |
| 114 | + m.def("test_device_equality(Device d1, Device d2) -> bool"); |
| 115 | + m.def("test_device_set_index(Device device, DeviceIndex index) -> Device"); |
| 116 | + m.def("test_device_index(Device device) -> DeviceIndex"); |
| 117 | + m.def("test_device_is_cuda(Device device) -> bool"); |
| 118 | + m.def("test_device_is_cpu(Device device) -> bool"); |
| 119 | +} |
| 120 | + |
| 121 | +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { |
| 122 | + m.impl("test_tensor_device", TORCH_BOX(&test_tensor_device)); |
| 123 | + m.impl("test_device_constructor", TORCH_BOX(&test_device_constructor)); |
| 124 | + m.impl("test_device_equality", TORCH_BOX(&test_device_equality)); |
| 125 | + m.impl("test_device_set_index", TORCH_BOX(&test_device_set_index)); |
| 126 | + m.impl("test_device_index", TORCH_BOX(&test_device_index)); |
| 127 | + m.impl("test_device_is_cuda", TORCH_BOX(&test_device_is_cuda)); |
| 128 | + m.impl("test_device_is_cpu", TORCH_BOX(&test_device_is_cpu)); |
| 129 | +} |
| 130 | + |
| 131 | +Tensor test_parallel_for(int64_t size, int64_t grain_size) { |
| 132 | + AtenTensorHandle tensor_handle; |
| 133 | + int64_t stride = 1; |
| 134 | + |
| 135 | + aoti_torch_empty_strided( |
| 136 | + 1, |
| 137 | + &size, |
| 138 | + &stride, |
| 139 | + aoti_torch_dtype_int64(), |
| 140 | + aoti_torch_device_type_cpu(), |
| 141 | + 0, |
| 142 | + &tensor_handle); |
| 143 | + |
| 144 | + Tensor tensor(tensor_handle); |
| 145 | + int64_t* data_ptr = reinterpret_cast<int64_t*>(tensor.data_ptr()); |
| 146 | + |
| 147 | + torch::stable::zero_(tensor); |
| 148 | + |
| 149 | + // Use parallel_for to fill each element with its index |
| 150 | + // If using a parallel path, the thread id is encoded in the upper 32 bits |
| 151 | + torch::stable::parallel_for( |
| 152 | + 0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { |
| 153 | + for (auto i = begin; i < end; i++) { |
| 154 | + STD_TORCH_CHECK(i <= UINT32_MAX); |
| 155 | + uint32_t thread_id; |
| 156 | + torch_get_thread_idx(&thread_id); |
| 157 | + data_ptr[i] = i | (static_cast<int64_t>(thread_id) << 32); |
| 158 | + } |
| 159 | + }); |
| 160 | + |
| 161 | + return tensor; |
| 162 | +} |
| 163 | + |
| 164 | +uint32_t test_get_num_threads() { |
| 165 | + return torch::stable::get_num_threads(); |
| 166 | +} |
| 167 | + |
| 168 | +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { |
| 169 | + m.def("test_parallel_for(int size, int grain_size) -> Tensor"); |
| 170 | + m.def("test_get_num_threads() -> int"); |
| 171 | +} |
| 172 | + |
| 173 | +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { |
| 174 | + m.impl("test_parallel_for", TORCH_BOX(&test_parallel_for)); |
| 175 | + m.impl("test_get_num_threads", TORCH_BOX(&test_get_num_threads)); |
| 176 | +} |
| 177 | + |
| 178 | +Tensor my_empty( |
| 179 | + torch::headeronly::HeaderOnlyArrayRef<int64_t> size, |
| 180 | + std::optional<torch::headeronly::ScalarType> dtype, |
| 181 | + std::optional<torch::stable::Device> device, |
| 182 | + std::optional<bool> pin_memory) { |
| 183 | + return empty(size, dtype, device, pin_memory); |
| 184 | +} |
| 185 | + |
| 186 | +Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> shape) { |
| 187 | + return reshape(t, shape); |
| 188 | +} |
| 189 | + |
| 190 | +Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> size) { |
| 191 | + return view(t, size); |
| 192 | +} |
| 193 | + |
| 194 | +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { |
| 195 | + m.def( |
| 196 | + "my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor"); |
| 197 | + m.def("my_reshape(Tensor t, int[] shape) -> Tensor"); |
| 198 | + m.def("my_view(Tensor t, int[] size) -> Tensor"); |
| 199 | +} |
| 200 | + |
| 201 | +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { |
| 202 | + m.impl("my_empty", TORCH_BOX(&my_empty)); |
| 203 | + m.impl("my_reshape", TORCH_BOX(&my_reshape)); |
| 204 | + m.impl("my_view", TORCH_BOX(&my_view)); |
| 205 | +} |
| 206 | + |
| 207 | +uint64_t get_any_data_ptr(Tensor t, bool mutable_) { |
| 208 | + if (mutable_) { |
| 209 | + return reinterpret_cast<uint64_t>(t.mutable_data_ptr()); |
| 210 | + } else { |
| 211 | + return reinterpret_cast<uint64_t>(t.const_data_ptr()); |
| 212 | + } |
| 213 | +} |
| 214 | + |
| 215 | +uint64_t get_template_any_data_ptr(Tensor t, c10::ScalarType dtype, bool mutable_) { |
| 216 | +#define DEFINE_CASE(T, name) \ |
| 217 | + case torch::headeronly::ScalarType::name: { \ |
| 218 | + if (mutable_) { \ |
| 219 | + return reinterpret_cast<uint64_t>(t.mutable_data_ptr<T>()); \ |
| 220 | + } else { \ |
| 221 | + return reinterpret_cast<uint64_t>(t.const_data_ptr<T>()); \ |
| 222 | + } \ |
| 223 | + } |
| 224 | + switch (dtype) { |
| 225 | + // per aten/src/ATen/templates/TensorMethods.cpp: |
| 226 | + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE) |
| 227 | + DEFINE_CASE(uint16_t, UInt16) |
| 228 | + DEFINE_CASE(uint32_t, UInt32) |
| 229 | + DEFINE_CASE(uint64_t, UInt64) |
| 230 | + default: |
| 231 | + return 0; |
| 232 | + } |
| 233 | +#undef DEFINE_CASE |
| 234 | +} |
| 235 | + |
| 236 | +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { |
| 237 | + m.def("get_any_data_ptr(Tensor t, bool mutable_) -> int"); |
| 238 | + m.def("get_template_any_data_ptr(Tensor t, ScalarType dtype, bool mutable_) -> int"); |
| 239 | +} |
| 240 | + |
| 241 | +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { |
| 242 | + m.impl("get_any_data_ptr", TORCH_BOX(&get_any_data_ptr)); |
| 243 | + m.impl("get_template_any_data_ptr", TORCH_BOX(&get_template_any_data_ptr)); |
| 244 | +} |
0 commit comments