Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exir/_serialize/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ runtime.python_library(
"_dataclass.py",
"_flatbuffer.py",
"_program.py",
"utils.py",
],
resources = {
"//executorch/schema:program.fbs": "program.fbs",
Expand Down
47 changes: 8 additions & 39 deletions exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import re

from dataclasses import dataclass
from typing import ClassVar, List, Literal, Optional, Tuple
from typing import ClassVar, List, Optional, Tuple

from executorch.exir._serialize._cord import Cord
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
Expand All @@ -21,6 +21,13 @@
_program_json_to_flatbuffer,
)

from executorch.exir._serialize.utils import (
_aligned_size,
_HEADER_BYTEORDER,
_pad_to,
_padding_required,
)

from executorch.exir.schema import (
BackendDelegateDataReference,
BackendDelegateInlineData,
Expand All @@ -33,12 +40,6 @@
from executorch.exir.tensor import ALIGNMENT


# Byte order of numbers written to program headers. Always little-endian
# regardless of the host system, since all commonly-used modern CPUs are little
# endian.
_HEADER_BYTEORDER: Literal["little"] = "little"


def _program_to_json(program: Program) -> str:
"""Returns the JSON representation of the given Program."""
return json.dumps(program, cls=_DataclassEncoder)
Expand All @@ -50,19 +51,6 @@ def _json_to_program(program_json: bytes) -> Program:
return _json_to_dataclass(json.loads(program_json), cls=Program)


def _padding_required(offset: int, alignment: int) -> int:
"""Returns the padding required to align `offset` to `alignment`."""
remainder: int = offset % alignment
if remainder != 0:
return alignment - remainder
return 0


def _aligned_size(input_size: int, alignment: int) -> int:
"""Returns input_size padded up to the next whole multiple of alignment."""
return input_size + _padding_required(input_size, alignment)


def _insert_flatbuffer_header(
flatbuffer_data: bytes, magic_regex: str, header_data: bytes
) -> bytes:
Expand Down Expand Up @@ -211,25 +199,6 @@ def to_bytes(self) -> bytes:
return data


def _pad_to(data: bytes, length: int) -> bytes:
"""Returns the input followed by enough zero bytes to become the requested length.

Args:
data: The data to pad.
length: The length of the returned data.
Returns:
The padded data.
Raises:
ValueError: If the requested length is less than the input length.
"""
if length < len(data):
raise ValueError(f"Data length {len(data)} > padded length {length}")
if length > len(data):
data = data + b"\x00" * (length - len(data))
assert len(data) == length
return data


def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]:
"""Returns the extended header of the program data, if present and valid."""
try:
Expand Down
42 changes: 42 additions & 0 deletions exir/_serialize/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict

from typing import Literal

# Byte order of numbers written to program headers. Always little-endian
# regardless of the host system, since all commonly-used modern CPUs are little
# endian.
_HEADER_BYTEORDER: Literal["little"] = "little"


def _pad_to(data: bytes, length: int) -> bytes:
"""Returns the input followed by enough zero bytes to become the requested length.

Args:
data: The data to pad.
length: The length of the returned data.
Returns:
The padded data.
Raises:
ValueError: If the requested length is less than the input length.
"""
if length < len(data):
raise ValueError(f"Data length {len(data)} > padded length {length}")
if length > len(data):
data = data + b"\x00" * (length - len(data))
assert len(data) == length
return data


def _padding_required(offset: int, alignment: int) -> int:
"""Returns the padding required to align `offset` to `alignment`."""
remainder: int = offset % alignment
if remainder != 0:
return alignment - remainder
return 0


def _aligned_size(input_size: int, alignment: int) -> int:
"""Returns input_size padded up to the next whole multiple of alignment."""
return input_size + _padding_required(input_size, alignment)
Loading