Skip to content

Commit 1aacb46

Browse files
Update shape_inference.cpp for randn.generator (#4062)
PR adds support for `aten.randn.generator`, the generator object (if passed) will not affect shape inference.
1 parent 0bb263e commit 1aacb46

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

projects/ltc/csrc/base_lazy_backend/shape_inference.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,14 @@ compute_shape_randn(at::IntArrayRef size, ::std::optional<at::ScalarType> dtype,
470470
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
471471
}
472472

473+
std::vector<torch::lazy::Shape> compute_shape_randn(
474+
at::IntArrayRef size, ::std::optional<at::Generator> generator,
475+
::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout,
476+
::std::optional<at::Device> device, ::std::optional<bool> pin_memory) {
477+
return {
478+
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
479+
}
480+
473481
std::vector<torch::lazy::Shape> compute_shape_randint(
474482
int64_t high, at::IntArrayRef size, ::std::optional<at::ScalarType> dtype,
475483
::std::optional<at::Layout> layout, ::std::optional<at::Device> device,

0 commit comments

Comments
 (0)