@@ -50,14 +50,36 @@ AOTITorchError aoti_torch_get_storage_offset(
5050}
5151
5252AOTITorchError aoti_torch_get_strides (Tensor* tensor, int64_t ** ret_strides) {
53- std::vector<int64_t > strides (tensor->dim ());
54- auto tensor_strides = tensor->strides ();
55- for (ssize_t i = 0 ; i < tensor->dim (); i++) {
56- strides[i] = static_cast <int64_t >(tensor_strides[i]);
53+ auto it = internal::tensor_to_strides.find (tensor);
54+ bool needs_update = false ;
55+
56+ if (it == internal::tensor_to_strides.end ()) {
57+ needs_update = true ;
58+ } else {
59+ // Check if cached values are still valid
60+ auto tensor_strides = tensor->strides ();
61+ if (it->second .size () != static_cast <size_t >(tensor->dim ())) {
62+ needs_update = true ;
63+ } else {
64+ for (int i = 0 ; i < tensor->dim (); i++) {
65+ if (it->second [i] != tensor_strides[i]) {
66+ needs_update = true ;
67+ break ;
68+ }
69+ }
70+ }
71+ }
72+
73+ if (needs_update) {
74+ std::vector<int64_t > strides (tensor->dim ());
75+ auto tensor_strides = tensor->strides ();
76+ for (int i = 0 ; i < tensor->dim (); i++) {
77+ strides[i] = tensor_strides[i];
78+ }
79+ it =
80+ internal::tensor_to_strides.insert_or_assign (tensor, std::move (strides))
81+ .first ;
5782 }
58- auto it =
59- internal::tensor_to_strides.insert_or_assign (tensor, std::move (strides))
60- .first ;
6183
6284 // For 0D tensors, data() returns nullptr on empty vectors, but we need to
6385 // return a valid pointer
@@ -78,13 +100,35 @@ AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) {
78100}
79101
80102AOTITorchError aoti_torch_get_sizes (Tensor* tensor, int64_t ** ret_sizes) {
81- std::vector<int64_t > sizes (tensor->dim ());
82- auto tensor_sizes = tensor->sizes ();
83- for (ssize_t i = 0 ; i < tensor->dim (); i++) {
84- sizes[i] = static_cast <int64_t >(tensor_sizes[i]);
103+ auto it = internal::tensor_to_sizes.find (tensor);
104+ bool needs_update = false ;
105+
106+ if (it == internal::tensor_to_sizes.end ()) {
107+ needs_update = true ;
108+ } else {
109+ // Check if cached values are still valid
110+ auto tensor_sizes = tensor->sizes ();
111+ if (it->second .size () != static_cast <size_t >(tensor->dim ())) {
112+ needs_update = true ;
113+ } else {
114+ for (int i = 0 ; i < tensor->dim (); i++) {
115+ if (it->second [i] != tensor_sizes[i]) {
116+ needs_update = true ;
117+ break ;
118+ }
119+ }
120+ }
121+ }
122+
123+ if (needs_update) {
124+ std::vector<int64_t > sizes (tensor->dim ());
125+ auto tensor_sizes = tensor->sizes ();
126+ for (int i = 0 ; i < tensor->dim (); i++) {
127+ sizes[i] = tensor_sizes[i];
128+ }
129+ it = internal::tensor_to_sizes.insert_or_assign (tensor, std::move (sizes))
130+ .first ;
85131 }
86- auto it = internal::tensor_to_sizes.insert_or_assign (tensor, std::move (sizes))
87- .first ;
88132
89133 // For 0D tensors, data() returns nullptr on empty vectors, but we need to
90134 // return a valid pointer
0 commit comments