Skip to content

Commit 2ae1ba1

Browse files
committed
add Logic Types
1 parent 6d25e2a commit 2ae1ba1

File tree

3 files changed

+43
-2
lines changed

3 files changed

+43
-2
lines changed

src/finchlite/finch_logic/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,19 @@
2020
TableValueFType,
2121
Value,
2222
)
23+
from .stages import LogicEvaluator, LogicLowerer, LogicTransform
2324

2425
__all__ = [
2526
"Aggregate",
2627
"Alias",
2728
"Field",
2829
"FinchLogicInterpreter",
2930
"Literal",
31+
"LogicEvaluator",
3032
"LogicExpression",
33+
"LogicLowerer",
3134
"LogicNode",
35+
"LogicTransform",
3236
"LogicTree",
3337
"MapJoin",
3438
"Plan",

src/finchlite/finch_logic/interpreter.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,18 @@
1919
TableValue,
2020
Value,
2121
)
22+
from .stages import LogicEvaluator
2223

2324

24-
class FinchLogicInterpreter:
25+
class FinchLogicInterpreter(LogicEvaluator):
2526
def __init__(self, *, make_tensor=np.full, verbose=False):
2627
self.verbose = verbose
2728
self.bindings = {}
2829
self.make_tensor = make_tensor # Added make_tensor argument
2930

30-
def __call__(self, node):
31+
def __call__(self, node, bindings=None):
32+
if bindings is not None:
33+
self.bindings = bindings.copy()
3134
# Example implementation for evaluating an expression
3235
if self.verbose:
3336
print(f"Evaluating: {node}")
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from abc import ABC, abstractmethod
2+
3+
from finchlite.algebra.tensor import TensorFType
4+
5+
from .. import finch_notation as ntn
6+
from ..symbolic import Stage
7+
from . import nodes as lgc
8+
9+
10+
class LogicEvaluator(Stage):
11+
@abstractmethod
12+
def __call__(self, term: lgc.LogicNode, bindings: dict[lgc.Alias, lgc.TableValue]|None=None) -> lgc.TableValue | tuple[lgc.TableValue]:
13+
"""
14+
Evaluate the given logic.
15+
"""
16+
17+
18+
class LogicLowerer(ABC):
19+
@abstractmethod
20+
def __call__(
21+
self, term: lgc.LogicNode, bindings: dict[lgc.Alias, TensorFType]
22+
) -> tuple[ntn.Module, dict[lgc.Alias, TensorFType]]:
23+
"""
24+
Generate Finch Notation from the given logic and input types,
25+
types for all aliases.
26+
"""
27+
28+
29+
class LogicTransform(ABC):
30+
@abstractmethod
31+
def __call__(self, term: lgc.LogicNode, bindings: dict[lgc.Alias, TensorFType]) -> tuple[lgc.LogicNode, dict[lgc.Alias, TensorFType]]:
32+
"""
33+
Transform the given logic term into another logic term.
34+
"""

0 commit comments

Comments
 (0)