diff --git a/extension/apple/ExecuTorch/Internal/ExecuTorchUtils.h b/extension/apple/ExecuTorch/Internal/ExecuTorchUtils.h index e53999cd355..d514a101383 100644 --- a/extension/apple/ExecuTorch/Internal/ExecuTorchUtils.h +++ b/extension/apple/ExecuTorch/Internal/ExecuTorchUtils.h @@ -7,3 +7,40 @@ */ #import + +#ifdef __cplusplus + +#import + +namespace executorch::extension::utils { +using namespace aten; + +/** + * Deduces the scalar type for a given NSNumber based on its type encoding. + * + * @param number The NSNumber instance whose scalar type is to be deduced. + * @return The corresponding ScalarType. + */ +static inline ScalarType deduceScalarType(NSNumber *number) { + auto type = [number objCType][0]; + type = (type >= 'A' && type <= 'Z') ? type + ('a' - 'A') : type; + if (type == 'c') { + return ScalarType::Byte; + } else if (type == 's') { + return ScalarType::Short; + } else if (type == 'i') { + return ScalarType::Int; + } else if (type == 'q' || type == 'l') { + return ScalarType::Long; + } else if (type == 'f') { + return ScalarType::Float; + } else if (type == 'd') { + return ScalarType::Double; + } + ET_CHECK_MSG(false, "Unsupported type: %c", type); + return ScalarType::Undefined; +} + +} // namespace executorch::extension::utils + +#endif // __cplusplus