Skip to content

Commit 4f41a69

Browse files
committed
[slimtensor] Add storage and device property getters to common_shims_slim
Add storage and device property getter AOTI shim functions to the header-only common_shims_slim library: 1. `aoti_torch_get_storage_offset()` - Returns the storage offset (SlimTensor: real offset, ETensor: always 0) 2. `aoti_torch_get_storage_size()` - Returns storage size in bytes 3. `aoti_torch_get_device_type()` - Returns device type (SlimTensor: real type, ETensor: CPU=0) 4. `aoti_torch_get_device_index()` - Returns device index (SlimTensor: real index, ETensor: 0) Differential Revision: [D90126251](https://our.internmc.facebook.com/intern/diff/D90126251/) ghstack-source-id: 331923140 Pull Request resolved: #16455
1 parent 048ba58 commit 4f41a69

File tree

2 files changed

+209
-0
lines changed

2 files changed

+209
-0
lines changed

backends/aoti/common_shims_slim.h

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,84 @@ inline AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim) {
210210
return Error::Ok;
211211
}
212212

213+
// ============================================================
214+
// Storage & Device Property Getters - Inline implementations
215+
// ============================================================
216+
217+
inline AOTITorchError aoti_torch_get_storage_offset(
218+
Tensor* tensor,
219+
int64_t* ret_storage_offset) {
220+
if (tensor == nullptr) {
221+
return Error::InvalidArgument;
222+
}
223+
if (ret_storage_offset == nullptr) {
224+
return Error::InvalidArgument;
225+
}
226+
227+
#ifdef CUDA_AVAILABLE
228+
// SlimTensor supports real storage offset
229+
*ret_storage_offset = tensor->storage_offset();
230+
#else
231+
// ETensor doesn't support storage_offset, return 0
232+
*ret_storage_offset = 0;
233+
#endif
234+
return Error::Ok;
235+
}
236+
237+
inline AOTITorchError aoti_torch_get_storage_size(
238+
Tensor* tensor,
239+
int64_t* ret_size) {
240+
if (tensor == nullptr) {
241+
return Error::InvalidArgument;
242+
}
243+
if (ret_size == nullptr) {
244+
return Error::InvalidArgument;
245+
}
246+
247+
*ret_size = static_cast<int64_t>(tensor->nbytes());
248+
return Error::Ok;
249+
}
250+
251+
inline AOTITorchError aoti_torch_get_device_type(
252+
Tensor* tensor,
253+
int32_t* ret_device_type) {
254+
if (tensor == nullptr) {
255+
return Error::InvalidArgument;
256+
}
257+
if (ret_device_type == nullptr) {
258+
return Error::InvalidArgument;
259+
}
260+
261+
#ifdef CUDA_AVAILABLE
262+
// SlimTensor supports real device type
263+
*ret_device_type = static_cast<int32_t>(tensor->device_type());
264+
#else
265+
// ETensor is always CPU in default mode
266+
*ret_device_type = 0; // CPU
267+
#endif
268+
return Error::Ok;
269+
}
270+
271+
inline AOTITorchError aoti_torch_get_device_index(
272+
Tensor* tensor,
273+
int32_t* ret_device_index) {
274+
if (tensor == nullptr) {
275+
return Error::InvalidArgument;
276+
}
277+
if (ret_device_index == nullptr) {
278+
return Error::InvalidArgument;
279+
}
280+
281+
#ifdef CUDA_AVAILABLE
282+
// SlimTensor supports real device index
283+
*ret_device_index = static_cast<int32_t>(tensor->device_index());
284+
#else
285+
// ETensor doesn't support multi-device, return 0
286+
*ret_device_index = 0;
287+
#endif
288+
return Error::Ok;
289+
}
290+
213291
} // namespace aoti
214292
} // namespace backends
215293
} // namespace executorch

backends/aoti/tests/test_common_shims_slim.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,93 @@ void runGetDimTest(slim_c10::DeviceType device_type) {
289289
}
290290
}
291291

