Skip to content

Commit d4ea20b

Browse files
committed
[slimtensor] Add utility functions to common_shims_slim
Add utility functions to the header-only common_shims_slim library: 1. DType constants: - `aoti_torch_dtype_float32()` - Returns 6 (ScalarType::Float) - `aoti_torch_dtype_bfloat16()` - Returns 15 (ScalarType::BFloat16) - `aoti_torch_dtype_int64()` - Returns 4 (ScalarType::Long) - `aoti_torch_dtype_int32()` - Returns 3 (ScalarType::Int) - `aoti_torch_dtype_int16()` - Returns 2 (ScalarType::Short) - `aoti_torch_dtype_int8()` - Returns 1 (ScalarType::Char) - `aoti_torch_dtype_bool()` - Returns 11 (ScalarType::Bool) 2. Device type constants: - `aoti_torch_device_type_cpu()` - Returns 0 (DeviceType::CPU) - `aoti_torch_device_type_cuda()` - Returns 1 (DeviceType::CUDA) 3. Grad mode functions (not supported in ExecuTorch): - `aoti_torch_grad_mode_is_enabled()` - Always returns false - `aoti_torch_grad_mode_set_enabled()` - Returns Ok for false, NotSupported for true Differential Revision: [D90126250](https://our.internmc.facebook.com/intern/diff/D90126250/) ghstack-source-id: 331923139 Pull Request resolved: #16457
1 parent 4f41a69 commit d4ea20b

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

backends/aoti/common_shims_slim.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,65 @@ inline AOTITorchError aoti_torch_get_device_index(
288288
return Error::Ok;
289289
}
290290

291+
// ============================================================
292+
// DType Constants - These return PyTorch ScalarType enum values
293+
// ============================================================
294+
295+
inline int32_t aoti_torch_dtype_float32() {
296+
return 6; // ScalarType::Float
297+
}
298+
299+
inline int32_t aoti_torch_dtype_bfloat16() {
300+
return 15; // ScalarType::BFloat16
301+
}
302+
303+
inline int32_t aoti_torch_dtype_int64() {
304+
return 4; // ScalarType::Long
305+
}
306+
307+
inline int32_t aoti_torch_dtype_int32() {
308+
return 3; // ScalarType::Int
309+
}
310+
311+
inline int32_t aoti_torch_dtype_int16() {
312+
return 2; // ScalarType::Short
313+
}
314+
315+
inline int32_t aoti_torch_dtype_int8() {
316+
return 1; // ScalarType::Char
317+
}
318+
319+
inline int32_t aoti_torch_dtype_bool() {
320+
return 11; // ScalarType::Bool
321+
}
322+
323+
// ============================================================
324+
// Device Type Constants
325+
// ============================================================
326+
327+
inline int32_t aoti_torch_device_type_cpu() {
328+
return 0; // DeviceType::CPU
329+
}
330+
331+
inline int32_t aoti_torch_device_type_cuda() {
332+
return 1; // DeviceType::CUDA
333+
}
334+
335+
// ============================================================
336+
// Grad Mode Functions (not supported in ExecuTorch)
337+
// ============================================================
338+
339+
inline bool aoti_torch_grad_mode_is_enabled() {
340+
return false; // ExecuTorch doesn't support autograd
341+
}
342+
343+
inline AOTITorchError aoti_torch_grad_mode_set_enabled(bool enabled) {
344+
if (enabled) {
345+
return Error::NotSupported; // Grad mode not supported in ExecuTorch
346+
}
347+
return Error::Ok;
348+
}
349+
291350
} // namespace aoti
292351
} // namespace backends
293352
} // namespace executorch

backends/aoti/tests/test_common_shims_slim.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,3 +589,44 @@ TEST_F(CommonShimsSlimTest, ConsistentPointerReturn) {
589589
EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr2), Error::Ok);
590590
EXPECT_EQ(strides_ptr1, strides_ptr2);
591591
}
592+
593+
// ============================================================================
594+
// DType Constants Tests
595+
// ============================================================================
596+
597+
TEST_F(CommonShimsSlimTest, DTypeConstants) {
598+
// Verify dtype constants match expected PyTorch ScalarType values
599+
EXPECT_EQ(aoti_torch_dtype_float32(), 6); // ScalarType::Float
600+
EXPECT_EQ(aoti_torch_dtype_bfloat16(), 15); // ScalarType::BFloat16
601+
EXPECT_EQ(aoti_torch_dtype_int64(), 4); // ScalarType::Long
602+
EXPECT_EQ(aoti_torch_dtype_int32(), 3); // ScalarType::Int
603+
EXPECT_EQ(aoti_torch_dtype_int16(), 2); // ScalarType::Short
604+
EXPECT_EQ(aoti_torch_dtype_int8(), 1); // ScalarType::Char
605+
EXPECT_EQ(aoti_torch_dtype_bool(), 11); // ScalarType::Bool
606+
}
607+
608+
// ============================================================================
609+
// Device Type Constants Tests
610+
// ============================================================================
611+
612+
TEST_F(CommonShimsSlimTest, DeviceTypeConstants) {
613+
EXPECT_EQ(aoti_torch_device_type_cpu(), 0); // DeviceType::CPU
614+
EXPECT_EQ(aoti_torch_device_type_cuda(), 1); // DeviceType::CUDA
615+
}
616+
617+
// ============================================================================
618+
// Grad Mode Tests
619+
// ============================================================================
620+
621+
TEST_F(CommonShimsSlimTest, GradModeIsEnabled) {
622+
// ExecuTorch doesn't support autograd, so should always return false
623+
EXPECT_EQ(aoti_torch_grad_mode_is_enabled(), false);
624+
}
625+
626+
TEST_F(CommonShimsSlimTest, GradModeSetEnabled) {
627+
// Setting to false should succeed
628+
EXPECT_EQ(aoti_torch_grad_mode_set_enabled(false), Error::Ok);
629+
630+
// Setting to true should fail (not supported in ExecuTorch)
631+
EXPECT_EQ(aoti_torch_grad_mode_set_enabled(true), Error::NotSupported);
632+
}

0 commit comments

Comments
 (0)