From be12417727ef3f2094f28706ab1448214539023b Mon Sep 17 00:00:00 2001 From: Olivia Liu Date: Mon, 28 Oct 2024 18:20:28 -0700 Subject: [PATCH] Add ScalarType 22 `BITS16` support in etdump gen and deserialization (#6504) Summary: Add support in etdump gen and etdump parsing for ScalarType 22 `BITS16`. Here's where the type is defined: https://www.internalfb.com/code/fbsource/[64a8013e3ebfbbaa70c220fc90102683d3181ad8]/fbcode/executorch/runtime/core/portable_type/scalar_type.h?lines=85 Note: `fbcode/executorch/schema/scalar_type.fbs`, `fbcode/executorch/devtools/etdump/scalar_type.fbs` and `fbcode/executorch/devtools/bundled_program/schema/scalar_type.fbs` are copies of each other to get around some build issue and there's a script to make sure they're in sync. Reviewed By: larryliu0820, Gasoonjia Differential Revision: D64812253 --- devtools/bundled_program/schema/scalar_type.fbs | 5 +++++ devtools/etdump/etdump_flatcc.cpp | 2 ++ devtools/etdump/scalar_type.fbs | 5 +++++ exir/scalar_type.py | 4 +++- exir/tensor.py | 2 +- schema/scalar_type.fbs | 5 +++++ 6 files changed, 21 insertions(+), 2 deletions(-) diff --git a/devtools/bundled_program/schema/scalar_type.fbs b/devtools/bundled_program/schema/scalar_type.fbs index a8da080c679..fc299ac691e 100644 --- a/devtools/bundled_program/schema/scalar_type.fbs +++ b/devtools/bundled_program/schema/scalar_type.fbs @@ -24,9 +24,14 @@ enum ScalarType : byte { QINT32 = 14, QUINT4X2 = 16, QUINT2X4 = 17, + BITS16 = 22, // Types currently not implemented. // COMPLEXHALF = 8, // COMPLEXFLOAT = 9, // COMPLEXDOUBLE = 10, // BFLOAT16 = 15, + // BITS1x8 = 18, + // BITS2x4 = 19, + // BITS4x2 = 20, + // BITS8 = 21, } diff --git a/devtools/etdump/etdump_flatcc.cpp b/devtools/etdump/etdump_flatcc.cpp index 4c05bb5acee..cfd1d2ae14d 100644 --- a/devtools/etdump/etdump_flatcc.cpp +++ b/devtools/etdump/etdump_flatcc.cpp @@ -55,6 +55,8 @@ executorch_flatbuffer_ScalarType_enum_t get_flatbuffer_scalar_type( return executorch_flatbuffer_ScalarType_DOUBLE; case exec_aten::ScalarType::Bool: return executorch_flatbuffer_ScalarType_BOOL; + case exec_aten::ScalarType::Bits16: + return executorch_flatbuffer_ScalarType_BITS16; default: ET_CHECK_MSG( 0, diff --git a/devtools/etdump/scalar_type.fbs b/devtools/etdump/scalar_type.fbs index a8da080c679..fc299ac691e 100644 --- a/devtools/etdump/scalar_type.fbs +++ b/devtools/etdump/scalar_type.fbs @@ -24,9 +24,14 @@ enum ScalarType : byte { QINT32 = 14, QUINT4X2 = 16, QUINT2X4 = 17, + BITS16 = 22, // Types currently not implemented. // COMPLEXHALF = 8, // COMPLEXFLOAT = 9, // COMPLEXDOUBLE = 10, // BFLOAT16 = 15, + // BITS1x8 = 18, + // BITS2x4 = 19, + // BITS4x2 = 20, + // BITS8 = 21, } diff --git a/exir/scalar_type.py b/exir/scalar_type.py index b789a09f3a8..5d41038610b 100644 --- a/exir/scalar_type.py +++ b/exir/scalar_type.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from enum import IntEnum @@ -26,4 +28,4 @@ class ScalarType(IntEnum): BFLOAT16 = 15 QUINT4x2 = 16 QUINT2x4 = 17 - Bits16 = 22 + BITS16 = 22 diff --git a/exir/tensor.py b/exir/tensor.py index d63ed5d2627..a40bef4e5e0 100644 --- a/exir/tensor.py +++ b/exir/tensor.py @@ -262,7 +262,7 @@ def memory_format_enum(memory_format: torch.memory_format) -> int: torch.qint32: ScalarType.QINT32, torch.bfloat16: ScalarType.BFLOAT16, torch.quint4x2: ScalarType.QUINT4x2, - torch.uint16: ScalarType.Bits16, + torch.uint16: ScalarType.BITS16, } diff --git a/schema/scalar_type.fbs b/schema/scalar_type.fbs index a8da080c679..fc299ac691e 100644 --- a/schema/scalar_type.fbs +++ b/schema/scalar_type.fbs @@ -24,9 +24,14 @@ enum ScalarType : byte { QINT32 = 14, QUINT4X2 = 16, QUINT2X4 = 17, + BITS16 = 22, // Types currently not implemented. // COMPLEXHALF = 8, // COMPLEXFLOAT = 9, // COMPLEXDOUBLE = 10, // BFLOAT16 = 15, + // BITS1x8 = 18, + // BITS2x4 = 19, + // BITS4x2 = 20, + // BITS8 = 21, }