Skip to content

Commit 4c127f1

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Split libtorch agnostic tests by feature version (pytorch#167803)
Tests are split into libtorch_agnostic_2_9_extension and libtorch_agnostic_2_10_extension depending on the minimum version they should compile+run in Pull Request resolved: pytorch#167803 Approved by: https://github.com/janeyx99 ghstack dependencies: pytorch#168025, pytorch#167802
1 parent 3beb378 commit 4c127f1

File tree

14 files changed

+1179
-1039
lines changed

14 files changed

+1179
-1039
lines changed
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
2+
#include <torch/csrc/stable/accelerator.h>
3+
#include <torch/csrc/stable/device.h>
4+
#include <torch/csrc/stable/library.h>
5+
#include <torch/csrc/stable/tensor.h>
6+
#include <torch/csrc/stable/ops.h>
7+
#include <torch/headeronly/util/Exception.h>
8+
#include <torch/headeronly/core/ScalarType.h>
9+
10+
#ifdef LAE_USE_CUDA
11+
#include <cuda_runtime.h>
12+
#endif
13+
14+
#include <optional>
15+
16+
using torch::stable::Tensor;
17+
18+
std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
19+
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
20+
aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data());
21+
return torch::stable::detail::to<std::vector<Tensor>>(stack[0]);
22+
}
23+
24+
void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
25+
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
26+
aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data());
27+
}
28+
29+
Tensor my_clone(Tensor t) {
30+
return clone(t);
31+
}
32+
33+
std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) {
34+
// This function tests that my__foreach_mul can take in std::initializer_lists
35+
// in addition to std::vectors.
36+
Tensor t1_1 = my_clone(t1);
37+
Tensor t1_2 = my_clone(t1);
38+
Tensor t2_1 = my_clone(t2);
39+
Tensor t2_2 = my_clone(t2);
40+
return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2});
41+
}
42+
43+
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
44+
m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]");
45+
m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()");
46+
m.def("make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) -> Tensor[]");
47+
}
48+
49+
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
50+
m.impl("my__foreach_mul", TORCH_BOX(&my__foreach_mul));
51+
m.impl("my__foreach_mul_", TORCH_BOX(&my__foreach_mul_));
52+
m.impl("make_tensor_clones_and_call_foreach", TORCH_BOX(&make_tensor_clones_and_call_foreach));
53+
}
54+
55+
// Test functions for torch::stable::Tensor device method
56+
57+
torch::stable::Device test_tensor_device(torch::stable::Tensor tensor) {
58+
return tensor.device();
59+
}
60+
61+
// Test functions for torch::stable::Device
62+
63+
torch::stable::Device test_device_constructor(
64+
bool is_cuda,
65+
torch::stable::DeviceIndex index,
66+
bool use_str) {
67+
using torch::stable::Device;
68+
using torch::stable::DeviceType;
69+
70+
if (use_str) {
71+
std::string device_str;
72+
if (is_cuda) {
73+
device_str = "cuda:" + std::to_string(index);
74+
} else {
75+
device_str = "cpu";
76+
}
77+
return Device(device_str);
78+
} else {
79+
if (is_cuda) {
80+
return Device(DeviceType::CUDA, index);
81+
} else {
82+
return Device(DeviceType::CPU);
83+
}
84+
}
85+
}
86+
87+
bool test_device_equality(torch::stable::Device d1, torch::stable::Device d2) {
88+
return d1 == d2;
89+
}
90+
91+
torch::stable::Device test_device_set_index(
92+
torch::stable::Device device,
93+
torch::stable::DeviceIndex index) {
94+
device.set_index(index);
95+
return device;
96+
}
97+
98+
torch::stable::DeviceIndex test_device_index(torch::stable::Device device) {
99+
return device.index();
100+
}
101+
102+
bool test_device_is_cuda(torch::stable::Device device) {
103+
return device.is_cuda();
104+
}
105+
106+
bool test_device_is_cpu(torch::stable::Device device) {
107+
return device.is_cpu();
108+
}
109+
110+
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
111+
m.def("test_tensor_device(Tensor t) -> Device");
112+
m.def(
113+
"test_device_constructor(bool is_cuda, DeviceIndex index, bool use_str) -> Device");
114+
m.def("test_device_equality(Device d1, Device d2) -> bool");
115+
m.def("test_device_set_index(Device device, DeviceIndex index) -> Device");
116+
m.def("test_device_index(Device device) -> DeviceIndex");
117+
m.def("test_device_is_cuda(Device device) -> bool");
118+
m.def("test_device_is_cpu(Device device) -> bool");
119+
}
120+
121+
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
122+
m.impl("test_tensor_device", TORCH_BOX(&test_tensor_device));
123+
m.impl("test_device_constructor", TORCH_BOX(&test_device_constructor));
124+
m.impl("test_device_equality", TORCH_BOX(&test_device_equality));
125+
m.impl("test_device_set_index", TORCH_BOX(&test_device_set_index));
126+
m.impl("test_device_index", TORCH_BOX(&test_device_index));
127+
m.impl("test_device_is_cuda", TORCH_BOX(&test_device_is_cuda));
128+
m.impl("test_device_is_cpu", TORCH_BOX(&test_device_is_cpu));
129+
}
130+
131+
Tensor test_parallel_for(int64_t size, int64_t grain_size) {
132+
AtenTensorHandle tensor_handle;
133+
int64_t stride = 1;
134+
135+
aoti_torch_empty_strided(
136+
1,
137+
&size,
138+
&stride,
139+
aoti_torch_dtype_int64(),
140+
aoti_torch_device_type_cpu(),
141+
0,
142+
&tensor_handle);
143+
144+
Tensor tensor(tensor_handle);
145+
int64_t* data_ptr = reinterpret_cast<int64_t*>(tensor.data_ptr());
146+
147+
torch::stable::zero_(tensor);
148+
149+
// Use parallel_for to fill each element with its index
150+
// If using a parallel path, the thread id is encoded in the upper 32 bits
151+
torch::stable::parallel_for(
152+
0, size, grain_size, [data_ptr](int64_t begin, int64_t end) {
153+
for (auto i = begin; i < end; i++) {
154+
STD_TORCH_CHECK(i <= UINT32_MAX);
155+
uint32_t thread_id;
156+
torch_get_thread_idx(&thread_id);
157+
data_ptr[i] = i | (static_cast<int64_t>(thread_id) << 32);
158+
}
159+
});
160+
161+
return tensor;
162+
}
163+
164+
uint32_t test_get_num_threads() {
165+
return torch::stable::get_num_threads();
166+
}
167+
168+
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
169+
m.def("test_parallel_for(int size, int grain_size) -> Tensor");
170+
m.def("test_get_num_threads() -> int");
171+
}
172+
173+
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
174+
m.impl("test_parallel_for", TORCH_BOX(&test_parallel_for));
175+
m.impl("test_get_num_threads", TORCH_BOX(&test_get_num_threads));
176+
}
177+
178+
Tensor my_empty(
179+
torch::headeronly::HeaderOnlyArrayRef<int64_t> size,
180+
std::optional<torch::headeronly::ScalarType> dtype,
181+
std::optional<torch::stable::Device> device,
182+
std::optional<bool> pin_memory) {
183+
return empty(size, dtype, device, pin_memory);
184+
}
185+
186+
Tensor my_reshape(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> shape) {
187+
return reshape(t, shape);
188+
}
189+
190+
Tensor my_view(Tensor t, torch::headeronly::HeaderOnlyArrayRef<int64_t> size) {
191+
return view(t, size);
192+
}
193+
194+
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
195+
m.def(
196+
"my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor");
197+
m.def("my_reshape(Tensor t, int[] shape) -> Tensor");
198+
m.def("my_view(Tensor t, int[] size) -> Tensor");
199+
}
200+
201+
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
202+
m.impl("my_empty", TORCH_BOX(&my_empty));
203+
m.impl("my_reshape", TORCH_BOX(&my_reshape));
204+
m.impl("my_view", TORCH_BOX(&my_view));
205+
}
206+
207+
uint64_t get_any_data_ptr(Tensor t, bool mutable_) {
208+
if (mutable_) {
209+
return reinterpret_cast<uint64_t>(t.mutable_data_ptr());
210+
} else {
211+
return reinterpret_cast<uint64_t>(t.const_data_ptr());
212+
}
213+
}
214+
215+
uint64_t get_template_any_data_ptr(Tensor t, c10::ScalarType dtype, bool mutable_) {
216+
#define DEFINE_CASE(T, name) \
217+
case torch::headeronly::ScalarType::name: { \
218+
if (mutable_) { \
219+
return reinterpret_cast<uint64_t>(t.mutable_data_ptr<T>()); \
220+
} else { \
221+
return reinterpret_cast<uint64_t>(t.const_data_ptr<T>()); \
222+
} \
223+
}
224+
switch (dtype) {
225+
// per aten/src/ATen/templates/TensorMethods.cpp:
226+
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
227+
DEFINE_CASE(uint16_t, UInt16)
228+
DEFINE_CASE(uint32_t, UInt32)
229+
DEFINE_CASE(uint64_t, UInt64)
230+
default:
231+
return 0;
232+
}
233+
#undef DEFINE_CASE
234+
}
235+
236+
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
237+
m.def("get_any_data_ptr(Tensor t, bool mutable_) -> int");
238+
m.def("get_template_any_data_ptr(Tensor t, ScalarType dtype, bool mutable_) -> int");
239+
}
240+
241+
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
242+
m.impl("get_any_data_ptr", TORCH_BOX(&get_any_data_ptr));
243+
m.impl("get_template_any_data_ptr", TORCH_BOX(&get_template_any_data_ptr));
244+
}

0 commit comments

Comments
 (0)