Skip to content

Commit 5794fff

Browse files
committed
Linting
1 parent 17212ed commit 5794fff

File tree

2 files changed

+109
-64
lines changed

2 files changed

+109
-64
lines changed

extension/embedded/export_add.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import torch
2-
from torch.export import export
32
from executorch.exir import to_edge
3+
from torch.export import export
4+
45

56
# Start with a PyTorch model that adds two input tensors (matrices)
67
class Add(torch.nn.Module):
7-
def __init__(self):
8-
super(Add, self).__init__()
8+
def __init__(self):
9+
super(Add, self).__init__()
10+
11+
def forward(self, x: torch.Tensor, y: torch.Tensor):
12+
return x + y
913

10-
def forward(self, x: torch.Tensor, y: torch.Tensor):
11-
return x + y
1214

1315
# 1. torch.export: Defines the program with the ATen operator set.
1416
aten_dialect = export(Add(), (torch.ones(1), torch.ones(1)))
@@ -22,4 +24,3 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
2224
# 4. Save the compiled .pte program
2325
with open("add.pte", "wb") as file:
2426
file.write(executorch_program.buffer)
25-

extension/embedded/pte_to_header.py

Lines changed: 102 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,22 @@
33
Generates .pte model, operator definitions, and header files
44
"""
55

6+
import argparse
67
import os
7-
import sys
88
import subprocess
9-
import argparse
9+
import sys
1010
from pathlib import Path
1111

12+
1213
def run_command(cmd, cwd=None, description=""):
1314
"""Run a command and handle errors"""
1415
print(f"Running: {description}")
1516
print(f"Command: {' '.join(cmd)}")
16-
17+
1718
try:
18-
result = subprocess.run(cmd, cwd=cwd, check=True, capture_output=True, text=True)
19+
result = subprocess.run(
20+
cmd, cwd=cwd, check=True, capture_output=True, text=True
21+
)
1922
print(f"✓ {description} completed successfully")
2023
if result.stdout:
2124
print(f"Output: {result.stdout}")
@@ -25,120 +28,160 @@ def run_command(cmd, cwd=None, description=""):
2528
print(f"Error: {e.stderr}")
2629
sys.exit(1)
2730

31+
2832
def main():
29-
parser = argparse.ArgumentParser(description="Build ExecuTorch ARM Hello World model")
30-
parser.add_argument("--executorch-root", default="~/optional/modules/lib/executorch",
31-
help="Path to ExecuTorch root directory")
32-
parser.add_argument("--model-name", default="add",
33-
help="Name of the model (default: add)")
34-
parser.add_argument("--clean", action="store_true",
35-
help="Clean generated files before building")
36-
33+
parser = argparse.ArgumentParser(
34+
description="Build ExecuTorch ARM Hello World model"
35+
)
36+
parser.add_argument(
37+
"--project-root",
38+
default="~/",
39+
help="Path to project root (should be zephry/../)",
40+
required=True,
41+
)
42+
parser.add_argument(
43+
"--model-name", default="add", help="Name of the model (default: add)"
44+
)
45+
parser.add_argument(
46+
"--clean", action="store_true", help="Clean generated files before building"
47+
)
48+
3749
args = parser.parse_args()
38-
50+
3951
# Paths
4052
script_dir = Path(__file__).parent
41-
project_root = script_dir.parent.parent.parent.parent.parent.parent # Go up to petriok root
42-
executorch_root = project_root / args.executorch_root
43-
example_files_dir = "/home/zephyruser/zephyr/samples/modules/executorch/arm/hello_world/example_files"
53+
project_root = Path(args.project_root)
54+
executorch_root = project_root / "modules" / "lib" / "executorch"
55+
4456
src_dir = script_dir / "src"
45-
57+
4658
model_name = args.model_name
4759
pte_file = f"{model_name}.pte"
4860
ops_def_file = "gen_ops_def.yml"
4961
header_file = "model_pte.h"
50-
62+
5163
print(f"Building ExecuTorch model: {model_name}")
5264
print(f"ExecuTorch root: {executorch_root}")
5365
print(f"Working directory: {script_dir}")
54-
66+
5567
# Clean previous build if requested
5668
if args.clean:
5769
files_to_clean = [pte_file, ops_def_file, src_dir / header_file]
5870
for file_path in files_to_clean:
5971
if Path(file_path).exists():
6072
Path(file_path).unlink()
6173
print(f"Cleaned: {file_path}")
62-
74+
6375
# Step 1: Generate the .pte model file
64-
export_script = os.path.join(example_files_dir, f"export_{model_name}.py")
65-
if not os.path.exists(export_script):
76+
export_script = (
77+
executorch_root / "extension" / "embedded" / "export_{model_name}.py"
78+
)
79+
if not export_script.exists():
6680
print(f"Error: Export script not found: {export_script}")
6781
sys.exit(1)
68-
82+
6983
try:
7084
run_command(
7185
[sys.executable, str(export_script)],
7286
cwd=script_dir,
73-
description="Generating .pte model file"
87+
description="Generating .pte model file",
7488
)
7589
except SystemExit:
76-
print(f"\n❌ Model generation failed. This is likely because PyTorch/ExecuTorch is not installed.")
77-
print(f"For now, using dummy model_pte.h for compilation testing.")
78-
print(f"To generate a real model, install PyTorch and ExecuTorch:")
79-
print(f" pip install torch")
80-
print(f" # Install ExecuTorch according to documentation")
81-
print(f" python build_model.py")
90+
print(
91+
"\n❌ Model generation failed. This is likely because PyTorch/ExecuTorch is not installed."
92+
)
93+
print("For now, using dummy model_pte.h for compilation testing.")
94+
print("To generate a real model, install PyTorch and ExecuTorch:")
95+
print(" pip install torch")
96+
print(" # Install ExecuTorch according to documentation")
97+
print(" python build_model.py")
8298
return
83-
99+
84100
if not Path(script_dir / pte_file).exists():
85101
print(f"Error: Model file {pte_file} was not generated")
86102
sys.exit(1)
87-
103+
88104
# Step 2: Generate operator definitions
89105

90-
gen_ops_script = "/home/zephyruser/optional/modules/lib/executorch/codegen/tools/gen_ops_def.py"
106+
gen_ops_script = (
107+
"/home/zephyruser/optional/modules/lib/executorch/codegen/tools/gen_ops_def.py"
108+
)
91109
if not os.path.exists(gen_ops_script):
92110
print(f"Error: gen_ops_def.py not found at {gen_ops_script}")
93111
sys.exit(1)
94-
112+
95113
run_command(
96-
[sys.executable, str(gen_ops_script),
97-
"--output_path", ops_def_file,
98-
"--model_file_path", pte_file],
114+
[
115+
sys.executable,
116+
str(gen_ops_script),
117+
"--output_path",
118+
ops_def_file,
119+
"--model_file_path",
120+
pte_file,
121+
],
99122
cwd=script_dir,
100-
description="Generating operator definitions"
123+
description="Generating operator definitions",
101124
)
102-
125+
103126
# Step 3: Convert .pte to header file
104-
#pte_to_header_script = executorch_root / "examples" / "arm" / "executor_runner" / "pte_to_header.py"
127+
# pte_to_header_script = executorch_root / "examples" / "arm" / "executor_runner" / "pte_to_header.py"
105128
pte_to_header_script = "/home/zephyruser/optional/modules/lib/executorch/examples/arm/executor_runner/pte_to_header.py"
106129
if not os.path.exists(pte_to_header_script):
107130
print(f"Error: pte_to_header.py not found at {pte_to_header_script}")
108131
sys.exit(1)
109-
132+
110133
run_command(
111-
[sys.executable, str(pte_to_header_script),
112-
"--pte", pte_file,
113-
"--outdir", "src"],
134+
[
135+
sys.executable,
136+
str(pte_to_header_script),
137+
"--pte",
138+
pte_file,
139+
"--outdir",
140+
"src",
141+
],
114142
cwd=script_dir,
115-
description="Converting .pte to header file"
143+
description="Converting .pte to header file",
116144
)
117-
145+
118146
# Step 4: Make the generated array const and remove section attribute
119147
header_path = src_dir / header_file
120148
if header_path.exists():
121149
content = header_path.read_text()
122-
150+
123151
# Remove section attribute and replace with Zephyr alignment macro
124152
import re
153+
125154
# Replace section+aligned pattern with Zephyr __ALIGN macro
126-
content = re.sub(r'__attribute__\s*\(\s*\(\s*section\s*\([^)]*\)\s*,\s*aligned\s*\(([^)]*)\)\s*\)\s*\)\s*', r'__ALIGN(\1) ', content)
127-
# Remove any remaining section-only attributes
128-
content = re.sub(r'__attribute__\s*\(\s*\(\s*section\s*\([^)]*\)\s*\)\s*\)\s*', '', content)
155+
content = re.sub(
156+
r"__attribute__\s*\(\s*\(\s*section\s*\([^)]*\)\s*,\s*aligned\s*\(([^)]*)\)\s*\)\s*\)\s*",
157+
r"__ALIGN(\1) ",
158+
content,
159+
)
160+
# Remove any remaining section-only attributes
161+
content = re.sub(
162+
r"__attribute__\s*\(\s*\(\s*section\s*\([^)]*\)\s*\)\s*\)\s*", "", content
163+
)
129164
# Also replace any standalone __attribute__((aligned(n))) with __ALIGN(n)
130-
content = re.sub(r'__attribute__\s*\(\s*\(\s*aligned\s*\(([^)]*)\)\s*\)\s*\)\s*', r'__ALIGN(\1) ', content)
131-
165+
content = re.sub(
166+
r"__attribute__\s*\(\s*\(\s*aligned\s*\(([^)]*)\)\s*\)\s*\)\s*",
167+
r"__ALIGN(\1) ",
168+
content,
169+
)
170+
132171
# Replace 'char model_pte_data[]' with 'const char model_pte_data[]'
133-
content = content.replace('char model_pte_data[]', 'const char model_pte_data[]')
172+
content = content.replace(
173+
"char model_pte_data[]", "const char model_pte_data[]"
174+
)
134175
# Also handle 'char model_pte[]' variant
135-
content = content.replace('char model_pte[]', 'const char model_pte[]')
136-
176+
content = content.replace("char model_pte[]", "const char model_pte[]")
177+
137178
header_path.write_text(content)
138-
print(f"✓ Made model data const and removed section attributes in {header_file}")
179+
print(
180+
f"✓ Made model data const and removed section attributes in {header_file}"
181+
)
139182
else:
140183
print(f"Warning: Header file {header_file} not found")
141-
184+
142185
print("\n=== Build Summary ===")
143186
print(f"✓ Generated: {pte_file}")
144187
print(f"✓ Generated: {ops_def_file}")
@@ -147,5 +190,6 @@ def main():
147190
print("1. Review gen_ops_def.yml and customize if needed")
148191
print("2. Build the Zephyr application with west build")
149192

193+
150194
if __name__ == "__main__":
151-
main()
195+
main()

0 commit comments

Comments
 (0)