292+
// ============================================================================
293+
// Storage & Device Property Tests
294+
// ============================================================================
295+
296+
void runGetStorageOffsetTest(slim_c10::DeviceType device_type) {
297+
std::vector<int64_t> sizes = {2, 3};
298+
std::vector<int64_t> strides = calculateContiguousStrides(sizes);
299+
slim_c10::Device device(device_type, 0);
300+
301+
Tensor* tensor = new Tensor(slim::empty_strided(
302+
slim::makeArrayRef(sizes),
303+
slim::makeArrayRef(strides),
304+
slim_c10::ScalarType::Float,
305+
device));
306+
307+
int64_t ret_storage_offset = -1;
308+
AOTITorchError error =
309+
aoti_torch_get_storage_offset(tensor, &ret_storage_offset);
310+
311+
EXPECT_EQ(error, Error::Ok);
312+
// Default storage offset for newly created tensor is 0
313+
EXPECT_EQ(ret_storage_offset, 0);
314+
315+
delete tensor;
316+
}
317+
318+
void runGetStorageSizeTest(slim_c10::DeviceType device_type) {
319+
std::vector<int64_t> sizes = {2, 3};
320+
std::vector<int64_t> strides = calculateContiguousStrides(sizes);
321+
slim_c10::Device device(device_type, 0);
322+
323+
Tensor* tensor = new Tensor(slim::empty_strided(
324+
slim::makeArrayRef(sizes),
325+
slim::makeArrayRef(strides),
326+
slim_c10::ScalarType::Float,
327+
device));
328+
329+
int64_t ret_size = -1;
330+
AOTITorchError error = aoti_torch_get_storage_size(tensor, &ret_size);
331+
332+
EXPECT_EQ(error, Error::Ok);
333+
// 2 * 3 * sizeof(float) = 6 * 4 = 24 bytes
334+
EXPECT_EQ(ret_size, 24);
335+
336+
delete tensor;
337+
}
338+
339+
void runGetDeviceTypeTest(slim_c10::DeviceType device_type) {
340+
std::vector<int64_t> sizes = {2, 3};
341+
std::vector<int64_t> strides = calculateContiguousStrides(sizes);
342+
slim_c10::Device device(device_type, 0);
343+
344+
Tensor* tensor = new Tensor(slim::empty_strided(
345+
slim::makeArrayRef(sizes),
346+
slim::makeArrayRef(strides),
347+
slim_c10::ScalarType::Float,
348+
device));
349+
350+
int32_t ret_device_type = -1;
351+
AOTITorchError error = aoti_torch_get_device_type(tensor, &ret_device_type);
352+
353+
EXPECT_EQ(error, Error::Ok);
354+
EXPECT_EQ(ret_device_type, static_cast<int32_t>(device_type));
355+
356+
delete tensor;
357+
}
358+
359+
void runGetDeviceIndexTest(slim_c10::DeviceType device_type) {
360+
std::vector<int64_t> sizes = {2, 3};
361+
std::vector<int64_t> strides = calculateContiguousStrides(sizes);
362+
slim_c10::Device device(device_type, 0);
363+
364+
Tensor* tensor = new Tensor(slim::empty_strided(
365+
slim::makeArrayRef(sizes),
366+
slim::makeArrayRef(strides),
367+
slim_c10::ScalarType::Float,
368+
device));
369+
370+
int32_t ret_device_index = -1;
371+
AOTITorchError error = aoti_torch_get_device_index(tensor, &ret_device_index);
372+
373+
EXPECT_EQ(error, Error::Ok);
374+
EXPECT_EQ(ret_device_index, 0);
375+
376+
delete tensor;
377+
}
378+
292379
// ============================================================================
293380
// CPU Tests
294381
// ============================================================================
@@ -313,6 +400,22 @@ TEST_F(CommonShimsSlimTest, GetDim_CPU) {
313400
runGetDimTest(slim_c10::DeviceType::CPU);
314401
}
315402

403+
TEST_F(CommonShimsSlimTest, GetStorageOffset_CPU) {
404+
runGetStorageOffsetTest(slim_c10::DeviceType::CPU);
405+
}
406+
407+
TEST_F(CommonShimsSlimTest, GetStorageSize_CPU) {
408+
runGetStorageSizeTest(slim_c10::DeviceType::CPU);
409+
}
410+
411+
TEST_F(CommonShimsSlimTest, GetDeviceType_CPU) {
412+
runGetDeviceTypeTest(slim_c10::DeviceType::CPU);
413+
}
414+
415+
TEST_F(CommonShimsSlimTest, GetDeviceIndex_CPU) {
416+
runGetDeviceIndexTest(slim_c10::DeviceType::CPU);
417+
}
418+
316419
// ============================================================================
317420
// CUDA Tests
318421
// ============================================================================
@@ -352,6 +455,34 @@ TEST_F(CommonShimsSlimTest, GetDim_CUDA) {
352455
}
353456
runGetDimTest(slim_c10::DeviceType::CUDA);
354457
}
458+
459+
TEST_F(CommonShimsSlimTest, GetStorageOffset_CUDA) {
460+
if (!isCudaAvailable()) {
461+
GTEST_SKIP() << "CUDA not available";
462+
}
463+
runGetStorageOffsetTest(slim_c10::DeviceType::CUDA);
464+
}
465+
466+
TEST_F(CommonShimsSlimTest, GetStorageSize_CUDA) {
467+
if (!isCudaAvailable()) {
468+
GTEST_SKIP() << "CUDA not available";
469+
}
470+
runGetStorageSizeTest(slim_c10::DeviceType::CUDA);
471+
}
472+
473+
TEST_F(CommonShimsSlimTest, GetDeviceType_CUDA) {
474+
if (!isCudaAvailable()) {
475+
GTEST_SKIP() << "CUDA not available";
476+
}
477+
runGetDeviceTypeTest(slim_c10::DeviceType::CUDA);
478+
}
479+
480+
TEST_F(CommonShimsSlimTest, GetDeviceIndex_CUDA) {
481+
if (!isCudaAvailable()) {
482+
GTEST_SKIP() << "CUDA not available";
483+
}
484+
runGetDeviceIndexTest(slim_c10::DeviceType::CUDA);
485+
}
355486
#endif
356487

357488
// ============================================================================

0 commit comments

Comments
 (0)