Skip to content

Commit 4b42871

Browse files
committed
Update base for Update on "[ET-VK] Introduce AOT operator registry"
## Changes Move the following files to the root directory of Vulkan backend: * `backends/vulkan/partitioner/supported_ops.py` -> `backends/vulkan/op_registry.py` * `backends/vulkan/_passes/custom_ops_defs.py` -> `backends/vulkan/custom_ops_lib.py` In the new `op_registry.py` file, the way operator features are specified is reworked to provide much more detail about the features of the operator implementation in Vulkan. See the new `OpFeatures` class for more details. An example of registering a new operator to the export flow is ``` update_features( [ exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.amax.default, exir_ops.edge.aten.amin.default, ] ) def register_reduce_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( uses_packed_dim=True, ) features.resize_fn = True def check_reduce_node(node: torch.fx.Node) -> bool: dim_list = node.args[1] assert isinstance(dim_list, list) if len(dim_list) != 1: return False keepdim = node.args[2] assert isinstance(keepdim, bool) if not keepdim: return False return True features.check_node_fn = check_reduce_node return features ``` ## Rationale The purpose of these changes is to centralize operator definitions so that there is a common source of truth about the capabilities of operator implementation in Vulkan. This way, the partitioner does not have to implement ad-hoc functions for specific operators (i.e. `is_valid_to_copy`) and graph transforms do not have to maintain their own operator metadata (`USES_WEIGHTS` in `insert_prepack_nodes`). Differential Revision: [D64915640](https://our.internmc.facebook.com/intern/diff/D64915640/) [ghstack-poisoned]
2 parents fdb7392 + 16b633b commit 4b42871

File tree

20 files changed

+327
-38
lines changed

20 files changed

+327
-38
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
export-D64151426
1+
bd5482c7c3e1197e10c46ff739027f917d9c1fcc

build/packaging/smoke_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
# will fail and the process will exit.
1616
from executorch.extension.pybindings import portable_lib # usort: skip
1717

18+
# Import custom ops. This requires portable_lib to be loaded first.
19+
from executorch.extension.llm.custom_ops import ( # noqa: F401, F403
20+
sdpa_with_kv_cache,
21+
) # usort: skip
22+
23+
# Import quantized ops. This requires portable_lib to be loaded first.
24+
from executorch.kernels import quantized # usort: skip # noqa: F401, F403
25+
1826
# Import this after importing the ExecuTorch pybindings. If the pybindings
1927
# links against a different torch.so than this uses, there will be a set of
2028
# symbol comflicts; the process will either exit now, or there will be issues
@@ -75,6 +83,15 @@ def main():
7583
assert len(ops) > 0, "Empty operator list"
7684
print(f"Found {len(ops)} operators; first element '{ops[0]}'")
7785

86+
# Make sure custom ops are registered.
87+
assert (
88+
"llama::sdpa_with_kv_cache" in ops
89+
), f"sdpa_with_kv_cache not registered, Got ops: {ops}"
90+
91+
# Make sure quantized ops are registered.
92+
assert (
93+
"quantized_decomposed::add.out" in ops
94+
), f"quantized_decomposed::add.out not registered, Got ops: {ops}"
7895
# Export LinearModel to .pte data.
7996
pte_data: bytes = export_linear_model()
8097

examples/models/llama/llama_transformer.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -265,21 +265,22 @@ class Attention(nn.Module):
265265
def __init__(self, args: ModelArgs, layer_id: int):
266266
super().__init__()
267267
self.use_kv_cache = args.use_kv_cache
268-
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
269-
assert args.n_heads % self.n_kv_heads == 0
268+
self.n_heads = args.n_heads
269+
self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads
270+
assert self.n_heads % self.n_kv_heads == 0
270271
model_parallel_size = 1
271-
self.n_local_heads = args.n_heads // model_parallel_size
272+
self.n_local_heads = self.n_heads // model_parallel_size
272273
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
273274
self.n_rep = self.n_local_heads // self.n_local_kv_heads
274-
self.head_dim = args.dim // args.n_heads
275+
self.head_dim = args.dim // self.n_heads
275276
self.max_batch_size = args.max_batch_size
276277
self.max_seq_len = args.max_seq_len
277278
self.dim = args.dim
278-
# args.dim = 4096, args.n_heads = 32, self.head_dim = 4096 / 32 = 125
279-
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
280-
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
281-
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
282-
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
279+
# self.dim = 4096, self.n_heads = 32, self.head_dim = 4096 / 32 = 125
280+
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
281+
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
282+
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
283+
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
283284

284285
self.layer_id = layer_id
285286

examples/models/llama/source_transformation/lora.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def __init__(
7070
precision=precision,
7171
scales_precision=scales_precision,
7272
)
73+
# TODO(lunwenh): Remove this once TorchAO's commit pin in ExecuTorch is updated to include this PR
74+
self.zeros = torch.zeros_like(self.zeros)
7375
self.adaptor = LoRAAdaptorLinear(
7476
in_features,
7577
out_features,

examples/models/llama/source_transformation/pre_quantization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
4646
precision=precision,
4747
scales_precision=scales_precision,
4848
)
49+
# TODO(lunwenh): Remove this once TorchAO's commit pin in ExecuTorch is updated to include this PR
50+
new_linear.zeros = torch.zeros_like(new_linear.zeros)
4951
return new_linear
5052

5153
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)

