Skip to content

Commit b9a5a44

Browse files
authored
Merge branch 'main' into export-D83113293
2 parents 68d2629 + 9283b4e commit b9a5a44

File tree

18 files changed

+316
-377
lines changed

18 files changed

+316
-377
lines changed

.ci/scripts/unittest-buck2.sh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,17 @@ BUILDABLE_KERNELS_PRIM_OPS_TARGETS=$(buck2 query //kernels/prim_ops/... | grep -
3535
for op in "build" "test"; do
3636
buck2 $op $BUILDABLE_OPTIMIZED_OPS \
3737
//examples/selective_build:select_all_dtype_selective_lib_portable_lib \
38+
//extension/llm/custom_ops/spinquant/test:fast_hadamard_transform_test \
39+
//extension/llm/runner/test:test_multimodal_input \
40+
//extension/llm/runner/test:test_generation_config \
3841
//kernels/portable/... \
3942
$BUILDABLE_KERNELS_PRIM_OPS_TARGETS //runtime/backend/... //runtime/core/... \
4043
//runtime/executor: //runtime/kernel/... //runtime/platform/...
4144
done
4245

4346
# Build only without testing
44-
buck2 build //codegen/tools/... # Needs torch for testing which we don't have in our OSS buck setup.
47+
buck2 build //codegen/tools/... \
48+
//extension/llm/runner/io_manager:io_manager \
49+
//extension/llm/modules/... \
50+
//extension/llm/runner:multimodal_runner_lib \
51+
//extension/llm/runner:text_decoder_runner

.lintrunner.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ exclude_patterns = [
206206
'**/*.png',
207207
'**/*.webp',
208208
'**/*.jpeg',
209+
'**/*.mp3',
209210
'**/*.mp4',
210211
'**/*.pte',
211212
'**/*.pth',
@@ -216,6 +217,8 @@ exclude_patterns = [
216217
'**/*.jpg',
217218
'**/*.jar',
218219
'**/*.gif',
220+
'extension/llm/tokenizers',
221+
'extension/llm/tokenizers/**',
219222
# File contains @generated
220223
'extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h',
221224
'extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_special_unstrided_cpu.h',

backends/arm/operator_support/to_dim_order_copy_support.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Declare operator support for ``_to_dim_order_copy`` in TOSA.
6+
7+
Provide dtype-compatibility checks for casting when converting to a specific
8+
dimension order. Supported input/output dtype pairs depend on the active TOSA
9+
profile (integer and/or float).
10+
11+
"""
512

613
# pyre-unsafe
714
import copy
@@ -25,6 +32,16 @@
2532

2633
@register_tosa_support_check
2734
class ToCopySupported(SupportedTOSAOperatorCheck):
35+
"""Provide TOSA support check for ``_to_dim_order_copy``.
36+
37+
Attributes:
38+
SUPPORTED_INT_PROFILE_DTYPES (dict[torch.dtype, list[torch.dtype]]):
39+
Allowed output dtypes for each integer input dtype.
40+
SUPPORTED_FP_PROFILE_DTYPES (dict[torch.dtype, list[torch.dtype]]):
41+
Allowed output dtypes for each floating input dtype.
42+
43+
"""
44+
2845
targets = [
2946
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
3047
]
@@ -40,21 +57,31 @@ def _merge_supported_types(
4057
dtypes1: SupportedTypeDict,
4158
dtypes2: SupportedTypeDict,
4259
) -> SupportedTypeDict:
60+
"""Return a merged mapping of supported dtype transitions.
61+
62+
Args:
63+
dtypes1 (dict[torch.dtype, list[torch.dtype]]): Base mapping.
64+
dtypes2 (dict[torch.dtype, list[torch.dtype]]): Mapping to merge in.
65+
66+
Returns:
67+
dict[torch.dtype, list[torch.dtype]]: Combined mapping.
68+
69+
"""
4370
merged_dtypes = copy.deepcopy(
4471
dtypes1
45-
) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_TYPES
72+
) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_PROFILE_DTYPES
4673
for k, v in dtypes2.items():
4774
merged_dtypes[k] = merged_dtypes.get(k, []) + v
4875
return merged_dtypes
4976

50-
SUPPORTED_INT_TYPES: SupportedTypeDict = {
77+
SUPPORTED_INT_PROFILE_DTYPES: SupportedTypeDict = {
5178
torch.bool: [torch.bool, torch.int8, torch.int16, torch.int32],
5279
torch.int8: [torch.bool, torch.int8, torch.int16, torch.int32],
5380
torch.int16: [torch.bool, torch.int8, torch.int16, torch.int32],
5481
torch.int32: [torch.bool, torch.int8, torch.int16, torch.int32],
5582
torch.int64: [torch.bool, torch.int8, torch.int16, torch.int32],
5683
}
57-
SUPPORTED_FLOAT_TYPES: SupportedTypeDict = {
84+
SUPPORTED_FP_PROFILE_DTYPES: SupportedTypeDict = {
5885
torch.int8: [torch.int8, torch.float16, torch.bfloat16, torch.float32],
5986
torch.int16: [torch.int16, torch.float16, torch.bfloat16, torch.float32],
6087
torch.int32: [torch.int32, torch.float16, torch.bfloat16, torch.float32],
@@ -92,22 +119,25 @@ def _merge_supported_types(
92119
torch.float32,
93120
],
94121
}
95-
ALL_SUPPORTED_TYPES = _merge_supported_types(
96-
SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES
97-
)
98122

99123
def is_node_tosa_supported(
100124
self, node: fx.Node, tosa_spec: TosaSpecification
101125
) -> bool:
126+
"""Return True if the node is supported by TOSA.
127+
128+
Check FakeTensor metadata, validate input dtype is supported for the
129+
active profile, and ensure the output dtype is allowed for the given
130+
input dtype.
102131
132+
"""
103133
supported_dtypes: SupportedTypeDict = {}
104134
if tosa_spec.support_integer():
105135
supported_dtypes = self._merge_supported_types(
106-
self.SUPPORTED_INT_TYPES, supported_dtypes
136+
self.SUPPORTED_INT_PROFILE_DTYPES, supported_dtypes
107137
)
108138
if tosa_spec.support_float():
109139
supported_dtypes = self._merge_supported_types(
110-
self.SUPPORTED_FLOAT_TYPES, supported_dtypes
140+
self.SUPPORTED_FP_PROFILE_DTYPES, supported_dtypes
111141
)
112142

113143
if len(node.all_input_nodes) != 1:

0 commit comments

Comments
 (0)