@@ -50,36 +50,14 @@ AOTITorchError aoti_torch_get_storage_offset(
5050}
5151
5252AOTITorchError aoti_torch_get_strides (Tensor* tensor, int64_t ** ret_strides) {
53- auto it = internal::tensor_to_strides.find (tensor);
54- bool needs_update = false ;
55-
56- if (it == internal::tensor_to_strides.end ()) {
57- needs_update = true ;
58- } else {
59- // Check if cached values are still valid
60- auto tensor_strides = tensor->strides ();
61- if (it->second .size () != static_cast <size_t >(tensor->dim ())) {
62- needs_update = true ;
63- } else {
64- for (int i = 0 ; i < tensor->dim (); i++) {
65- if (it->second [i] != tensor_strides[i]) {
66- needs_update = true ;
67- break ;
68- }
69- }
70- }
71- }
72-
73- if (needs_update) {
74- std::vector<int64_t > strides (tensor->dim ());
75- auto tensor_strides = tensor->strides ();
76- for (int i = 0 ; i < tensor->dim (); i++) {
77- strides[i] = tensor_strides[i];
78- }
79- it =
80- internal::tensor_to_strides.insert_or_assign (tensor, std::move (strides))
81- .first ;
53+ std::vector<int64_t > strides (tensor->dim ());
54+ auto tensor_strides = tensor->strides ();
55+ for (ssize_t i = 0 ; i < tensor->dim (); i++) {
56+ strides[i] = static_cast <int64_t >(tensor_strides[i]);
8257 }
58+ auto it =
59+ internal::tensor_to_strides.insert_or_assign (tensor, std::move (strides))
60+ .first ;
8361
8462 // For 0D tensors, data() returns nullptr on empty vectors, but we need to
8563 // return a valid pointer
@@ -100,35 +78,13 @@ AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) {
10078}
10179
10280AOTITorchError aoti_torch_get_sizes (Tensor* tensor, int64_t ** ret_sizes) {
103- auto it = internal::tensor_to_sizes.find (tensor);
104- bool needs_update = false ;
105-
106- if (it == internal::tensor_to_sizes.end ()) {
107- needs_update = true ;
108- } else {
109- // Check if cached values are still valid
110- auto tensor_sizes = tensor->sizes ();
111- if (it->second .size () != static_cast <size_t >(tensor->dim ())) {
112- needs_update = true ;
113- } else {
114- for (int i = 0 ; i < tensor->dim (); i++) {
115- if (it->second [i] != tensor_sizes[i]) {
116- needs_update = true ;
117- break ;
118- }
119- }
120- }
121- }
122-
123- if (needs_update) {
124- std::vector<int64_t > sizes (tensor->dim ());
125- auto tensor_sizes = tensor->sizes ();
126- for (int i = 0 ; i < tensor->dim (); i++) {
127- sizes[i] = tensor_sizes[i];
128- }
129- it = internal::tensor_to_sizes.insert_or_assign (tensor, std::move (sizes))
130- .first ;
81+ std::vector<int64_t > sizes (tensor->dim ());
82+ auto tensor_sizes = tensor->sizes ();
83+ for (ssize_t i = 0 ; i < tensor->dim (); i++) {
84+ sizes[i] = static_cast <int64_t >(tensor_sizes[i]);
13185 }
86+ auto it = internal::tensor_to_sizes.insert_or_assign (tensor, std::move (sizes))
87+ .first ;
13288
13389 // For 0D tensors, data() returns nullptr on empty vectors, but we need to
13490 // return a valid pointer
0 commit comments