diff --git a/llama_patch/llama_patch.py b/llama_patch/llama_patch.py index 77609a6..ad2a9ae 100644 --- a/llama_patch/llama_patch.py +++ b/llama_patch/llama_patch.py @@ -19,23 +19,25 @@ class LlamaPatchException(Exception): def apply_patch(llmpatch, out, cwd, verbose): if verbose: logger.setLevel(logging.DEBUG) - + os.chdir(cwd) - + patch_content = llmpatch.read() patches = parse_patches(patch_content) try: + diff_output = [] for patch in patches: file_path, element_type, element_name, changes = patch logger.debug(f"Processing patch for {element_type} {element_name} in {file_path}") - apply_patch_to_file(file_path, element_type, element_name, changes) - + diff_output.append(generate_diff_for_patch(file_path, element_type, element_name, changes)) + if out.name != '': - generate_patch_file(out.name) + with open(out.name, 'w') as file: + file.write('\n'.join(diff_output)) else: - sys.stdout.write(generate_patch_diff()) - + sys.stdout.write('\n'.join(diff_output)) + except LlamaPatchException as e: logger.error(f"Error applying patch: {e}") sys.exit(1) @@ -46,49 +48,45 @@ def apply_patch(llmpatch, out, cwd, verbose): def parse_patches(patch_content): patch_pattern = re.compile(r'^--- (.+?)\n\?\? (\w+)\s*(\w+)?\n((?:[+\-].*?\n)+)', re.DOTALL | re.MULTILINE) matches = patch_pattern.findall(patch_content) - + if not matches: raise LlamaPatchException("No valid patches found in the provided patch content.") - + patches = [] for match in matches: path, element_type, element_name, changes = match changes = changes.strip().split('\n') patches.append((path, element_type, element_name, changes)) - + return patches -def apply_patch_to_file(file_path, element_type, element_name, changes): +def generate_diff_for_patch(file_path, element_type, element_name, changes): if not os.path.isfile(file_path): raise LlamaPatchException(f"File not found: {file_path}") with open(file_path, 'r') as file: original_code = file.read() - try: - if element_type in ["function", "def"]: - updated_code = apply_function_patch(original_code, element_name, changes) - elif element_type == "class": - updated_code = apply_class_patch(original_code, element_name, changes) - elif element_type == "struct": - updated_code = apply_struct_patch(original_code, element_name, changes) - elif element_type == "<<": - updated_code = prepend_code(original_code, changes) - elif element_type == ">>": - updated_code = append_code(original_code, changes) - elif element_type == "call": - updated_code = apply_call_patch(original_code, element_name, changes) - else: - raise LlamaPatchException(f"Unsupported element type: {element_type}") - - with open(file_path, 'w') as file: - file.write(updated_code) - except LlamaPatchException as e: - logger.error(f"Error applying patch for {element_type} {element_name} in {file_path}: {str(e)}") - logger.debug(f"Context code:\n{original_code}\nChanges:\n{changes}") - raise + if element_type in ["function", "def"]: + updated_code = apply_function_patch(original_code, element_name, changes) + elif element_type == "class": + updated_code = apply_class_patch(original_code, element_name, changes) + elif element_type == "struct": + updated_code = apply_struct_patch(original_code, element_name, changes) + elif element_type == "<<": + updated_code = prepend_code(original_code, changes) + elif element_type == ">>": + updated_code = append_code(original_code, changes) + elif element_type == "call": + updated_code = apply_call_patch(original_code, element_name, changes) + else: + raise LlamaPatchException(f"Unsupported element type: {element_type}") + original_lines = original_code.splitlines(keepends=True) + updated_lines = updated_code.splitlines(keepends=True) + diff = difflib.unified_diff(original_lines, updated_lines, fromfile=file_path, tofile=file_path) + return ''.join(diff) def apply_function_patch(original_code, function_name, changes): function_pattern = re.compile(rf'def {function_name}\(.*?\):\n((?:\s+.*?\n)*)', re.DOTALL) @@ -159,7 +157,7 @@ def append_code(original_code, changes): return updated_code def generate_new_code(context_code, changes, call=False): - context_lines = context_code.strip().split('\n') + context_lines = context_code.split('\n') new_code = [] try: @@ -179,25 +177,6 @@ def generate_new_code(context_code, changes, call=False): else: return "\n".join(new_code) - -def generate_patch_file(output_file): - diff = generate_patch_diff() - with open(output_file, 'w') as file: - file.write(diff) - -def generate_patch_diff(): - file_list = [f for f in os.listdir('.') if os.path.isfile(f)] - diff = [] - - for file_name in file_list: - with open(file_name, 'r') as file: - new_code = file.readlines() - with open(f'{file_name}.orig', 'r') as file: - old_code = file.readlines() - file_diff = difflib.unified_diff(old_code, new_code, fromfile=f'{file_name}.orig', tofile=file_name) - diff.extend(file_diff) - - return ''.join(diff) - if __name__ == "__main__": apply_patch() + diff --git a/tests/python/useless-exercise.llmpatch b/tests/python/useless-exercise.llmpatch new file mode 100644 index 0000000..68f2f50 --- /dev/null +++ b/tests/python/useless-exercise.llmpatch @@ -0,0 +1,32 @@ +--- useless-exercise.py +?? def add_numbers +- def add_numbers(a, b): +- return a + b + +?? def subtract_numbers +- def subtract_numbers(a, b): +- return a - b + +?? def multiply_numbers +- def multiply_numbers(a, b): +- return a * b + +?? def divide_numbers +- def divide_numbers(a, b): +- return a / b if b != 0 else None + +?? call add_numbers +- numbers[i] = add_numbers(numbers[i], CONST_ONE) ++ numbers[i] = numbers[i] + CONST_ONE + +?? call subtract_numbers +- numbers[i] = subtract_numbers(numbers[i], CONST_TWO) ++ numbers[i] = numbers[i] - CONST_TWO + +?? call multiply_numbers +- numbers[i] = multiply_numbers(numbers[i], CONST_THREE) ++ numbers[i] = numbers[i] * CONST_THREE + +?? call divide_numbers +- numbers[i] = divide_numbers(numbers[i], CONST_ONE) ++ numbers[i] = numbers[i] / CONST_ONE diff --git a/tests/python/useless-exercise.py b/tests/python/useless-exercise.py new file mode 100644 index 0000000..ac81673 --- /dev/null +++ b/tests/python/useless-exercise.py @@ -0,0 +1,142 @@ +# This is a very long Python program that essentially does nothing of importance. +# It includes a lot of repetitive and unnecessary code to increase its length. + +import sys +import time +import random + +# Define some constants +CONST_ONE = 1 +CONST_TWO = 2 +CONST_THREE = 3 + +# This function does nothing +def do_nothing(): + pass + +# This function adds two numbers and returns the result +def add_numbers(a, b): + return a + b + +# This function subtracts two numbers and returns the result +def subtract_numbers(a, b): + return a - b + +# This function multiplies two numbers and returns the result +def multiply_numbers(a, b): + return a * b + +# This function divides two numbers and returns the result +def divide_numbers(a, b): + return a / b if b != 0 else None + +# This function does nothing useful +def useless_function(): + for i in range(100): + do_nothing() + +# This class represents a useless object +class UselessClass: + def __init__(self): + self.value = 0 + + def increment_value(self): + self.value += 1 + + def decrement_value(self): + self.value -= 1 + + def get_value(self): + return self.value + +# Create a list of numbers +numbers = list(range(1000)) + +# Perform some meaningless operations on the list +for i in range(len(numbers)): + numbers[i] = add_numbers(numbers[i], CONST_ONE) + numbers[i] = subtract_numbers(numbers[i], CONST_TWO) + numbers[i] = multiply_numbers(numbers[i], CONST_THREE) + numbers[i] = divide_numbers(numbers[i], CONST_ONE) + +# Print the list of numbers +for number in numbers: + print(number) + +# Create an instance of the useless class +useless_instance = UselessClass() + +# Perform some operations on the useless instance +for i in range(100): + useless_instance.increment_value() + useless_instance.decrement_value() + print(useless_instance.get_value()) + +# Another useless function +def another_useless_function(): + for i in range(50): + do_nothing() + +# Yet another useless function +def yet_another_useless_function(): + for i in range(200): + do_nothing() + +# More useless code +for i in range(100): + another_useless_function() + yet_another_useless_function() + +# A function that simulates doing something important +def pretend_to_do_something_important(): + print("Pretending to do something important...") + time.sleep(1) + +# Execute the pretend function +for i in range(10): + pretend_to_do_something_important() + +# Another class that does nothing +class AnotherUselessClass: + def __init__(self): + self.data = [] + + def add_data(self, item): + self.data.append(item) + + def remove_data(self, item): + if item in self.data: + self.data.remove(item) + + def get_data(self): + return self.data + +# Create an instance of another useless class +another_useless_instance = AnotherUselessClass() + +# Perform some operations on the useless instance +for i in range(10): + another_useless_instance.add_data(i) + print(another_useless_instance.get_data()) + +for i in range(10): + another_useless_instance.remove_data(i) + print(another_useless_instance.get_data()) + +# Simulate a random meaningless task +def random_task(): + tasks = ["task1", "task2", "task3"] + selected_task = random.choice(tasks) + print(f"Performing {selected_task}...") + +# Execute the random task function multiple times +for i in range(10): + random_task() + +# A final useless function to end the program +def final_useless_function(): + print("This is the end of the useless program.") + +# Execute the final useless function +final_useless_function() + diff --git a/tests/rust/example.rs b/tests/rust/example.rs new file mode 100644 index 0000000..772cf38 --- /dev/null +++ b/tests/rust/example.rs @@ -0,0 +1,11 @@ +// example.rs + +struct MyStruct { + x: i32, + y: i32, +} + +fn my_function(a: i32) -> i32 { + a + 1 +} + diff --git a/tests/rust/example_added.rs b/tests/rust/example_added.rs new file mode 100644 index 0000000..772cf38 --- /dev/null +++ b/tests/rust/example_added.rs @@ -0,0 +1,11 @@ +// example.rs + +struct MyStruct { + x: i32, + y: i32, +} + +fn my_function(a: i32) -> i32 { + a + 1 +} + diff --git a/tests/rust/example_removed.rs b/tests/rust/example_removed.rs new file mode 100644 index 0000000..772cf38 --- /dev/null +++ b/tests/rust/example_removed.rs @@ -0,0 +1,11 @@ +// example.rs + +struct MyStruct { + x: i32, + y: i32, +} + +fn my_function(a: i32) -> i32 { + a + 1 +} + diff --git a/tests/rust/example_replaced.rs b/tests/rust/example_replaced.rs new file mode 100644 index 0000000..772cf38 --- /dev/null +++ b/tests/rust/example_replaced.rs @@ -0,0 +1,11 @@ +// example.rs + +struct MyStruct { + x: i32, + y: i32, +} + +fn my_function(a: i32) -> i32 { + a + 1 +} + diff --git a/tests/rust/patch_add_function.json b/tests/rust/patch_add_function.json new file mode 100644 index 0000000..2c921dd --- /dev/null +++ b/tests/rust/patch_add_function.json @@ -0,0 +1,7 @@ +{ + "file": "example.rs", + "type": "fn", + "name": "new_function", + "code": "fn new_function(b: i32) -> i32 {\n b * b\n}" +} + diff --git a/tests/rust/patch_remove_struct.json b/tests/rust/patch_remove_struct.json new file mode 100644 index 0000000..fc38075 --- /dev/null +++ b/tests/rust/patch_remove_struct.json @@ -0,0 +1,7 @@ +{ + "file": "example.rs", + "type": "struct", + "name": "MyStruct", + "code": null +} + diff --git a/tests/rust/patch_replace_function.json b/tests/rust/patch_replace_function.json new file mode 100644 index 0000000..c5d7a4f --- /dev/null +++ b/tests/rust/patch_replace_function.json @@ -0,0 +1,7 @@ +{ + "file": "example.rs", + "type": "fn", + "name": "my_function", + "code": "fn my_function(a: i32) -> i32 {\n a * 2\n}" +} + diff --git a/tests/test_llama_patch.py b/tests/test_llama_patch.py new file mode 100644 index 0000000..96e0606 --- /dev/null +++ b/tests/test_llama_patch.py @@ -0,0 +1,106 @@ +import pytest +import os +from click.testing import CliRunner +from llama_patch import apply_patch + +@pytest.fixture +def setup_files(): + if not os.path.exists('src'): + os.makedirs('src') + with open('src/main.py', 'w') as f: + f.write('''def target_function(x): + # Original implementation + return x + 1 +class OldClass: + def __init__(self): + pass +''') + + yield + + # Cleanup after tests + if os.path.exists('src/main.py'): + os.remove('src/main.py') + if os.path.exists('src'): + os.rmdir('src') + +@pytest.fixture +def patch_file(): + patch_content = '''--- src/main.py +?? def target_function +def target_function(x): +- # Original implementation ++ # New implementation +- return x + 1 ++ return x * 2 +--- src/main.py +?? class OldClass +- class OldClass: +- def __init__(self): +- pass +''' + with open('patch.diff', 'w') as f: + f.write(patch_content) + yield 'patch.diff' + os.remove('patch.diff') + +def test_apply_patch_function(setup_files, patch_file): + runner = CliRunner() + result = runner.invoke(apply_patch, ['--llmpatch', patch_file, '--cwd', '.']) + + assert result.exit_code == 0 + + with open('src/main.py', 'r') as f: + content = f.read() + assert 'def target_function(x):' in content + assert '# New implementation' in content + assert 'return x * 2' in content + assert 'class OldClass' not in content + +def test_apply_patch_stdin(setup_files): + patch_content = '''--- src/main.py +?? def target_function +def target_function(x): +- # Original implementation ++ # New implementation +- return x + 1 ++ return x * 2 +''' + + runner = CliRunner() + result = runner.invoke(apply_patch, ['--cwd', '.'], input=patch_content) + + assert result.exit_code == 0 + + with open('src/main.py', 'r') as f: + content = f.read() + assert 'def target_function(x):' in content + assert '# New implementation' in content + assert 'return x * 2' in content + +def test_prepend_append(setup_files): + patch_content = '''--- src/main.py +?? << ++import sys +--- src/main.py +?? >> ++print("End of file") +''' + + runner = CliRunner() + result = runner.invoke(apply_patch, ['--llmpatch', '-', '--cwd', '.'], input=patch_content) + + assert result.exit_code == 0 + + with open('src/main.py', 'r') as f: + content = f.read() + assert content.startswith('import sys\n') + assert content.endswith('print("End of file")\n') + +def test_verbose_flag(setup_files, patch_file): + runner = CliRunner() + result = runner.invoke(apply_patch, ['--llmpatch', patch_file, '--cwd', '.', '--verbose']) + + assert result.exit_code == 0 + assert 'DEBUG' in result.output + diff --git a/tests/try_rust.sh b/tests/try_rust.sh new file mode 100755 index 0000000..982fb40 --- /dev/null +++ b/tests/try_rust.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +# Create a clean environment +mkdir -p test_env +cd test_env +cp ../rust/example.rs . + +# Test case 1: Replace function +echo "Applying patch_replace_function.json..." +../../target/debug/llama-patch ../rust/patch_replace_function.json > patch_replace_function.diff +echo "Resulting diff for replacing function:" +cat patch_replace_function.diff + +# Check the output file +echo "Content of modified example.rs after replacing function:" +cat example.rs +cp example.rs ../rust/example_replaced.rs + +# Reset example.rs for the next test +cp ../rust/example.rs . + +# Test case 2: Remove struct +echo "Applying patch_remove_struct.json..." +../../target/debug/llama-patch ../rust/patch_remove_struct.json > patch_remove_struct.diff +echo "Resulting diff for removing struct:" +cat patch_remove_struct.diff + +# Check the output file +echo "Content of modified example.rs after removing struct:" +cat example.rs +cp example.rs ../rust/example_removed.rs + +# Reset example.rs for the next test +cp ../rust/example.rs . + +# Test case 3: Add new function +echo "Applying patch_add_function.json..." +../../target/debug/llama-patch ../rust/patch_add_function.json > patch_add_function.diff +echo "Resulting diff for adding new function:" +cat patch_add_function.diff + +# Check the output file +echo "Content of modified example.rs after adding new function:" +cat example.rs +cp example.rs ../rust/example_added.rs + +# Clean up +cd .. +rm -rf test_env