Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions backends/arm/process_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
from torch._export.utils import (
get_buffer,
get_lifted_tensor_constant,
get_param,
is_buffer,
is_lifted_tensor_constant,
is_param,
)
from torch.export.exported_program import ExportedProgram


Expand Down Expand Up @@ -99,8 +107,7 @@ def process_inputs_to_parameters(
f"Failed processing parameter placeholder: {node.name}. "
"Is the original torch function supported?"
) from e
parameter_name = edge_program.graph_signature.inputs_to_parameters[tosa_arg.name]
parameter_data = edge_program.state_dict[parameter_name]
parameter_data = get_param(edge_program, node)

assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor"
parameter_values = parameter_data.detach().numpy()
Expand Down Expand Up @@ -128,8 +135,7 @@ def process_inputs_to_buffers(
f"Failed processing buffer placeholder: {node.name}. "
"Is the original torch function supported?"
) from e
buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
buffer_data = edge_program.state_dict[buffer_name]
buffer_data = get_buffer(edge_program, node)

assert isinstance(buffer_data, torch.Tensor), "Expect Attr to be tensor"
buffer_values = buffer_data.detach().numpy()
Expand All @@ -156,11 +162,8 @@ def process_inputs_to_lifted_tensor_constants(
f"Failed processing lifted tensor constant placeholder: {node.name}. "
"Is the original torch function supported?"
) from e
tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[
tosa_arg.name
]
tensor = edge_program.tensor_constants[tensor_name]
tensor_data = tensor.detach().numpy()
tensor = get_lifted_tensor_constant(edge_program, node)
tensor_data = tensor.detach().numpy() # type: ignore[union-attr]

tosa_graph.addConst(
tensor_data.shape, tosa_arg.dtype, tensor_data, name=tosa_arg.name
Expand All @@ -179,11 +182,11 @@ def process_placeholder(

if node.name in edge_program.graph_signature.user_inputs:
process_inputs(node, tosa_graph, tosa_spec)
elif node.name in edge_program.graph_signature.inputs_to_parameters:
elif is_param(edge_program, node):
process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec)
elif node.name in edge_program.graph_signature.inputs_to_buffers:
elif is_buffer(edge_program, node):
process_inputs_to_buffers(node, tosa_graph, edge_program)
elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants:
elif is_lifted_tensor_constant(edge_program, node):
process_inputs_to_lifted_tensor_constants(node, tosa_graph, edge_program)
elif node.name in edge_program.graph_signature.inputs_to_lifted_custom_objs:
raise NotImplementedError(
Expand Down
49 changes: 49 additions & 0 deletions backends/arm/test/misc/test_non_persistent_buffers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn

from executorch.backends.arm.test.common import parametrize
from executorch.backends.arm.test.tester.test_pipeline import (
TosaPipelineBI,
TosaPipelineMI,
)


class NonPersistentBuffer(nn.Module):
"""
Min code version registering a non-persistent input buffer.
"""

def __init__(self):
super().__init__()
self.register_buffer("test_buff", torch.rand(2, 2, 2, 2), persistent=False)

def forward(self, x):
return x - self.test_buff


test_input = {"input": (torch.ones(2, 2, 2, 2),)}

input_t = tuple[torch.Tensor]


@parametrize("test_data", test_input)
def test_non_persistent_buffer_MI(test_data: input_t):
"""
Test validates Arm backend handling of non-persistent buffers
and ensures that there are no asserts or errors when they are used.
"""
TosaPipelineMI[input_t](NonPersistentBuffer(), test_data, "").run()


@parametrize("test_data", test_input)
def test_non_persistent_buffer_BI(test_data: input_t):
"""
Test validates Arm backend handling of non-persistent buffers
and ensures that there are no asserts or errors when they are used.
"""
TosaPipelineBI[input_t](NonPersistentBuffer(), test_data, "").run()
18 changes: 0 additions & 18 deletions backends/arm/test/models/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,24 +79,6 @@ def prepare_model(self):

llama_model, llama_inputs, llama_meta = get_llama_model(args)

# TODO: Remove workaround since attention mask should not be persistent,
# it only works if input shape is always the same
freqs_c = "freqs_cos"
freqs_s = "freqs_sin"
for i in range(llama_model.n_layers):
val = llama_model.layers[i].attention.get_buffer("mask")
llama_model.layers[i].attention.register_buffer(
"mask", val, persistent=True
)
val = llama_model.layers[i].attention.rope.get_buffer(freqs_c)
llama_model.layers[i].attention.rope.register_buffer(
freqs_c, val, persistent=True
)
val = llama_model.layers[i].attention.rope.get_buffer(freqs_s)
llama_model.layers[i].attention.rope.register_buffer(
freqs_s, val, persistent=True
)

return llama_model, llama_inputs, llama_meta

def test_llama_tosa_MI(self):
Expand Down
Loading