Skip to content

Commit 54af621

Browse files
Add shape inference for AtenAsStrided (#4076)
Added shape inference for AtenAsStridedOp
1 parent 4c6e463 commit 54af621

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

projects/ltc/csrc/base_lazy_backend/shape_inference.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,13 @@ std::vector<torch::lazy::Shape> compute_shape_scalar_tensor(
513513
return {Shape(dtype.value_or(s.type()), c10::ArrayRef<int64_t>{})};
514514
}
515515

516+
std::vector<torch::lazy::Shape>
517+
compute_shape_as_strided(const at::Tensor &self, at::IntArrayRef size,
518+
at::IntArrayRef stride,
519+
c10::optional<int64_t> storage_offset) {
520+
return {Shape(self.scalar_type(), size.vec())};
521+
}
522+
516523
std::vector<torch::lazy::Shape> compute_shape_roll(const at::Tensor &self,
517524
at::IntArrayRef shifts,
518525
at::IntArrayRef dims) {

0 commit comments

Comments
 (0)