Skip to content

Commit 63632fc

Browse files
janeyx99pytorchmergebot
authored andcommitted
Add new_zeros dtype variant to the shim and as a stable op (pytorch#161597)
In case we want this before 2.9 Pull Request resolved: pytorch#161597 Approved by: https://github.com/mikaylagawarecki
1 parent 05d0f11 commit 63632fc

File tree

6 files changed

+74
-1
lines changed

6 files changed

+74
-1
lines changed

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ void boxed_my_narrow(
343343

344344
Tensor my_new_empty_dtype_variant(Tensor t) {
345345
std::vector<int64_t> sizes = {2, 5};
346-
auto dtype = std::make_optional(at::ScalarType::BFloat16);
346+
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
347347
return new_empty(t, sizes, dtype);
348348
}
349349

@@ -352,13 +352,25 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui
352352
stack[0] = from(res);
353353
}
354354

355+
Tensor my_new_zeros_dtype_variant(Tensor t) {
356+
std::vector<int64_t> sizes = {2, 5};
357+
auto dtype = std::make_optional(at::ScalarType::Float);
358+
return new_zeros(t, sizes, dtype);
359+
}
360+
361+
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
362+
auto res = my_new_zeros_dtype_variant(to<Tensor>(stack[0]));
363+
stack[0] = from(res);
364+
}
365+
355366
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
356367
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
357368
m.def("my_empty_like(Tensor t) -> Tensor");
358369
m.def("fill_infinity(Tensor(a!) t) -> Tensor(a!)");
359370
m.def("my_pad(Tensor t) -> Tensor");
360371
m.def("my_narrow(Tensor t, int dim, int start, int length) -> Tensor");
361372
m.def("my_new_empty_dtype_variant(Tensor t) -> Tensor");
373+
m.def("my_new_zeros_dtype_variant(Tensor t) -> Tensor");
362374
}
363375

364376
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
@@ -367,6 +379,7 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
367379
m.impl("fill_infinity", &boxed_fill_infinity);
368380
m.impl("my_is_cpu", &boxed_my_is_cpu);
369381
m.impl("my_new_empty_dtype_variant", &boxed_my_new_empty_dtype_variant);
382+
m.impl("my_new_zeros_dtype_variant", &boxed_my_new_zeros_dtype_variant);
370383
}
371384

372385
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) {

test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,15 @@ def my_new_empty_dtype_variant(t) -> Tensor:
295295
Returns: New empty tensor with shape [2, 5] and dtype bfloat16
296296
"""
297297
return torch.ops.libtorch_agnostic.my_new_empty_dtype_variant.default(t)
298+
299+
300+
def my_new_zeros_dtype_variant(t) -> Tensor:
301+
"""
302+
Returns a new tensor filled with 0s with shape [2, 5] and dtype Float
303+
304+
Args:
305+
t: Input tensor used as a reference for device and other properties
306+
307+
Returns: New zeros tensor
308+
"""
309+
return torch.ops.libtorch_agnostic.my_new_zeros_dtype_variant.default(t)

test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,14 @@ def test_my_new_empty_dtype_variant(self, device):
337337
finally:
338338
torch.use_deterministic_algorithms(deterministic)
339339

340+
def test_my_new_zeros_dtype_variant(self, device):
341+
import libtorch_agnostic
342+
343+
t = torch.randn(3, 4, device=device)
344+
out = libtorch_agnostic.ops.my_new_zeros_dtype_variant(t)
345+
ref_out = t.new_zeros((2, 5), dtype=torch.float)
346+
self.assertEqual(out, ref_out, exact_device=True)
347+
340348
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
341349

342350
if __name__ == "__main__":

torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_amax(AtenTensorHandle self, con
1818
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value);
1919
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0);
2020
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_new_empty(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
21+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_new_zeros(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
2122
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0);
2223

2324
#ifdef __cplusplus

torch/csrc/stable/ops.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,44 @@ inline Tensor new_empty(
9090
return Tensor(ret0);
9191
}
9292

93+
// We expect this to be a stable version of the new_zeros op that takes in
94+
// only dtype information.
95+
inline Tensor new_zeros(
96+
const Tensor& self,
97+
std::vector<int64_t> size,
98+
std::optional<c10::ScalarType> dtype = std::nullopt) {
99+
int32_t device_type;
100+
TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type));
101+
102+
int32_t device_index;
103+
TORCH_ERROR_CODE_CHECK(
104+
aoti_torch_get_device_index(self.get(), &device_index));
105+
106+
int32_t target_dtype;
107+
if (dtype.has_value()) {
108+
target_dtype = to<int32_t>(from(dtype.value()));
109+
} else {
110+
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype));
111+
}
112+
113+
int32_t layout;
114+
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout));
115+
116+
AtenTensorHandle ath;
117+
TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_zeros(
118+
self.get(),
119+
size.data(),
120+
static_cast<int64_t>(size.size()),
121+
&target_dtype,
122+
&layout,
123+
&device_type,
124+
device_index,
125+
nullptr, // pin_memory (nullptr for default)
126+
&ath));
127+
128+
return Tensor(ath);
129+
}
130+
93131
// We expect this to be the stable version of the pad.default op.
94132
// pad.default takes in a SymInt[] as the pad argument however pad is typed as
95133
// use std::vector<int64_t> because

torchgen/aoti/fallback_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,4 +187,5 @@
187187
"aten.narrow.default": {},
188188
"aten.amax.default": {},
189189
"aten.new_empty.default": {},
190+
"aten.new_zeros.default": {},
190191
}

0 commit comments

Comments
 (0)