examples/models/llama/source_transformation/quantize.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def __init__(
375375
self.in_features = in_features
376376
self.out_features = out_features
377377
self.register_buffer(
378-
"weight", torch.empty((out_features, in_features), dtype=torch.int8)
378+
"weight", torch.zeros((out_features, in_features), dtype=torch.int8)
379379
)
380380
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
381381

@@ -448,18 +448,18 @@ def __init__(
448448
# currently storing unpacked int8 weights
449449
self.register_buffer(
450450
"weight",
451-
torch.empty((out_features, in_features), dtype=torch.int8),
451+
torch.zeros((out_features, in_features), dtype=torch.int8),
452452
)
453453
self.register_buffer(
454454
"scales",
455-
torch.empty(
455+
torch.zeros(
456456
(out_features),
457457
dtype=torch.float32,
458458
),
459459
)
460460
self.register_buffer(
461461
"zeros",
462-
torch.empty(
462+
torch.zeros(
463463
(out_features),
464464
dtype=torch.float32,
465465
),
@@ -632,15 +632,15 @@ def __init__(
632632
if not packed:
633633
self.register_buffer(
634634
"weight",
635-
torch.empty(
635+
torch.zeros(
636636
(vocab_size, embedding_dim), dtype=torch.int8, device=device
637637
),
638638
)
639639
else: # packed
640640
if bitwidth == 2:
641641
self.register_buffer(
642642
"weight",
643-
torch.empty(
643+
torch.zeros(
644644
(vocab_size, embedding_dim // 4),
645645
dtype=torch.uint8,
646646
device=device,
@@ -649,7 +649,7 @@ def __init__(
649649
elif bitwidth == 4:
650650
self.register_buffer(
651651
"weight",
652-
torch.empty(
652+
torch.zeros(
653653
(vocab_size, embedding_dim // 2),
654654
dtype=torch.uint8,
655655
device=device,

examples/qualcomm/oss_scripts/llama2/llama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@ def post_process():
564564
exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}")
565565

566566
if args.compile_only:
567+
compile(args)
567568
exit(f"Finish compile_only and save to {args.artifact}")
568569

569570
try:

exir/program/_program.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,6 @@ def to_executorch(
453453
def __deepcopy__(
454454
self, memo: Optional[Dict[int, Any]] = None
455455
) -> "ExirExportedProgram":
456-
457456
new_eep = ExirExportedProgram(
458457
copy.deepcopy(self.exported_program, memo),
459458
self.after_to_edge_passes,
@@ -764,7 +763,6 @@ def _replace_aten_ops_with_transformed_ops(
764763
program: ExportedProgram,
765764
partitioner,
766765
):
767-
768766
ops_to_not_decompose = set()
769767
partitioners = partitioner.get(name)
770768
if partitioners is None:
@@ -1020,9 +1018,9 @@ def to_edge_transform_and_lower(
10201018
aten_programs = programs
10211019

10221020
if not isinstance(partitioner, dict) and partitioner is not None:
1023-
partitioner = {"forward": partitioner}
1021+
partitioner = {name: partitioner for name in aten_programs.keys()}
10241022
elif partitioner is None:
1025-
partitioner = {"forward": []}
1023+
partitioner = {name: [] for name in aten_programs.keys()}
10261024

10271025
edge_manager = _gen_edge_manager_for_partitioners(
10281026
partitioner, aten_programs, config, constant_methods
@@ -1037,7 +1035,6 @@ def to_edge_transform_and_lower(
10371035
edge_manager = edge_manager.to_backend({name: curr_partitioner})
10381036

10391037
for name, program in edge_manager._edge_programs.items():
1040-
10411038
ops_set_to_not_decompose: Set[torch._ops.OpOverload] = set()
10421039
partitioners = partitioner.get(name, [])
10431040
for curr_partitioner in partitioners:

exir/schema.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,24 @@ class OptionalTensorList:
3535

3636
class TensorShapeDynamism(IntEnum):
3737
"""
38-
Check schema.fbs for explanations of this enum.
38+
Check program.fbs for explanations of this enum.
3939
"""
4040

4141
STATIC = 0
4242
DYNAMIC_BOUND = 1
4343
DYNAMIC_UNBOUND = 2
4444

4545

46+
@dataclass
47+
class ExtraTensorInfo:
48+
"""
49+
Check program.fbs for explanations of this enum.
50+
"""
51+
52+
mutable_data_segments_idx: Optional[int] = None
53+
fully_qualified_name: Optional[str] = None
54+
55+
4656
@dataclass
4757
class Tensor:
4858
scalar_type: ScalarType
@@ -54,8 +64,9 @@ class Tensor:
5464
data_buffer_idx: int
5565
allocation_info: Optional[AllocationDetails]
5666

57-
# check schema.fbs for explanations
67+
# check program.fbs for explanations.
5868
shape_dynamism: TensorShapeDynamism
69+
extra_tensor_info: Optional[ExtraTensorInfo] = None
5970

6071

6172
@dataclass

exir/tests/test_memory_format_ops_pass_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class MemoryFormatOpsPassTestUtils:
6969
def memory_format_test_runner(
7070
test_class: unittest.TestCase, test_set: MemoryFormatTestSet
7171
):
72-
before = export(test_set.module, test_set.sample_input)
72+
before = export(test_set.module, test_set.sample_input).run_decompositions({})
7373

7474
if test_set.use_xnnpack:
7575
epm = to_edge_transform_and_lower(

0 commit comments

Comments
 (0)