Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@
# Get sample code to customize main.cpp

# Get impulses ID from model_variables.h
with open(os.path.join(target_dir, 'model-parameters/model_variables.h'), 'r') as file:
with open(os.path.join(target_dir, 'model-parameters/model_variables.h'), 'r', encoding='utf-8') as file:
file_content = file.read()
impulses_id_set = set(re.findall(r"impulse_(\d+)_(\d+)", file_content))
impulses_id = {}
Expand Down Expand Up @@ -216,7 +216,7 @@
{newline}"""

# Insert custom code in main.cpp
with open(os.path.join(target_dir, 'source/main.cpp'), 'r') as file1:
with open(os.path.join(target_dir, 'source/main.cpp'), 'r', encoding='utf-8') as file1:
main_template = file1.readlines()

idx = main_template.index("// get_signal declaration inserted here\n") +1
Expand All @@ -229,7 +229,7 @@
main_template[idx:idx] = callback_function_code

logger.info("Editing main.cpp")
with open(os.path.join(target_dir, 'source/main.cpp'), 'w') as file1:
with open(os.path.join(target_dir, 'source/main.cpp'), 'w', encoding='utf-8') as file1:
file1.writelines(main_template)
logger.info("main.cpp edited")

Expand Down
44 changes: 22 additions & 22 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
def insert_define_statement(file_path, define_statement):
logger.info(f"Inserting {define_statement} into {file_path}")
try:
with open(file_path, 'r') as file:
with open(file_path, 'r', encoding='utf-8') as file:
file_content = file.readlines()

include_idx = None
Expand All @@ -54,7 +54,7 @@ def insert_define_statement(file_path, define_statement):
if include_idx is not None and define_idx is not None:
file_content.insert(include_idx + 1, define_statement + '\n')

with open(file_path, 'w') as file:
with open(file_path, 'w', encoding='utf-8') as file:
file.writelines(file_content)

logger.info(f"Inserted {define_statement} into {file_path}")
Expand All @@ -68,7 +68,7 @@ def insert_define_statement(file_path, define_statement):
def insert_after_line(file_path, search_line, lines_to_insert):
logger.info(f"Inserting lines into {file_path} after {search_line}")
try:
with open(file_path, 'r') as file:
with open(file_path, 'r', encoding='utf-8') as file:
file_content = file.readlines()

insert_idx = None
Expand All @@ -83,7 +83,7 @@ def insert_after_line(file_path, search_line, lines_to_insert):
for line in reversed(lines_to_insert):
file_content.insert(insert_idx, line + '\n')

with open(file_path, 'w') as file:
with open(file_path, 'w', encoding='utf-8') as file:
file.writelines(file_content)

logger.info(f"Lines inserted into {file_path}")
Expand All @@ -97,12 +97,12 @@ def insert_after_line(file_path, search_line, lines_to_insert):
def replace_line(file_path, search_line, replacement_line):
logger.info(f"Replacing line in {file_path}: {search_line}")
try:
with open(file_path, 'r') as file:
with open(file_path, 'r', encoding='utf-8') as file:
file_content = file.readlines()

file_content = [line if search_line not in line else replacement_line + '\n' for line in file_content]

with open(file_path, 'w') as file:
with open(file_path, 'w', encoding='utf-8') as file:
file.writelines(file_content)

logger.info(f"Replaced line in {file_path}")
Expand All @@ -116,12 +116,12 @@ def replace_line(file_path, search_line, replacement_line):
def remove_line(file_path, search_string):
logger.info(f"Removing line from {file_path} containing {search_string}")
try:
with open(file_path, 'r') as file:
with open(file_path, 'r', encoding='utf-8') as file:
file_content = file.readlines()

file_content = [line for line in file_content if search_string not in line]

with open(file_path, 'w') as file:
with open(file_path, 'w', encoding='utf-8') as file:
file.writelines(file_content)

logger.info(f"Removed line containing {search_string} from {file_path}")
Expand All @@ -137,7 +137,7 @@ def remove_line(file_path, search_string):
def edit_file(file_path, patterns, suffix):
logger.info("Editing " + file_path)
try:
with open(file_path, 'r') as file:
with open(file_path, 'r', encoding='utf-8') as file:
file_content = file.read()

# function to add suffix to search patterns
Expand All @@ -156,7 +156,7 @@ def add_suffix(term):
logger.debug("pattern: " + pattern)
file_content = re.sub(pattern, add_suffix, file_content)

with open(file_path, 'w') as file:
with open(file_path, 'w', encoding='utf-8') as file:
file.write(file_content)

logger.debug(f"{file_path} edited")
Expand Down Expand Up @@ -290,9 +290,9 @@ def find_value(file_content, macro_string):
def merge_model_metadata(src_file, dest_file):
try:
# Open the first file for reading
with open(src_file, 'r') as file1:
with open(src_file, 'r', encoding='utf-8') as file1:
src_file_contents = file1.readlines()
with open(dest_file, 'r') as file2:
with open(dest_file, 'r', encoding='utf-8') as file2:
dest_file_contents = file2.readlines()

compare_version(src_file_contents, dest_file_contents)
Expand All @@ -315,7 +315,7 @@ def merge_model_metadata(src_file, dest_file):
dest_file_contents = find_common_type(src_file_contents, dest_file_contents, "EI_CLASSIFIER_OBJECT_DETECTION_LAST_LAYER", object_detection_types)
dest_file_contents = find_common_type(src_file_contents, dest_file_contents, "EI_CLASSIFIER_HAS_ANOMALY", anomaly_types)

with open(dest_file, 'w') as file2:
with open(dest_file, 'w', encoding='utf-8') as file2:
file2.writelines("".join(dest_file_contents))

except FileNotFoundError as e:
Expand All @@ -329,7 +329,7 @@ def merge_model_variables(src_file, dest_file):
include_line_str = '#include "tflite-model/tflite_learn'
try:
# Open the first file for reading
with open(src_file, 'r') as file1:
with open(src_file, 'r', encoding='utf-8') as file1:
file1_contents = file1.readlines()

start_line = None
Expand All @@ -348,7 +348,7 @@ def merge_model_variables(src_file, dest_file):
raise ValueError("Start or end string not found model_variables.h")

# Open the second file for reading
with open(dest_file, 'r') as file2:
with open(dest_file, 'r', encoding='utf-8') as file2:
file2_contents = file2.readlines()

# Find the line number for the insertion string in the second file
Expand All @@ -374,7 +374,7 @@ def merge_model_variables(src_file, dest_file):
file2_contents[insert_line:insert_line] = portion_to_copy

# Open the second file for writing and overwrite its contents
with open(dest_file, 'w') as file2:
with open(dest_file, 'w', encoding='utf-8') as file2:
file2.writelines(file2_contents)

logger.info("Portion copied and inserted successfully!")
Expand All @@ -385,17 +385,17 @@ def merge_model_variables(src_file, dest_file):
# Function to keep intersection of model_ops_define.h
def merge_model_ops(src_file, dest_file):
try:
with open(src_file, 'r') as file1:
with open(src_file, 'r', encoding='utf-8') as file1:
lines_file1 = [line.strip() for line in file1.readlines()]

with open(dest_file, 'r') as file2:
with open(dest_file, 'r', encoding='utf-8') as file2:
lines_file2 = [line.strip() for line in file2.readlines()]

# Find the intersection of lines
intersection = [line for line in lines_file1 if line in lines_file2]

# Write the intersection back to src_file
with open(dest_file, 'w') as file:
with open(dest_file, 'w', encoding='utf-8') as file:
for line in intersection:
file.write(line + '\n')

Expand All @@ -406,10 +406,10 @@ def merge_model_ops(src_file, dest_file):

def merge_tflite_resolver(src_file, dest_file):
try:
with open(src_file, 'r') as file1:
with open(src_file, 'r', encoding='utf-8') as file1:
lines_file1 = [line.strip() for line in file1.readlines()]

with open(dest_file, 'r') as file2:
with open(dest_file, 'r', encoding='utf-8') as file2:
lines_file2 = [line.strip() for line in file2.readlines()]

# Find the union of lines
Expand All @@ -421,7 +421,7 @@ def merge_tflite_resolver(src_file, dest_file):
union.append(line2)

# Write the union back to src_file
with open(dest_file, 'w') as file:
with open(dest_file, 'w', encoding='utf-8') as file:
for line in union:
if line.startswith('resolver.') and not line.endswith('\\'):
line = line + ' \\'
Expand Down