diff --git a/tests/examples/advanced_rewrite.py b/tests/examples/advanced_rewrite.py new file mode 100644 index 0000000..07a1f69 --- /dev/null +++ b/tests/examples/advanced_rewrite.py @@ -0,0 +1,217 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +from oqd_compiler_infrastructure.interface import TypeReflectBaseModel +from oqd_compiler_infrastructure.rule import RewriteRule +from oqd_compiler_infrastructure.walk import Post + +# AST data structures (same as before) +class Expression(TypeReflectBaseModel): + """Base class for arithmetic expressions. + + This class serves as the foundation for all expression types in the AST. + """ + pass + +class Number(Expression): + """Represents a numeric literal. + + Attributes: + value (float): The numeric value of the literal. + """ + value: float + +class Variable(Expression): + """Represents a variable in an expression. + + Attributes: + name (str): The name of the variable. + """ + name: str + +class BinaryOp(Expression): + """Represents a binary operation. + + Attributes: + op (str): The operator (e.g., '+', '-', '*', '/'). + left (Expression): The left operand. + right (Expression): The right operand. + """ + op: str # '+', '-', '*', '/' + left: Expression + right: Expression + +class AdvancedAlgebraicSimplifier(RewriteRule): + """Applies advanced algebraic simplification rules. + + Rules implemented: + - x - x = 0 + - x + (-x) = 0 + - x * x = x^2 + - (x + y) - y = x + - Distributive law: a * (b + c) = (a * b) + (a * c) + """ + # Implements additional algebraic identities like subtraction and distribution + + def map_BinaryOp(self, model): + """Apply advanced simplification to binary operations. + + Args: + model (BinaryOp): The binary operation to simplify. + + Returns: + Expression: The simplified expression or original if no rule applies. + """ + # x - x = 0 + # Rule: subtracting identical terms yields zero + if model.op == '-' and self._expressions_equal(model.left, model.right): + return Number(value=0) + + # x + (-x) = 0 + # Rule: x + (-1 * x) => 0 + if (model.op == '+' and + isinstance(model.right, BinaryOp) and + model.right.op == '*' and + isinstance(model.right.left, Number) and + model.right.left.value == -1 and + self._expressions_equal(model.left, model.right.right)): + return Number(value=0) + + # Distributive law: a * (b + c) = (a * b) + (a * c) + # Rule: a * (b + c) => (a*b) + (a*c) + if (model.op == '*' and + isinstance(model.right, BinaryOp) and + model.right.op in ['+', '-']): + # a * (b + c) -> (a * b) + (a * c) + return BinaryOp( + op=model.right.op, + left=BinaryOp(op='*', left=model.left, right=model.right.left), + right=BinaryOp(op='*', left=model.left, right=model.right.right) + ) + + # (x + y) - y = x + # Rule: (x + y) - y => x + if (model.op == '-' and + isinstance(model.left, BinaryOp) and + model.left.op == '+' and + self._expressions_equal(model.left.right, model.right)): + return model.left.left + + return model + + def _expressions_equal(self, expr1, expr2): + """Check if two expressions are structurally equal. + + Args: + expr1 (Expression): The first expression to compare. + expr2 (Expression): The second expression to compare. + + Returns: + bool: True if structurally equal, False otherwise. + """ + # Compare types and recursively compare sub-expressions + if type(expr1) != type(expr2): + return False + + if isinstance(expr1, Number): + return expr1.value == expr2.value + + if isinstance(expr1, Variable): + return expr1.name == expr2.name + + if isinstance(expr1, BinaryOp): + return (expr1.op == expr2.op and + self._expressions_equal(expr1.left, expr2.left) and + self._expressions_equal(expr1.right, expr2.right)) + + return False + +def print_expr(expr): + """Convert an expression into a readable string. + + Args: + expr (Expression): The expression to format. + + Returns: + str: A string representation of the expression. + """ + # Convert AST nodes into parenthesized infix notation + if isinstance(expr, Number): + return str(expr.value) + elif isinstance(expr, Variable): + return expr.name + elif isinstance(expr, BinaryOp): + return f"({print_expr(expr.left)} {expr.op} {print_expr(expr.right)})" + return str(expr) + +def main(): + """Main function to demonstrate advanced algebraic simplification.""" + # Prepare test cases and run the AdvancedAlgebraicSimplifier + # Create test expressions + test_cases = [ + # x - x = 0 + BinaryOp( + op='-', + left=Variable(name='x'), + right=Variable(name='x') + ), + + # x + (-1 * x) = 0 + BinaryOp( + op='+', + left=Variable(name='x'), + right=BinaryOp( + op='*', + left=Number(value=-1), + right=Variable(name='x') + ) + ), + + # a * (b + c) -> (a * b) + (a * c) + BinaryOp( + op='*', + left=Variable(name='a'), + right=BinaryOp( + op='+', + left=Variable(name='b'), + right=Variable(name='c') + ) + ), + + # (x + y) - y = x + BinaryOp( + op='-', + left=BinaryOp( + op='+', + left=Variable(name='x'), + right=Variable(name='y') + ), + right=Variable(name='y') + ) + ] + + # Create simplifier with Post traversal + simplifier = Post(AdvancedAlgebraicSimplifier()) + + # Run simplifications + print("Advanced Algebraic Simplifications:") + for i, expr in enumerate(test_cases, 1): + print(f"\nTest Case {i}:") + print(f"Original: {print_expr(expr)}") + result = simplifier(expr) + print(f"Simplified: {print_expr(result)}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/examples/conversion_pass.py b/tests/examples/conversion_pass.py new file mode 100644 index 0000000..bd8f306 --- /dev/null +++ b/tests/examples/conversion_pass.py @@ -0,0 +1,116 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List +from oqd_compiler_infrastructure.interface import TypeReflectBaseModel +from oqd_compiler_infrastructure.rule import ConversionRule + +# Same data structures as before +class Expression(TypeReflectBaseModel): + """Base class for arithmetic expressions. + + This class serves as the foundation for all expression types in the AST. + """ + pass + +class Number(Expression): + """Represents a numeric literal. + + Attributes: + value (float): The numeric value of the literal. + """ + value: float + +class BinaryOp(Expression): + """Represents a binary operation. + + Attributes: + op (str): The operator for the binary operation (e.g., '+', '-', '*', '/'). + left (Expression): The left operand of the operation. + right (Expression): The right operand of the operation. + """ + op: str # '+', '-', '*', '/' + left: Expression + right: Expression + +# Constant folding using ConversionRule +class ConstantFoldingConversion(ConversionRule): + """A conversion pass that folds constant expressions. + + This pass uses ConversionRule dispatch to evaluate binary operations + with constant operands and replace them with computed values. + """ + + def map_Number(self, model, operands=None): + # Numbers are terminal constants; no folding needed + """Returns the number as is since it's already folded. + + Args: + model (Number): The number to process. + operands (dict, optional): The processed operands (unused). + + Returns: + Number: The same number model. + """ + # Numbers are already folded + return model + + def map_BinaryOp(self, model, operands=None): + # Fold binary operations: evaluate when both operands are numeric constants + """Folds binary operations when both operands are numbers. + + Args: + model (BinaryOp): The binary operation model. + operands (dict, optional): Additional processed operands (unused). + + Returns: + Number or BinaryOp: Folded number if both operands numeric, otherwise new BinaryOp. + """ + # Recursively process operands first + left = self(model.left) + right = self(model.right) + + # If both are numbers, fold them + if isinstance(left, Number) and isinstance(right, Number): + if model.op == '+': + return Number(value=left.value + right.value) + elif model.op == '*': + return Number(value=left.value * right.value) + elif model.op == '-': + return Number(value=left.value - right.value) + elif model.op == '/': + if right.value != 0: + return Number(value=left.value / right.value) + + # Otherwise, rebuild the BinaryOp with possibly simplified children + return BinaryOp(op=model.op, left=left, right=right) + +def main(): + """Demonstrates constant folding pass on a sample expression.""" + # Create an expression: (2 + 3) * (4 + 5) + expr = BinaryOp( + op='*', + left=BinaryOp(op='+', left=Number(value=2), right=Number(value=3)), + right=BinaryOp(op='+', left=Number(value=4), right=Number(value=5)) + ) + + # Create and run our pass + pass_ = ConstantFoldingConversion() + result = pass_(expr) + + print(f"Original expression: {expr}") + print(f"After constant folding: {result}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/examples/dispatch_example.py b/tests/examples/dispatch_example.py new file mode 100644 index 0000000..9327004 --- /dev/null +++ b/tests/examples/dispatch_example.py @@ -0,0 +1,128 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +from oqd_compiler_infrastructure.interface import TypeReflectBaseModel +from oqd_compiler_infrastructure.rule import ConversionRule +from pydantic import model_validator + +# Base class for shapes +class Shape(TypeReflectBaseModel): + # Base class for all shape models used in dispatch demonstration + """Base class for shapes. + + Provides the foundation for all shape types in the system. + """ + pass + +# Different shape types +class Circle(Shape): + """Represents a circle. + + Attributes: + radius (float): The radius of the circle. + """ + # Radius attribute for circle dimensions + radius: float + +class Rectangle(Shape): + """Represents a rectangle. + + Attributes: + width (float): The width of the rectangle. + height (float): The height of the rectangle. + """ + # Width and height attributes for rectangle dimensions + width: float + height: float + +class Square(Rectangle): # Inherits from Rectangle + """Represents a square, a special case of a rectangle. + + Attributes: + side (float): The side length of the square. + """ + # Side length attribute; width/height derived in validator + side: float + + @model_validator(mode='before') + @classmethod + def set_dimensions(cls, data): + """Set width and height fields based on side before validation. + + Ensures squares have equal width and height based on the side attribute. + """ + if isinstance(data, dict) and 'side' in data: + data['width'] = data['side'] + data['height'] = data['side'] + return data + +# Demonstration of dispatch +class ShapePrinter(ConversionRule): + """Demonstrates dispatch behavior for different shapes. + + Provides handlers for various shape types and shows how inheritance affects + method dispatch. + """ + + def map_Shape(self, model, operands=None): + """Handles generic Shape models. + + Returns a default message for unknown shapes. + """ + # Default handler for unknown shape types + return "Unknown shape" + + def map_Circle(self, model, operands=None): + """Handles Circle models. + + Returns a descriptive string including the radius. + """ + # Return descriptive info for circle + return f"Circle with radius {model.radius}" + + def map_Rectangle(self, model, operands=None): + """Handles Rectangle models. + + Returns a descriptive string including width and height. + """ + # Return descriptive info for rectangle dimensions + return f"Rectangle {model.width}x{model.height}" + + # Note: No specific handler for Square - will use Rectangle's handler + +def main(): + # Entry point: create shapes and demonstrate the dispatch mechanism + """Main entrypoint to demonstrate shape dispatch.""" + # Create some shapes + circle = Circle(radius=5.0) + rectangle = Rectangle(width=4.0, height=6.0) + square = Square(side=3.0) + + # Create our printer + printer = ShapePrinter() + + # Demonstrate dispatch + print("Demonstrating method dispatch:") + print(f"Circle: {printer(circle)}") + print(f"Rectangle: {printer(rectangle)}") + print(f"Square: {printer(square)} # Uses Rectangle's handler through inheritance") + + # Show the MRO (Method Resolution Order) for Square + print("\nSquare's Method Resolution Order:") + for cls in Square.__mro__: + print(f"- {cls.__name__}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/examples/multi_target_compiler.py b/tests/examples/multi_target_compiler.py new file mode 100644 index 0000000..48ae837 --- /dev/null +++ b/tests/examples/multi_target_compiler.py @@ -0,0 +1,229 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List +from oqd_compiler_infrastructure.interface import TypeReflectBaseModel +from oqd_compiler_infrastructure.rule import ConversionRule + +# Our AST data structures (same as before) +class Expression(TypeReflectBaseModel): + """Base class for arithmetic expressions. + + This class serves as the foundation for all expression types in the AST. + """ + pass + +class Number(Expression): + """Represents a numeric literal. + + Attributes: + value (float): The numeric value of the literal. + """ + value: float + +class BinaryOp(Expression): + """Represents a binary operation. + + Attributes: + op (str): The operator (e.g., '+', '-', '*', '/'). + left (Expression): The left operand. + right (Expression): The right operand. + """ + op: str # '+', '-', '*', '/' + left: Expression + right: Expression + +# Python Code Generator +class PythonCodeGen(ConversionRule): + """Converts an AST into Python code. + + Uses dispatch to generate Python expression strings from AST nodes. + """ + # Generates Python expression strings by dispatching on node types + + def map_Number(self, model, operands=None): + """Convert a Number model to its string representation. + + Args: + model (Number): The number model to convert. + + Returns: + str: The string representation of the number. + """ + # Return the numeric value as a Python literal string + return str(model.value) + + def map_BinaryOp(self, model, operands=None): + """Convert a BinaryOp model to a Python expression string. + + Args: + model (BinaryOp): The binary operation model to convert. + + Returns: + str: The Python code representation of the binary operation. + """ + # Generate code for left and right operands + left = self(model.left) + right = self(model.right) + # Return the formatted binary expression + return f"({left} {model.op} {right})" + +# LaTeX Math Generator +class LaTeXCodeGen(ConversionRule): + """Converts an AST into LaTeX math expressions. + + Uses dispatch to generate LaTeX code from AST nodes. + """ + # Generates LaTeX code for math expressions + + def map_Number(self, model, operands=None): + """Convert a Number model to its LaTeX string. + + Args: + model (Number): The number model to convert. + + Returns: + str: The LaTeX representation of the number. + """ + # Return the numeric value as a LaTeX literal + return str(model.value) + + def map_BinaryOp(self, model, operands=None): + """Convert a BinaryOp model to its LaTeX representation. + + Args: + model (BinaryOp): The binary operation model to convert. + + Returns: + str: The LaTeX representation of the binary operation. + """ + # Generate LaTeX for operands + left = self(model.left) + right = self(model.right) + + # Special handling for division in LaTeX + if model.op == '/': + return f"\\frac{{{left}}}{{{right}}}" + # Special handling for multiplication in LaTeX + elif model.op == '*': + return f"{left} \\times {right}" + else: + return f"{left} {model.op} {right}" + +# Assembly-like Instructions Generator +class AssemblyCodeGen(ConversionRule): + """Converts an AST into assembly-like instructions. + + Generates a sequence of instructions using temporary registers. + + Attributes: + temp_counter (int): Counter for temporary registers. + instructions (List[str]): List of generated instructions. + """ + # Builds a sequence of pseudo-assembly using temp registers + + def __init__(self): + super().__init__() + self.temp_counter = 0 + self.instructions = [] + + def get_temp(self): + """Generate a new temporary register name. + + Returns: + str: The name of the new temporary register. + """ + # Allocate next temporary register identifier + self.temp_counter += 1 + return f"t{self.temp_counter}" + + def map_Number(self, model, operands=None): + """Load a numeric literal into a temporary register. + + Args: + model (Number): The number model to load. + + Returns: + str: The temporary register holding the number. + """ + # Create a temp, emit a LOAD instruction + temp = self.get_temp() + self.instructions.append(f"LOAD {temp}, {model.value}") + return temp + + def map_BinaryOp(self, model, operands=None): + """Generate instructions for a binary operation. + + Args: + model (BinaryOp): The binary operation model. + + Returns: + str: The temporary register holding the result. + """ + # Recursively compute left and right into temps + left_temp = self(model.left) + right_temp = self(model.right) + result_temp = self.get_temp() + + # Map operators to assembly instructions + op_map = { + '+': 'ADD', + '-': 'SUB', + '*': 'MUL', + '/': 'DIV' + } + + self.instructions.append(f"{op_map[model.op]} {result_temp}, {left_temp}, {right_temp}") + return result_temp + + def get_program(self): + """Return the complete assembly program as a single string. + + Returns: + str: The concatenated assembly instructions. + """ + # Join all instructions into a program listing + return '\n'.join(self.instructions) + +def main(): + """Main function to demonstrate multi-target code generation.""" + # Entry point: demonstrate multi-target code generators + # Create a sample expression: (2 + 3) * (4 / 2) + expr = BinaryOp( + op='*', + left=BinaryOp(op='+', left=Number(value=2), right=Number(value=3)), + right=BinaryOp(op='/', left=Number(value=4), right=Number(value=2)) + ) + + # Generate different representations + python_gen = PythonCodeGen() + latex_gen = LaTeXCodeGen() + asm_gen = AssemblyCodeGen() + + python_code = python_gen(expr) + latex_code = latex_gen(expr) + asm_gen(expr) # This populates the instructions + asm_code = asm_gen.get_program() + + print("Original Expression:") + print(expr) + print("\nPython Code:") + print(python_code) + print("\nLaTeX Math:") + print(latex_code) + print("\nAssembly-like Code:") + print(asm_code) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/examples/rewrite_example.py b/tests/examples/rewrite_example.py new file mode 100644 index 0000000..b726052 --- /dev/null +++ b/tests/examples/rewrite_example.py @@ -0,0 +1,205 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +from oqd_compiler_infrastructure.interface import TypeReflectBaseModel +from oqd_compiler_infrastructure.rule import RewriteRule +from oqd_compiler_infrastructure.walk import Post + +# AST data structures +class Expression(TypeReflectBaseModel): + """Base class for arithmetic expressions. + + This class serves as the foundation for all expression types in the AST. + """ + pass + +class Number(Expression): + """Represents a numeric literal. + + Attributes: + value (float): The numeric value of the literal. + """ + value: float + +class BinaryOp(Expression): + """Represents a binary operation. + + Attributes: + op (str): The operator for the binary operation (e.g., '+', '-', '*', '/'). + left (Expression): The left operand of the operation. + right (Expression): The right operand of the operation. + """ + op: str # '+', '-', '*', '/' + left: Expression + right: Expression + +class Variable(Expression): + """Represents a variable in an expression. + + Attributes: + name (str): The name of the variable. + """ + name: str + +class AlgebraicSimplifier(RewriteRule): + """Applies algebraic simplification rules to binary expressions. + + Rules implemented: + - x * 1 = x + - 1 * x = x + - x * 0 = 0 + - 0 * x = 0 + - x + 0 = x + - 0 + x = x + - x - 0 = x + """ + + def map_BinaryOp(self, model): + """Simplify algebraic binary operations according to defined rules. + + Args: + model (BinaryOp): The binary operation to simplify. + + Returns: + Expression: The simplified expression. + """ + # Handle multiplication with 1 or 0 + # Multiplication rules: x*1 -> x, 1*x -> x, x*0 -> 0, 0*x -> 0 + if model.op == '*': + # Check for multiplication by 1 + # x * 1 => x + if isinstance(model.left, Number) and model.left.value == 1: + return model.right + # 1 * x => x + if isinstance(model.right, Number) and model.right.value == 1: + return model.left + + # Check for multiplication by 0 + # x * 0 or 0 * x => 0 + if (isinstance(model.left, Number) and model.left.value == 0) or \ + (isinstance(model.right, Number) and model.right.value == 0): + return Number(value=0) + + # Handle addition/subtraction with 0 + # Addition/Subtraction rules: x+0 -> x, 0+x -> x, x-0 -> x + elif model.op in ['+', '-']: + if isinstance(model.right, Number) and model.right.value == 0: + # x + 0 or x - 0 => x + return model.left + if model.op == '+' and isinstance(model.left, Number) and model.left.value == 0: + # 0 + x => x + return model.right + + return model + + def map_Number(self, model): + """Return numeric literals unchanged. + + Args: + model (Number): The number model. + + Returns: + Number: The same number model. + """ + # Numbers are already in simplest form + return model + + def map_Variable(self, model): + """Return variable models unchanged. + + Args: + model (Variable): The variable model. + + Returns: + Variable: The same variable model. + """ + # Variables cannot be simplified further here + return model + +def print_expr(expr, prefix=""): + """Return a readable string representation of an expression. + + Args: + expr (Expression): The expression to format. + prefix (str): Optional prefix for formatting. + + Returns: + str: The formatted expression string. + """ + # Recursively traverse the expression tree to build a string + if isinstance(expr, Number): + return str(expr.value) + elif isinstance(expr, Variable): + return expr.name + elif isinstance(expr, BinaryOp): + return f"({print_expr(expr.left)} {expr.op} {print_expr(expr.right)})" + return str(expr) + +def main(): + """Main function to demonstrate algebraic simplification examples.""" + # Prepare and run simplifications on several test cases + # Create test expressions + test_cases = [ + # x * 1 + BinaryOp( + op='*', + left=Variable(name='x'), + right=Number(value=1) + ), + # 0 * x + BinaryOp( + op='*', + left=Number(value=0), + right=Variable(name='x') + ), + # (x + 0) * (y - 0) + BinaryOp( + op='*', + left=BinaryOp(op='+', left=Variable(name='x'), right=Number(value=0)), + right=BinaryOp(op='-', left=Variable(name='y'), right=Number(value=0)) + ), + # Complex expression: ((x * 1) + (0 * y)) * (z - 0) + BinaryOp( + op='*', + left=BinaryOp( + op='+', + left=BinaryOp(op='*', left=Variable(name='x'), right=Number(value=1)), + right=BinaryOp(op='*', left=Number(value=0), right=Variable(name='y')) + ), + right=BinaryOp(op='-', left=Variable(name='z'), right=Number(value=0)) + ), + # (x + 3) * (y - 1) + BinaryOp( + op='+', + left=BinaryOp(op='+', left=Variable(name='x'), right=Number(value=3)), + right=BinaryOp(op='-', left=Variable(name='y'), right=Number(value=1)) + ) + + ] + + # Create simplifier wrapped in Post traversal + # Post ensures we simplify from bottom-up + simplifier = Post(AlgebraicSimplifier()) + + # Run simplifications + print("Algebraic Simplifications:") + for i, expr in enumerate(test_cases, 1): + print(f"\nTest Case {i}:") + print(f"Original: {print_expr(expr)}") + result = simplifier(expr) + print(f"Simplified: {print_expr(result)}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/examples/simple_pass.py b/tests/examples/simple_pass.py new file mode 100644 index 0000000..dd57319 --- /dev/null +++ b/tests/examples/simple_pass.py @@ -0,0 +1,161 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, List +from oqd_compiler_infrastructure.interface import TypeReflectBaseModel +from oqd_compiler_infrastructure.base import PassBase + +# Define our expression data structures +class Expression(TypeReflectBaseModel): + """Base class for arithmetic expressions. + + This class serves as the foundation for all expression types in the AST. + """ + pass + +class Number(Expression): + """Represents a numeric literal. + + Attributes: + value (float): The numeric value of the literal. + """ + value: float + +class BinaryOp(Expression): + """Represents a binary operation. + + Attributes: + op (str): The operator (e.g., '+', '-', '*', '/'). + left (Expression): The left operand. + right (Expression): The right operand. + """ + op: str # '+', '-', '*', '/' + left: Expression + right: Expression + +# Define a simple optimization pass +class ConstantFoldingPass(PassBase): + """A pass that folds constant expressions. + + This pass uses PassBase to recursively evaluate binary operations with constant operands + and replace them with their computed values. + """ + + @property + def children(self): + """Returns an empty list since this pass does not have sub-passes.""" + # No child passes to apply; operates directly on nodes + return [] + + def map(self, model): + """Recursively process the model and apply constant folding. + + Args: + model (Expression): The expression to process. + + Returns: + Expression: The processed expression after folding. + """ + # Only handle binary operations; leave other nodes unchanged + if isinstance(model, BinaryOp): + # First recursively process children + # Simplify left and right sub-expressions first + left = self(model.left) + right = self(model.right) + + # If both operands are numeric literals, compute the result + if isinstance(left, Number) and isinstance(right, Number): + if model.op == '+': + return Number(value=left.value + right.value) + elif model.op == '*': + return Number(value=left.value * right.value) + elif model.op == '-': + return Number(value=left.value - right.value) + elif model.op == '/': + if right.value != 0: # Avoid division by zero + return Number(value=left.value / right.value) + + # Cannot fold: return a new BinaryOp with potentially simplified children + return BinaryOp(op=model.op, left=left, right=right) + + # Non-binary nodes are returned unchanged + return model + +def test_basic_folding(): + """Run basic test cases for the ConstantFoldingPass.""" + # Define and run sample expressions to verify folding + + # Test case 1: Simple addition (2 + 3) + expr1 = BinaryOp(op='+', left=Number(value=2), right=Number(value=3)) + + # Test case 2: Nested expression ((2 + 3) * (4 + 5)) + expr2 = BinaryOp( + op='*', + left=BinaryOp(op='+', left=Number(value=2), right=Number(value=3)), + right=BinaryOp(op='+', left=Number(value=4), right=Number(value=5)) + ) + + # Test case 3: Mixed constants and variables ((2 + x) * 3) + expr3 = BinaryOp( + op='*', + left=BinaryOp(op='+', left=Number(value=2), right=BinaryOp(op='+', left=Number(value=1), right=Number(value=2))), + right=Number(value=3) + ) + + # Create our pass + pass_ = ConstantFoldingPass() + + # Run the pass on each expression + result1 = pass_(expr1) + result2 = pass_(expr2) + result3 = pass_(expr3) + + print("Test Case 1:") + print(f"Input: {expr1}") + print(f"Output: {result1}") + print() + + print("Test Case 2:") + print(f"Input: {expr2}") + print(f"Output: {result2}") + print() + + print("Test Case 3:") + print(f"Input: {expr3}") + print(f"Output: {result3}") + +# Example usage + +def main(): + """Main function to demonstrate and test constant folding.""" + # Example usage: fold a sample expression and run tests + + # Create an expression: (2 + 3) * (4 + 5) + expr = BinaryOp( + op='*', + left=BinaryOp(op='+', left=Number(value=2), right=Number(value=3)), + right=BinaryOp(op='+', left=Number(value=4), right=Number(value=5)) + ) + + # Create and run our pass + pass_ = ConstantFoldingPass() + result = pass_(expr) + + print(f"Original expression: {expr}") + print(f"After constant folding: {result}") + test_basic_folding() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/examples/symbolic_diff.py b/tests/examples/symbolic_diff.py new file mode 100644 index 0000000..a1a0122 --- /dev/null +++ b/tests/examples/symbolic_diff.py @@ -0,0 +1,396 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +from oqd_compiler_infrastructure.interface import TypeReflectBaseModel +from oqd_compiler_infrastructure.rule import RewriteRule +from oqd_compiler_infrastructure.walk import Post +from oqd_compiler_infrastructure.rewriter import Chain, FixedPoint + +# AST data structures +class Expression(TypeReflectBaseModel): + """Base class for expressions in symbolic differentiation. + + This class serves as the foundation for all expression types in the AST. + """ + pass + +class Number(Expression): + """Represents a numeric literal. + + Attributes: + value (float): The numeric value of the literal. + """ + value: float + +class Variable(Expression): + """Represents a variable in an expression. + + Attributes: + name (str): The name of the variable. + """ + name: str + +class BinaryOp(Expression): + """Represents a binary operation. + + Attributes: + op (str): The operator (e.g., '+', '-', '*', '/', '^'). + left (Expression): The left operand. + right (Expression): The right operand. + """ + op: str # '+', '-', '*', '/', '^' + +class Function(Expression): + """Represents a mathematical function. + + Attributes: + name (str): The name of the function (e.g., 'sin', 'cos', 'exp', 'ln'). + arg (Expression): The argument of the function. + """ + name: str # 'sin', 'cos', 'exp', 'ln' + +class Derivative(Expression): + """Represents a derivative operation. + + Attributes: + expr (Expression): The expression to differentiate. + var (str): The variable with respect to which to differentiate. + """ + expr: Expression + var: str # variable to differentiate with respect to + +class SymbolicDifferentiator(RewriteRule): + """Implements symbolic differentiation rules. + + This class defines rules for differentiating constants, variables, + arithmetic operations, exponentials, and common functions. + """ + + def map_Derivative(self, model): + """Differentiate the model expression with respect to the given variable. + + Args: + model (Derivative): The derivative model to process. + + Returns: + Expression: The resulting differentiated expression. + """ + # Apply differentiation rules based on expression type + expr = model.expr + var = model.var + + # Rule: d/dx(c) = 0 for constant + # Constant rule: derivative of constant is zero + if isinstance(expr, Number): + return Number(value=0) + + # Rule: d/dx(x) = 1 for variable x + # Variable rule: d/dx(x)=1, d/dx(y!=x)=0 + if isinstance(expr, Variable): + if expr.name == var: + return Number(value=1) + return Number(value=0) + + # Rule: d/dx(u + v) = d/dx(u) + d/dx(v) + # Sum rule: derivative distributes over addition/subtraction + if isinstance(expr, BinaryOp) and expr.op in ['+', '-']: + left_deriv = Derivative(expr=expr.left, var=var) + right_deriv = Derivative(expr=expr.right, var=var) + return BinaryOp(op=expr.op, left=self(left_deriv), right=self(right_deriv)) + + # Rule: d/dx(u * v) = u * d/dx(v) + v * d/dx(u) + # Product rule: u*v' + v*u' + if isinstance(expr, BinaryOp) and expr.op == '*': + u, v = expr.left, expr.right + du = Derivative(expr=u, var=var) + dv = Derivative(expr=v, var=var) + term1 = BinaryOp(op='*', left=u, right=self(dv)) + term2 = BinaryOp(op='*', left=v, right=self(du)) + return BinaryOp(op='+', left=term1, right=term2) + + # Rule: d/dx(u^n) = n * u^(n-1) * d/dx(u) + # Power rule: d/dx(u^n) = n*u^(n-1)*u' + if isinstance(expr, BinaryOp) and expr.op == '^' and isinstance(expr.right, Number): + u, n = expr.left, expr.right.value + du = Derivative(expr=u, var=var) + power = BinaryOp(op='^', left=u, right=Number(value=n-1)) + return BinaryOp( + op='*', + left=Number(value=n), + right=BinaryOp(op='*', left=power, right=self(du)) + ) + + # Rule: d/dx(sin(u)) = cos(u) * d/dx(u) + # Function rule: sin + if isinstance(expr, Function) and expr.name == 'sin': + du = Derivative(expr=expr.arg, var=var) + return BinaryOp( + op='*', + left=Function(name='cos', arg=expr.arg), + right=self(du) + ) + + # Rule: d/dx(cos(u)) = -sin(u) * d/dx(u) + # Function rule: cos + if isinstance(expr, Function) and expr.name == 'cos': + du = Derivative(expr=expr.arg, var=var) + return BinaryOp( + op='*', + left=BinaryOp( + op='*', + left=Number(value=-1), + right=Function(name='sin', arg=expr.arg) + ), + right=self(du) + ) + + # Rule: d/dx(exp(u)) = exp(u) * d/dx(u) + # Function rule: exp + if isinstance(expr, Function) and expr.name == 'exp': + du = Derivative(expr=expr.arg, var=var) + return BinaryOp(op='*', left=expr, right=self(du)) + + # Rule: d/dx(ln(u)) = d/dx(u) / u + # Function rule: ln + if isinstance(expr, Function) and expr.name == 'ln': + du = Derivative(expr=expr.arg, var=var) + return BinaryOp(op='/', left=self(du), right=expr.arg) + + return model + +class ExpressionSimplifier(RewriteRule): + """Simplifies expressions post-differentiation. + + This class applies algebraic simplification rules to the differentiated AST. + """ + + def map_BinaryOp(self, model): + """Apply simplification rules to binary operations after differentiation. + + Args: + model (BinaryOp): The binary operation to simplify. + + Returns: + Expression: The simplified expression. + """ + # First simplify the operands recursively + model.left = self(model.left) + model.right = self(model.right) + + # Simplify x^1 to x + # Simplify power-of-one: x^1 => x + if model.op == '^' and isinstance(model.right, Number) and model.right.value == 1: + return model.left + + # Simplify multiplication by 1 + # Simplify multiplication by 1 and 0 + if model.op == '*': + if isinstance(model.left, Number) and model.left.value == 1: + return model.right + if isinstance(model.right, Number) and model.right.value == 1: + return model.left + + # Simplify multiplication by 0 + # Zero multiplication: result is zero + if (isinstance(model.left, Number) and model.left.value == 0) or \ + (isinstance(model.right, Number) and model.right.value == 0): + return Number(value=0) + + # Combine numeric coefficients + # Combine nested numeric factors: 2*(3*x) => 6*x + if isinstance(model.left, Number): + if isinstance(model.right, BinaryOp) and model.right.op == '*': + if isinstance(model.right.left, Number): + # 2 * (3 * x) -> 6 * x + return BinaryOp(op='*', + left=Number(value=model.left.value * model.right.left.value), + right=model.right.right) + if isinstance(model.right.right, Number): + # 2 * (x * 3) -> 6 * x + return BinaryOp(op='*', + left=Number(value=model.left.value * model.right.right.value), + right=model.right.left) + + # Move all numbers to the left in multiplication chains + if isinstance(model.right, Number): + return BinaryOp(op='*', left=model.right, right=model.left) + + # Simplify addition/subtraction with 0 + # Simplify adding or subtracting zero: x+0=>x, 0+x=>x + elif model.op in ['+', '-']: + if isinstance(model.right, Number) and model.right.value == 0: + return model.left + if model.op == '+' and isinstance(model.left, Number) and model.left.value == 0: + return model.right + + return model + +def print_expr(expr): + """Convert an expression into a human-readable string. + + Args: + expr (Expression): The expression to format. + + Returns: + str: A human-readable representation of the expression. + """ + if isinstance(expr, Number): + # Format number + return str(expr.value) + elif isinstance(expr, Variable): + return expr.name + elif isinstance(expr, BinaryOp): + left = print_expr(expr.left) + right = print_expr(expr.right) + + # Special handling for multiplication + if expr.op == '*': + # If right operand is a multiplication, merge them + if isinstance(expr.right, BinaryOp) and expr.right.op == '*': + terms = [] + # Add left operand + if isinstance(expr.left, Number): + terms.append(str(expr.left.value)) + else: + terms.append(print_expr(expr.left)) + # Add right operand's parts + if isinstance(expr.right.left, Number): + terms.append(str(expr.right.left.value)) + else: + terms.append(print_expr(expr.right.left)) + terms.append(print_expr(expr.right.right)) + return " * ".join(terms) + + # Simple case: just two terms + if isinstance(expr.left, Number): + return f"{left} * {right}" + if not isinstance(expr.right, BinaryOp) or expr.right.op != '*': + return f"{left} * {right}" + + # Default case: use parentheses + return f"({left} {expr.op} {right})" + elif isinstance(expr, Function): + return f"{expr.name}({print_expr(expr.arg)})" + elif isinstance(expr, Derivative): + return f"d/d{expr.var}({print_expr(expr.expr)})" + return str(expr) + +def main(): + """Main function to demonstrate symbolic differentiation with simplification.""" + # Test cases for differentiation + test_cases = [ + # d/dx(x^2) + Derivative( + expr=BinaryOp( + op='^', + left=Variable(name='x'), + right=Number(value=2) + ), + var='x' + ), + + # d/dx(sin(x)) + Derivative( + expr=Function( + name='sin', + arg=Variable(name='x') + ), + var='x' + ), + + # d/dx(x * sin(x)) + Derivative( + expr=BinaryOp( + op='*', + left=Variable(name='x'), + right=Function(name='sin', arg=Variable(name='x')) + ), + var='x' + ), + + # d/dx(exp(x^2)) + Derivative( + expr=Function( + name='exp', + arg=BinaryOp( + op='^', + left=Variable(name='x'), + right=Number(value=2) + ) + ), + var='x' + ), + + # d/dx(sin(x^2)) + Derivative( + expr=Function( + name='sin', + arg=BinaryOp( + op='^', + left=Variable(name='x'), + right=Number(value=2) + ) + ), + var='x' + ), + + # d/dx(cos(x) * sin(x)) + Derivative( + expr=BinaryOp( + op='*', + left=Function(name='cos', arg=Variable(name='x')), + right=Function(name='sin', arg=Variable(name='x')) + ), + var='x' + ), + + # d/dx(ln(x^2 + 1)) + Derivative( + expr=Function( + name='ln', + arg=BinaryOp( + op='+', + left=BinaryOp( + op='^', + left=Variable(name='x'), + right=Number(value=2) + ), + right=Number(value=1) + ) + ), + var='x' + ) + ] + + # Create a chain of passes: + # 1. Differentiate the expression + # 2. Apply simplification rules repeatedly until no more changes + symbolic_diff_pass = Chain( + Post(SymbolicDifferentiator()), # First apply differentiation rules + FixedPoint(Post(ExpressionSimplifier())), # Apply simplification rules until no changes + ) + + # Run differentiation and simplification + print("Symbolic Differentiation Examples:") + print("=" * 50) + for i, expr in enumerate(test_cases, 1): + print(f"\nTest Case {i}:") + print(f"Expression: {print_expr(expr)}") + result = symbolic_diff_pass(expr) + print(f"Derivative: {print_expr(result)}") + print("-" * 30) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/examples/test_simple_pass.py b/tests/examples/test_simple_pass.py new file mode 100644 index 0000000..f56bcec --- /dev/null +++ b/tests/examples/test_simple_pass.py @@ -0,0 +1,60 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from simple_pass import Number, BinaryOp, ConstantFoldingPass +from oqd_compiler_infrastructure.rewriter import Chain, FixedPoint + +def test_basic_folding(): + """Run basic test cases for the ConstantFoldingPass.""" + # Test case 1: Simple addition (2 + 3) + expr1 = BinaryOp(op='+', left=Number(value=2), right=Number(value=3)) + + # Test case 2: Nested expression ((2 + 3) * (4 + 5)) + expr2 = BinaryOp( + op='*', + left=BinaryOp(op='+', left=Number(value=2), right=Number(value=3)), + right=BinaryOp(op='+', left=Number(value=4), right=Number(value=5)) + ) + + # Test case 3: Mixed constants and variables ((2 + x) * 3) + expr3 = BinaryOp( + op='*', + left=BinaryOp(op='+', left=Number(value=2), right=BinaryOp(op='+', left=Number(value=1), right=Number(value=2))), + right=Number(value=3) + ) + + # Create our pass + pass_ = ConstantFoldingPass() + + # Run the pass on each expression + result1 = pass_(expr1) + result2 = pass_(expr2) + result3 = pass_(expr3) + + print("Test Case 1:") + print(f"Input: {expr1}") + print(f"Output: {result1}") + print() + + print("Test Case 2:") + print(f"Input: {expr2}") + print(f"Output: {result2}") + print() + + print("Test Case 3:") + print(f"Input: {expr3}") + print(f"Output: {result3}") + +if __name__ == "__main__": + test_basic_folding() \ No newline at end of file diff --git a/tests/examples/trig_simplifier.py b/tests/examples/trig_simplifier.py new file mode 100644 index 0000000..e9ba58d --- /dev/null +++ b/tests/examples/trig_simplifier.py @@ -0,0 +1,342 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +from oqd_compiler_infrastructure.interface import TypeReflectBaseModel +from oqd_compiler_infrastructure.rule import RewriteRule +from oqd_compiler_infrastructure.walk import Post + +# Use the same Expression classes from symbolic_diff.py +# AST data structures +class Expression(TypeReflectBaseModel): + """Base class for expressions in trigonometric simplification. + + This class serves as the foundation for all expression types in the AST. + """ + pass + +class Number(Expression): + """Represents a numeric literal. + + Attributes: + value (float): The numeric value of the literal. + """ + value: float + +class Variable(Expression): + """Represents a variable in an expression. + + Attributes: + name (str): The name of the variable. + """ + name: str + +class BinaryOp(Expression): + """Represents a binary operation. + + Attributes: + op (str): The operator for the operation (e.g., '+', '-', '*', '/', '^'). + left (Expression): The left operand. + right (Expression): The right operand. + """ + op: str # '+', '-', '*', '/', '^' + +class Function(Expression): + """Represents a mathematical function. + + Attributes: + name (str): The name of the function (e.g., 'sin', 'cos', 'exp', 'ln'). + arg (Expression): The argument of the function. + """ + name: str # 'sin', 'cos', 'exp', 'ln' + +class Derivative(Expression): + """Represents a derivative operation. + + Attributes: + expr (Expression): The expression to differentiate. + var (str): The variable with respect to which to differentiate. + """ + expr: Expression + var: str # variable to differentiate with respect to + +class TrigSimplifier(RewriteRule): + """Implements trigonometric simplification rules. + + Applies various identities and transformations for trigonometric expressions: + - sin^2(x) + cos^2(x) = 1 + - tan(x) = sin(x)/cos(x) + - sin(-x) = -sin(x) + - cos(-x) = cos(x) + - sin(x)^2 = (1 - cos(2x))/2 (if not part of sin²(x) + cos²(x)) + - cos(x)^2 = (1 + cos(2x))/2 (if not part of sin²(x) + cos²(x)) + """ + + def walk(self, model): + """Override walk to apply Pythagorean identity before traversal. + + Args: + model (Expression): The expression to process. + + Returns: + Expression: Simplified expression if identity matches, otherwise traversal result. + """ + # If pattern matches sin^2(x)+cos^2(x), simplify to 1 immediately + if isinstance(model, BinaryOp) and self._is_pythagorean_identity(model): + return Number(value=1) + return super().walk(model) + + def map_BinaryOp(self, model): + """Apply double-angle transformations for trigonometric power expressions. + + Args: + model (BinaryOp): The binary operation model. + + Returns: + Expression: Transformed expression or original if no rule applies. + """ + # Double-angle: sin^2(x) -> (1 - cos(2x))/2 + if self._is_trig_power(model, 'sin', 2): + arg = self._get_trig_arg(model) + double_arg = BinaryOp(op='*', left=Number(value=2), right=arg) + return BinaryOp( + op='/', + left=BinaryOp( + op='-', + left=Number(value=1), + right=Function(name='cos', arg=double_arg) + ), + right=Number(value=2) + ) + + # Double-angle: cos^2(x) -> (1 + cos(2x))/2 + if self._is_trig_power(model, 'cos', 2): + arg = self._get_trig_arg(model) + double_arg = BinaryOp(op='*', left=Number(value=2), right=arg) + return BinaryOp( + op='/', + left=BinaryOp( + op='+', + left=Number(value=1), + right=Function(name='cos', arg=double_arg) + ), + right=Number(value=2) + ) + + return model + + def _is_pythagorean_identity(self, model): + """Check if expression matches sin^2(x) + cos^2(x) identity. + + Args: + model (BinaryOp): The binary operation model to check. + + Returns: + bool: True if the pattern matches, False otherwise. + """ + # Ensure operator is addition, then check both orderings of sin^2+cos^2 + if model.op != '+': + return False + + # Check both orderings: sin^2(x) + cos^2(x) and cos^2(x) + sin^2(x) + return ( + (self._is_trig_power(model.left, 'sin', 2) and + self._is_trig_power(model.right, 'cos', 2) and + self._args_equal(self._get_trig_arg(model.left), self._get_trig_arg(model.right))) + or + (self._is_trig_power(model.left, 'cos', 2) and + self._is_trig_power(model.right, 'sin', 2) and + self._args_equal(self._get_trig_arg(model.left), self._get_trig_arg(model.right))) + ) + + def map_Function(self, model): + """Simplify function expressions based on sign identities. + + Args: + model (Function): The function model to simplify. + + Returns: + Expression: Simplified expression or original if no rule applies. + """ + # Identity: sin(-x) => -sin(x) + if (model.name == 'sin' and + isinstance(model.arg, BinaryOp) and + model.arg.op == '*' and + isinstance(model.arg.left, Number) and + model.arg.left.value == -1): + return BinaryOp( + op='*', + left=Number(value=-1), + right=Function(name='sin', arg=model.arg.right) + ) + + # Identity: cos(-x) => cos(x) + if (model.name == 'cos' and + isinstance(model.arg, BinaryOp) and + model.arg.op == '*' and + isinstance(model.arg.left, Number) and + model.arg.left.value == -1): + return Function(name='cos', arg=model.arg.right) + + return model + + def _is_trig_power(self, expr, func_name, power): + """Check if expression is a trigonometric function raised to a specified power. + + Args: + expr (Expression): The expression to check. + func_name (str): Name of the trigonometric function. + power (int): The exponent power. + + Returns: + bool: True if the expression matches the pattern. + """ + # Match pattern: Function ^ number(power) + return (isinstance(expr, BinaryOp) and + expr.op == '^' and + isinstance(expr.left, Function) and + expr.left.name == func_name and + isinstance(expr.right, Number) and + expr.right.value == power) + + def _get_trig_arg(self, expr): + """Extract the argument from a trigonometric power expression. + + Args: + expr (BinaryOp): The trigonometric power expression. + + Returns: + Expression: The argument of the function. + """ + # The function node is expr.left; return its argument + return expr.left.arg + + def _args_equal(self, arg1, arg2): + """Check if two expression arguments are structurally equal. + + Args: + arg1 (Expression): The first argument to compare. + arg2 (Expression): The second argument to compare. + + Returns: + bool: True if the arguments are equal, False otherwise. + """ + # Perform recursive structural comparison of arguments + if type(arg1) != type(arg2): + return False + + if isinstance(arg1, Number): + return arg1.value == arg2.value + elif isinstance(arg1, Variable): + return arg1.name == arg2.name + elif isinstance(arg1, BinaryOp): + return (arg1.op == arg2.op and + self._args_equal(arg1.left, arg2.left) and + self._args_equal(arg1.right, arg2.right)) + elif isinstance(arg1, Function): + return (arg1.name == arg2.name and + self._args_equal(arg1.arg, arg2.arg)) + + return False + +def print_expr(expr): + """Convert an expression to a human-readable string. + + Args: + expr (Expression): The expression to print. + + Returns: + str: A readable string representation of the expression. + """ + # Pretty-print literals, binary ops, functions, and derivatives + if isinstance(expr, Number): + return str(expr.value) + elif isinstance(expr, Variable): + return expr.name + elif isinstance(expr, BinaryOp): + return f"({print_expr(expr.left)} {expr.op} {print_expr(expr.right)})" + elif isinstance(expr, Function): + return f"{expr.name}({print_expr(expr.arg)})" + return str(expr) + +def main(): + """Main function to demonstrate trigonometric simplification examples.""" + # Prepare test cases and apply simplifier to each + test_cases = [ + # sin^2(x) + cos^2(x) = 1 + BinaryOp( + op='+', + left=BinaryOp( + op='^', + left=Function(name='sin', arg=Variable(name='x')), + right=Number(value=2) + ), + right=BinaryOp( + op='^', + left=Function(name='cos', arg=Variable(name='x')), + right=Number(value=2) + ) + ), + + # cos^2(x) + sin^2(x) = 1 (reverse order) + BinaryOp( + op='+', + left=BinaryOp( + op='^', + left=Function(name='cos', arg=Variable(name='x')), + right=Number(value=2) + ), + right=BinaryOp( + op='^', + left=Function(name='sin', arg=Variable(name='x')), + right=Number(value=2) + ) + ), + + # sin^2(x) alone should use double angle formula + BinaryOp( + op='^', + left=Function(name='sin', arg=Variable(name='x')), + right=Number(value=2) + ), + + # sin(-x) = -sin(x) + Function( + name='sin', + arg=BinaryOp( + op='*', + left=Number(value=-1), + right=Variable(name='x') + ) + ) + ] + + # Create simplifier that checks for Pythagorean identity first + simplifier = TrigSimplifier() + + # Run simplifications + print("Trigonometric Simplification Examples:") + for i, expr in enumerate(test_cases, 1): + print(f"\nTest Case {i}:") + print(f"Original: {print_expr(expr)}") + # First try to match Pythagorean identity + if isinstance(expr, BinaryOp) and simplifier._is_pythagorean_identity(expr): + result = Number(value=1) + else: + # If not Pythagorean identity, apply other transformations + result = Post(simplifier)(expr) + print(f"Simplified: {print_expr(result)}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/examples/walk_example.py b/tests/examples/walk_example.py new file mode 100644 index 0000000..0d11299 --- /dev/null +++ b/tests/examples/walk_example.py @@ -0,0 +1,180 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +from oqd_compiler_infrastructure.interface import TypeReflectBaseModel +from oqd_compiler_infrastructure.rule import ConversionRule +from oqd_compiler_infrastructure.walk import Pre, Post, Level + +# AST data structures +class Expression(TypeReflectBaseModel): + """Base class for arithmetic expressions. + + This class serves as the foundation for all expression types in the abstract syntax tree (AST). + It inherits from TypeReflectBaseModel to provide type reflection capabilities. + """ + pass + +class Number(Expression): + """Represents a numeric literal. + + Attributes: + value (float): The numeric value of the literal. + """ + value: float + +class BinaryOp(Expression): + """Represents a binary operation. + + Attributes: + op (str): The operator for the binary operation (e.g., '+', '-', '*', '/'). + left (Expression): The left operand of the binary operation. + right (Expression): The right operand of the binary operation. + """ + op: str # '+', '-', '*', '/' + left: Expression + right: Expression + +# Simpler constant folding using Post traversal +class SimpleConstantFolder(ConversionRule): + """A simpler constant folding pass that relies on Post traversal. + + This class implements a constant folding optimization that evaluates binary operations + with constant operands and replaces them with their computed value. + + Methods: + map_Number(model): Returns the number as it is, since it's already folded. + map_BinaryOp(model, operands): Evaluates the binary operation if both operands are numbers. + """ + + def map_Number(self, model, operands=None): + """Returns the number as it is. + + Args: + model (Number): The number model to be returned. + + Returns: + Number: The same number model. + """ + return model + + def map_BinaryOp(self, model, operands=None): + """Evaluates the binary operation if both operands are numbers. + + Args: + model (BinaryOp): The binary operation model to be evaluated. + operands (dict, optional): The processed operands from the traversal. + + Returns: + Number or BinaryOp: The result of the operation if both operands are numbers, + otherwise returns a new BinaryOp with processed children. + """ + # Called after children are processed; 'operands' contains folded results + print(f"Visiting BinaryOp({model.op})") + print(f"Operands: {operands}") + + if operands: + # Extract folded children from operands dict + left = operands['left'] + right = operands['right'] + + # If both children are numeric, compute and return new Number + if isinstance(left, Number) and isinstance(right, Number): + if model.op == '+': + return Number(value=left.value + right.value) + elif model.op == '*': + return Number(value=left.value * right.value) + elif model.op == '-': + return Number(value=left.value - right.value) + elif model.op == '/': + if right.value != 0: + return Number(value=left.value / right.value) + + return BinaryOp(op=model.op, left=left, right=right) + # Fallback: return original model if no operands + return model + +# Debug printer to show traversal order +class DebugPrinter(ConversionRule): + """Prints nodes as they're visited to demonstrate traversal order. + + This class is used for debugging purposes to visualize the order in which nodes + are visited during the traversal of the AST. + + Attributes: + prefix (str): A prefix string to format the output. + """ + + def __init__(self, prefix=""): + super().__init__() + self.prefix = prefix + + def map_Number(self, model, operands=None): + """Prints the number being visited. + + Args: + model (Number): The number model being visited. + + Returns: + Number: The same number model. + """ + # Log visiting a Number node + print(f"{self.prefix}Visiting Number({model.value})") + return model + + def map_BinaryOp(self, model, operands=None): + """Prints the binary operation being visited. + + Args: + model (BinaryOp): The binary operation model being visited. + + Returns: + BinaryOp: The same binary operation model. + """ + # Log visiting a BinaryOp node + print(f"{self.prefix}Visiting BinaryOp({model.op})") + return model + +def main(): + """Main function to demonstrate the functionality of the AST and traversal methods.""" + # Build a sample expression tree: (2+3)*(4/2) + expr = BinaryOp( + op='*', + left=BinaryOp(op='+', left=Number(value=2), right=Number(value=3)), + right=BinaryOp(op='/', left=Number(value=4), right=Number(value=2)) + ) + + # Pre-order: root first, then children + print("Pre-order traversal (top-down):") + pre_printer = Pre(DebugPrinter(prefix=" ")) + pre_printer(expr) + + # Post-order: children first, then root + print("\nPost-order traversal (bottom-up):") + post_printer = Post(DebugPrinter(prefix=" ")) + post_printer(expr) + + # Level-order: breadth-first traversal + print("\nLevel-order traversal (breadth-first):") + level_printer = Level(DebugPrinter(prefix=" ")) + level_printer(expr) + + # Apply post-order constant folding + print("\nConstant folding with Post traversal:") + folder = Post(SimpleConstantFolder()) + result = folder(expr) + print(f"Result: {result}") + +if __name__ == "__main__": + main() \ No newline at end of file