diff --git a/pydough/__init__.py b/pydough/__init__.py index 68d9f7618..d38291255 100644 --- a/pydough/__init__.py +++ b/pydough/__init__.py @@ -12,6 +12,7 @@ "get_logger", "init_pydough_context", "parse_json_metadata_from_file", + "range_collection", "to_df", "to_sql", ] @@ -22,6 +23,7 @@ from .logger import get_logger from .metadata import parse_json_metadata_from_file from .unqualified import display_raw, from_string, init_pydough_context +from .user_collections.user_collection_apis import range_collection # Create a default session for the user to interact with. # In most situations users will just use this session and diff --git a/pydough/conversion/agg_removal.py b/pydough/conversion/agg_removal.py index 49c0ee14e..864be2ee0 100644 --- a/pydough/conversion/agg_removal.py +++ b/pydough/conversion/agg_removal.py @@ -12,6 +12,7 @@ CallExpression, EmptySingleton, Filter, + GeneratedTable, Join, JoinType, Limit, @@ -276,7 +277,7 @@ def aggregation_uniqueness_helper( ) return node, final_uniqueness # Empty singletons don't have uniqueness information. - case EmptySingleton(): + case EmptySingleton() | GeneratedTable(): return node, set() case _: raise NotImplementedError( diff --git a/pydough/conversion/filter_pushdown.py b/pydough/conversion/filter_pushdown.py index d50d9da5d..8fb6d63cc 100644 --- a/pydough/conversion/filter_pushdown.py +++ b/pydough/conversion/filter_pushdown.py @@ -10,6 +10,7 @@ ColumnReference, EmptySingleton, Filter, + GeneratedTable, Join, JoinType, Limit, @@ -143,7 +144,7 @@ def push_filters( # be transposed beneath a limit without changing its output. node._input = push_filters(node.input, set()) return build_filter(node, filters) - case EmptySingleton() | Scan(): + case EmptySingleton() | Scan() | GeneratedTable(): # For remaining nodes, materialize all of the remaining filters. return build_filter(node, filters) case _: diff --git a/pydough/conversion/hybrid_operations.py b/pydough/conversion/hybrid_operations.py index 620c7f9bb..de62a23ff 100644 --- a/pydough/conversion/hybrid_operations.py +++ b/pydough/conversion/hybrid_operations.py @@ -16,6 +16,7 @@ "HybridPartition", "HybridPartitionChild", "HybridRoot", + "HybridUserGeneratedCollection", ] @@ -27,6 +28,9 @@ ColumnProperty, PyDoughExpressionQDAG, ) +from pydough.qdag.collections.user_collection_qdag import ( + PyDoughUserGeneratedCollectionQDag, +) from .hybrid_connection import HybridConnection from .hybrid_expressions import ( @@ -483,3 +487,34 @@ def __repr__(self): def search_term_definition(self, name: str) -> HybridExpr | None: return self.predecessor.search_term_definition(name) + + +class HybridUserGeneratedCollection(HybridOperation): + """ + Class for HybridOperation corresponding to a user-generated collection. + """ + + def __init__(self, user_collection: PyDoughUserGeneratedCollectionQDag): + """ + Args: + `collection`: the QDAG node for the user-generated collection. + """ + self._user_collection: PyDoughUserGeneratedCollectionQDag = user_collection + terms: dict[str, HybridExpr] = {} + for name, typ in user_collection.collection.column_names_and_types: + terms[name] = HybridRefExpr(name, typ) + unique_exprs: list[HybridExpr] = [] + for name in sorted(self.user_collection.unique_terms, key=str): + expr: PyDoughExpressionQDAG = self.user_collection.get_expr(name) + unique_exprs.append(HybridRefExpr(name, expr.pydough_type)) + super().__init__(terms, {}, [], unique_exprs) + + @property + def user_collection(self) -> PyDoughUserGeneratedCollectionQDag: + """ + The user-generated collection that this hybrid operation represents. + """ + return self._user_collection + + def __repr__(self): + return f"USER_GEN_COLLECTION[{self.user_collection.name}]" diff --git a/pydough/conversion/hybrid_translator.py b/pydough/conversion/hybrid_translator.py index b376ce944..42a92e2c6 100644 --- a/pydough/conversion/hybrid_translator.py +++ b/pydough/conversion/hybrid_translator.py @@ -42,6 +42,9 @@ Where, WindowCall, ) +from pydough.qdag.collections.user_collection_qdag import ( + PyDoughUserGeneratedCollectionQDag, +) from pydough.types import BooleanType, NumericType from .hybrid_connection import ConnectionType, HybridConnection @@ -68,6 +71,7 @@ HybridPartition, HybridPartitionChild, HybridRoot, + HybridUserGeneratedCollection, ) from .hybrid_syncretizer import HybridSyncretizer from .hybrid_tree import HybridTree @@ -1339,6 +1343,9 @@ def define_root_link( case HybridRoot(): # A root does not need to be joined to its parent join_keys = [] + case HybridUserGeneratedCollection(): + # A user-generated collection does not need to be joined to its parent + join_keys = [] case _: raise NotImplementedError(f"{operation.__class__.__name__}") if join_keys is not None: @@ -1557,6 +1564,18 @@ def make_hybrid_tree( HybridLimit(hybrid.pipeline[-1], node.records_to_keep) ) return hybrid + case PyDoughUserGeneratedCollectionQDag(): + # A user-generated collection is a special case of a collection + # access that is not a sub-collection, but rather a user-defined + # collection that is defined in the PyDough user collections. + hybrid_collection = HybridUserGeneratedCollection(node) + # Create a new hybrid tree for the user-generated collection. + successor_hybrid = HybridTree(hybrid_collection, node.ancestral_mapping) + hybrid = self.make_hybrid_tree( + node.ancestor_context, parent, is_aggregate + ) + hybrid.add_successor(successor_hybrid) + return successor_hybrid case ChildOperatorChildAccess(): assert parent is not None match node.child_access: @@ -1624,6 +1643,17 @@ def make_hybrid_tree( successor_hybrid = HybridTree( HybridRoot(), node.ancestral_mapping ) + case PyDoughUserGeneratedCollectionQDag(): + # A user-generated collection is a special case of a collection + # access that is not a sub-collection, but rather a user-defined + # collection that is defined in the PyDough user collections. + hybrid_collection = HybridUserGeneratedCollection( + node.child_access + ) + # Create a new hybrid tree for the user-generated collection. + successor_hybrid = HybridTree( + hybrid_collection, node.ancestral_mapping + ) case _: raise NotImplementedError( f"{node.__class__.__name__} (child is {node.child_access.__class__.__name__})" diff --git a/pydough/conversion/hybrid_tree.py b/pydough/conversion/hybrid_tree.py index 9a52f3e50..d0263538a 100644 --- a/pydough/conversion/hybrid_tree.py +++ b/pydough/conversion/hybrid_tree.py @@ -46,6 +46,7 @@ HybridPartition, HybridPartitionChild, HybridRoot, + HybridUserGeneratedCollection, ) @@ -676,6 +677,8 @@ def always_exists(self) -> bool: # Stepping into a partition child always has a matching data # record for each parent, by definition. pass + case HybridUserGeneratedCollection(): + return start_operation.user_collection.collection.always_exists() case _: raise NotImplementedError( f"Invalid start of pipeline: {start_operation.__class__.__name__}" @@ -726,6 +729,8 @@ def is_singular(self) -> bool: case HybridChildPullUp(): if not self.children[self.pipeline[0].child_idx].subtree.is_singular(): return False + case HybridUserGeneratedCollection(): + return self.pipeline[0].user_collection.collection.is_singular() case _: return False # The current level is fine, so check any levels above it next. diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 956bb3698..90f3f7d87 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -36,6 +36,7 @@ EmptySingleton, ExpressionSortInfo, Filter, + GeneratedTable, Join, JoinCardinality, JoinType, @@ -49,6 +50,7 @@ WindowCallExpression, ) from pydough.types import BooleanType, NumericType, UnknownType +from pydough.types.pydough_type import PyDoughType from .agg_removal import remove_redundant_aggs from .agg_split import split_partial_aggregates @@ -79,6 +81,7 @@ HybridPartition, HybridPartitionChild, HybridRoot, + HybridUserGeneratedCollection, ) from .hybrid_translator import HybridTranslator from .hybrid_tree import HybridTree @@ -1166,6 +1169,29 @@ def translate_hybridroot(self, context: TranslationOutput) -> TranslationOutput: new_expressions[shifted_expr] = column_ref return TranslationOutput(context.relational_node, new_expressions) + def build_user_generated_table( + self, node: HybridUserGeneratedCollection + ) -> TranslationOutput: + """Builds a user-generated table from the given hybrid user-generated collection. + + Args: + `node`: The user-generated collection node to translate. + + Returns: + The translated output payload. + """ + collection = node._user_collection.collection + out_columns: dict[HybridExpr, ColumnReference] = {} + gen_columns: dict[str, RelationalExpression] = {} + for column_name, column_type in collection.column_names_and_types: + hybrid_ref = HybridRefExpr(column_name, column_type) + col_ref = ColumnReference(column_name, column_type) + out_columns[hybrid_ref] = col_ref + gen_columns[column_name] = col_ref + + answer = GeneratedTable(collection) + return TranslationOutput(answer, out_columns) + def rel_translation( self, hybrid: HybridTree, @@ -1289,6 +1315,18 @@ def rel_translation( case HybridRoot(): assert context is not None, "Malformed HybridTree pattern." result = self.translate_hybridroot(context) + case HybridUserGeneratedCollection(): + assert context is not None, "Malformed HybridTree pattern." + result = self.build_user_generated_table(operation) + result = self.join_outputs( + context, + result, + JoinType.INNER, + JoinCardinality.PLURAL_ACCESS, + [], + None, + None, + ) case _: raise NotImplementedError( f"TODO: support relational conversion on {operation.__class__.__name__}" @@ -1304,16 +1342,27 @@ def preprocess_root( """ Transforms the final PyDough collection by appending it with an extra CALCULATE containing all of the columns that are output. + Args: + `node`: the PyDough QDAG collection node to be translated. + `output_cols`: a list of tuples in the form `(alias, column)` + describing every column that should be in the output, in the order + they should appear, and the alias they should be given. If None, uses + the most recent CALCULATE in the node to determine the columns. + Returns: + The PyDoughCollectionQDAG with an additional CALCULATE at the end + that contains all of the columns that should be in the output. """ # Fetch all of the expressions that should be kept in the final output final_terms: list[tuple[str, PyDoughExpressionQDAG]] = [] if output_cols is None: for name in node.calc_terms: - final_terms.append((name, Reference(node, name))) + name_typ: PyDoughType = node.get_expr(name).pydough_type + final_terms.append((name, Reference(node, name, name_typ))) final_terms.sort(key=lambda term: node.get_expression_position(term[0])) else: for _, column in output_cols: - final_terms.append((column, Reference(node, column))) + column_typ: PyDoughType = node.get_expr(column).pydough_type + final_terms.append((column, Reference(node, column, column_typ))) children: list[PyDoughCollectionQDAG] = [] final_calc: Calculate = Calculate(node, children).with_terms(final_terms) return final_calc diff --git a/pydough/qdag/README.md b/pydough/qdag/README.md index 4284e4b1a..73bd1d180 100644 --- a/pydough/qdag/README.md +++ b/pydough/qdag/README.md @@ -65,7 +65,9 @@ table_collection = builder.build_child_access("Nations", global_context_node) # Build a reference node # Equivalent PyDough code: `TPCH.Nations.name` -reference_node = builder.build_reference(table_collection, "name") +ref_name = "name" +pydough_type = table_collection.get_expr(ref_name).pydough_type +reference_node = builder.build_reference(table_collection, ref_name, pydough_type) # Build an expression function call node # Equivalent PyDough code: `LOWER(TPCH.Nations.name)` @@ -99,7 +101,10 @@ regions_collection = builder.build_child_access("Regions", global_context_node) # Access nations sub-collection nations_sub_collection = builder.build_child_access("nations", regions_collection) # Create WHERE(key == 4) condition -key_ref = builder.build_reference(nations_sub_collection, "key") + +ref_name = "key" +pydough_type = nations_sub_collection.get_expr(ref_name).pydough_type +key_ref = builder.build_reference(nations_sub_collection, ref_name, pydough_type) literal_4 = builder.build_literal(4, NumericType()) condition = builder.build_expression_function_call("EQU", [key_ref, literal_4]) # Build WHERE node with condition @@ -108,7 +113,9 @@ where_node = where_node.with_condition(condition) # Create SINGULAR node from filtered result singular_node = builder.build_singular(where_node) # Build reference node for name -reference_node = builder.build_reference(singular_node, "name") +ref_name = "name" +pydough_type = singular_node.get_expr(ref_name).pydough_type +reference_node = builder.build_reference(singular_node, ref_name, pydough_type) # Build CALCULATE node with calculated term calculate_node = builder.build_calc(regions_collection, [nations_sub_collection]) calculate_node = calculate_node.with_terms([("n_4_nation", reference_node)]) @@ -130,7 +137,9 @@ top_k_node = top_k_node.with_collation([collation_expression]) # Build a PARTITION BY node # Equivalent PyDough code: `TPCH.PARTITION(Parts, name="p", by=part_type)` part_collection = builder.build_child_access("Parts", global_context_node) -partition_key = builder.build_reference(part_collection, "part_type") +ref_name = "part_type" +pydough_type = part_collection.get_expr(ref_name).pydough_type +partition_key = builder.build_reference(part_collection, ref_name, pydough_type) partition_by_node = builder.build_partition(part_collection, child_collection, "p") partition_by_node = partition_by_node.with_keys([partition_key]) diff --git a/pydough/qdag/collections/__init__.py b/pydough/qdag/collections/__init__.py index 6eaab3bd8..e9c28db18 100644 --- a/pydough/qdag/collections/__init__.py +++ b/pydough/qdag/collections/__init__.py @@ -21,6 +21,7 @@ "TableCollection", "TopK", "Where", + "range_collection", ] from .augmenting_child_operator import AugmentingChildOperator diff --git a/pydough/qdag/collections/augmenting_child_operator.py b/pydough/qdag/collections/augmenting_child_operator.py index c5f783026..f8ff089a4 100644 --- a/pydough/qdag/collections/augmenting_child_operator.py +++ b/pydough/qdag/collections/augmenting_child_operator.py @@ -84,7 +84,8 @@ def get_term(self, term_name: str) -> PyDoughQDAG: if isinstance(term, ChildAccess): term = term.clone_with_parent(self) elif isinstance(term, PyDoughExpressionQDAG): - term = Reference(self.preceding_context, term_name) + typ = self.preceding_context.get_expr(term_name).pydough_type + term = Reference(self.preceding_context, term_name, typ) return term @cache diff --git a/pydough/qdag/collections/collection_access.py b/pydough/qdag/collections/collection_access.py index 80ac8f40c..1b5907543 100644 --- a/pydough/qdag/collections/collection_access.py +++ b/pydough/qdag/collections/collection_access.py @@ -129,7 +129,9 @@ def get_term(self, term_name: str) -> PyDoughQDAG: else: assert context.ancestor_context is not None context = context.ancestor_context - return Reference(context, term_name) + return Reference( + context, term_name, context.get_expr(term_name).pydough_type + ) if term_name not in self.all_terms: raise PyDoughQDAGException(self.name_mismatch_error(term_name)) diff --git a/pydough/qdag/collections/partition_child.py b/pydough/qdag/collections/partition_child.py index e0609e658..4442591ae 100644 --- a/pydough/qdag/collections/partition_child.py +++ b/pydough/qdag/collections/partition_child.py @@ -102,7 +102,9 @@ def get_term(self, term_name: str): else: assert context.ancestor_context is not None context = context.ancestor_context - return Reference(context, term_name) + return Reference( + context, term_name, context.get_expr(term_name).pydough_type + ) elif term_name not in self.all_terms: raise PyDoughQDAGException(self.name_mismatch_error(term_name)) diff --git a/pydough/qdag/collections/user_collection_qdag.py b/pydough/qdag/collections/user_collection_qdag.py new file mode 100644 index 000000000..f53ffb428 --- /dev/null +++ b/pydough/qdag/collections/user_collection_qdag.py @@ -0,0 +1,136 @@ +from functools import cache + +from pydough.qdag import PyDoughCollectionQDAG +from pydough.qdag.abstract_pydough_qdag import PyDoughQDAG +from pydough.qdag.errors import PyDoughQDAGException +from pydough.qdag.expressions.back_reference_expression import BackReferenceExpression +from pydough.qdag.expressions.reference import Reference +from pydough.types import NumericType +from pydough.user_collections.user_collections import PyDoughUserGeneratedCollection + +from .child_access import ChildAccess + + +class PyDoughUserGeneratedCollectionQDag(ChildAccess): + def __init__( + self, + ancestor: PyDoughCollectionQDAG, + collection: PyDoughUserGeneratedCollection, + ): + assert ancestor is not None + super().__init__(ancestor) + self._collection = collection + self._ancestral_mapping: dict[str, int] = { + name: level + 1 for name, level in ancestor.ancestral_mapping.items() + } + + def clone_with_parent( + self, new_ancestor: PyDoughCollectionQDAG + ) -> "PyDoughUserGeneratedCollectionQDag": + """ + Copies `self` but with a new ancestor node that presumably has the + original ancestor in its predecessor chain. + + Args: + `new_ancestor`: the node to use as the new parent of the clone. + + Returns: + The cloned version of `self`. + """ + return PyDoughUserGeneratedCollectionQDag(new_ancestor, self._collection) + + @property + def collection(self) -> PyDoughUserGeneratedCollection: + """ + The metadata for the table that is being referenced by the collection + node. + """ + return self._collection + + @property + def name(self) -> str: + return self.collection.name + + @property + def calc_terms(self) -> set[str]: + return set(self.collection.columns) + + @property + def ancestral_mapping(self) -> dict[str, int]: + return self._ancestral_mapping + + @property + def inherited_downstreamed_terms(self) -> set[str]: + return self.ancestor_context.inherited_downstreamed_terms + + @cache + def get_term(self, term_name: str) -> PyDoughQDAG: + # Special handling of terms down-streamed + if term_name in self.ancestral_mapping: + # Verify that the ancestor name is not also a name in the current + # context. + if term_name in self.calc_terms: + raise PyDoughQDAGException( + f"Cannot have term name {term_name!r} used in an ancestor of collection {self!r}" + ) + # Create a back-reference to the ancestor term. + return BackReferenceExpression( + self, term_name, self.ancestral_mapping[term_name] + ) + + if term_name in self.inherited_downstreamed_terms: + context: PyDoughCollectionQDAG = self + while term_name not in context.all_terms: + if context is self: + context = self.ancestor_context + else: + assert context.ancestor_context is not None + context = context.ancestor_context + return Reference( + context, term_name, context.get_expr(term_name).pydough_type + ) + + if term_name not in self.all_terms: + raise PyDoughQDAGException(self.name_mismatch_error(term_name)) + + return Reference(self, term_name, NumericType()) + + @property + def all_terms(self) -> set[str]: + """ + The set of expression/subcollection names accessible by the context. + """ + return self.calc_terms + + def is_singular(self, context: "PyDoughCollectionQDAG") -> bool: + return False + + def get_expression_position(self, expr_name: str) -> int: + if expr_name not in self.calc_terms: + raise PyDoughQDAGException( + f"Unrecognized User Collection term: {expr_name!r}" + ) + return self.collection.get_expression_position(expr_name) + + @property + def unique_terms(self) -> list[str]: + return self.collection.columns + + @property + def standalone_string(self) -> str: + """ + Returns a string representation of the collection in a standalone form. + This is used for debugging and logging purposes. + """ + return self.to_string() + + @property + def key(self) -> str: + return f"USER_GENERATED_COLLECTION-{self.name}" + + def to_string(self) -> str: + return self.collection.to_string() + + @property + def tree_item_string(self) -> str: + return self.to_string() diff --git a/pydough/qdag/expressions/back_reference_expression.py b/pydough/qdag/expressions/back_reference_expression.py index 70498908c..6c21c8317 100644 --- a/pydough/qdag/expressions/back_reference_expression.py +++ b/pydough/qdag/expressions/back_reference_expression.py @@ -38,6 +38,14 @@ def __init__( ) self._ancestor = ancestor self._expression = self._ancestor.get_expr(term_name) + self._term_type = self._expression.pydough_type + + @property + def expression(self) -> PyDoughExpressionQDAG: + """ + The expression that the ChildReferenceExpression refers to. + """ + return self._expression @property def back_levels(self) -> int: diff --git a/pydough/qdag/expressions/child_reference_expression.py b/pydough/qdag/expressions/child_reference_expression.py index 5b1b7d15f..c338a16fa 100644 --- a/pydough/qdag/expressions/child_reference_expression.py +++ b/pydough/qdag/expressions/child_reference_expression.py @@ -29,11 +29,19 @@ def __init__( self._child_idx: int = child_idx self._term_name: str = term_name self._expression: PyDoughExpressionQDAG = self._collection.get_expr(term_name) + self._term_type = self._expression.pydough_type if not self.expression.is_singular(collection.starting_predecessor): raise PyDoughQDAGException( f"Cannot reference plural expression {self.expression} from {self.collection}" ) + @property + def expression(self) -> PyDoughExpressionQDAG: + """ + The expression that the ChildReferenceExpression refers to. + """ + return self._expression + @property def child_idx(self) -> int: """ diff --git a/pydough/qdag/expressions/reference.py b/pydough/qdag/expressions/reference.py index 5cc95597f..242a78685 100644 --- a/pydough/qdag/expressions/reference.py +++ b/pydough/qdag/expressions/reference.py @@ -8,7 +8,6 @@ from pydough.qdag.abstract_pydough_qdag import PyDoughQDAG from pydough.qdag.collections.collection_qdag import PyDoughCollectionQDAG -from pydough.qdag.errors import PyDoughQDAGException from pydough.types import PyDoughType from .expression_qdag import PyDoughExpressionQDAG @@ -20,14 +19,12 @@ class Reference(PyDoughExpressionQDAG): a preceding collection. """ - def __init__(self, collection: PyDoughCollectionQDAG, term_name: str): + def __init__( + self, collection: PyDoughCollectionQDAG, term_name: str, term_type: PyDoughType + ): self._collection: PyDoughCollectionQDAG = collection self._term_name: str = term_name - self._expression: PyDoughExpressionQDAG = collection.get_expr(term_name) - if not self.expression.is_singular(collection.starting_predecessor): - raise PyDoughQDAGException( - f"Cannot reference plural expression {self.expression} from {self.collection}" - ) + self._term_type: PyDoughType = term_type @property def collection(self) -> PyDoughCollectionQDAG: @@ -43,20 +40,13 @@ def term_name(self) -> str: """ return self._term_name - @property - def expression(self) -> PyDoughExpressionQDAG: - """ - The original expression that the reference refers to. - """ - return self._expression - @property def pydough_type(self) -> PyDoughType: - return self.expression.pydough_type + return self._term_type @property def is_aggregation(self) -> bool: - return self.expression.is_aggregation + return False def is_singular(self, context: PyDoughQDAG) -> bool: # References are already known to be singular via their construction. @@ -73,4 +63,5 @@ def equals(self, other: object) -> bool: isinstance(other, Reference) and self.term_name == other.term_name and self.collection.equals(other.collection) + and self.pydough_type == other.pydough_type ) diff --git a/pydough/qdag/node_builder.py b/pydough/qdag/node_builder.py index 3bfd3e829..0abe9115b 100644 --- a/pydough/qdag/node_builder.py +++ b/pydough/qdag/node_builder.py @@ -18,7 +18,11 @@ PyDoughOperator, builtin_registered_operators, ) +from pydough.qdag.collections.user_collection_qdag import ( + PyDoughUserGeneratedCollectionQDag, +) from pydough.types import PyDoughType +from pydough.user_collections.user_collections import PyDoughUserGeneratedCollection from .abstract_pydough_qdag import PyDoughQDAG from .collections import ( @@ -156,14 +160,15 @@ def build_window_call( ) def build_reference( - self, collection: PyDoughCollectionQDAG, name: str + self, collection: PyDoughCollectionQDAG, name: str, typ: PyDoughType ) -> Reference: """ - Creates a new reference to an expression from a preceding collection. + Creates a new reference to an expression in the collection. Args: `collection`: the collection that the reference comes from. `name`: the name of the expression being referenced. + `typ`: the PyDough type of the expression being referenced. Returns: The newly created PyDough Reference. @@ -172,7 +177,7 @@ def build_reference( `PyDoughQDAGException`: if `name` does not refer to an expression in the collection. """ - return Reference(collection, name) + return Reference(collection, name, typ) def build_child_reference_expression( self, @@ -392,3 +397,26 @@ def build_singular( The newly created PyDough SINGULAR instance. """ return Singular(preceding_context) + + def build_generated_collection( + self, + preceding_context: PyDoughCollectionQDAG, + user_collection: PyDoughUserGeneratedCollection, + ) -> PyDoughUserGeneratedCollectionQDag: + """ + Creates a new user-defined collection. + + Args: + `preceding_context`: the preceding collection that the + user-defined collection is based on. + `user_collection`: the user-defined collection to be created. + + Returns: + The newly created user-defined collection. + """ + collection_qdag: PyDoughUserGeneratedCollectionQDag = ( + PyDoughUserGeneratedCollectionQDag( + ancestor=preceding_context, collection=user_collection + ) + ) + return collection_qdag diff --git a/pydough/relational/__init__.py b/pydough/relational/__init__.py index ff2cfb653..56cb7fde1 100644 --- a/pydough/relational/__init__.py +++ b/pydough/relational/__init__.py @@ -10,6 +10,7 @@ "EmptySingleton", "ExpressionSortInfo", "Filter", + "GeneratedTable", "Join", "JoinCardinality", "JoinType", @@ -45,6 +46,7 @@ ColumnPruner, EmptySingleton, Filter, + GeneratedTable, Join, JoinCardinality, JoinType, diff --git a/pydough/relational/relational_nodes/__init__.py b/pydough/relational/relational_nodes/__init__.py index 9ecb9a689..991cdfd2d 100644 --- a/pydough/relational/relational_nodes/__init__.py +++ b/pydough/relational/relational_nodes/__init__.py @@ -8,6 +8,7 @@ "ColumnPruner", "EmptySingleton", "Filter", + "GeneratedTable", "Join", "JoinCardinality", "JoinType", @@ -25,6 +26,7 @@ from .column_pruner import ColumnPruner from .empty_singleton import EmptySingleton from .filter import Filter +from .generated_table import GeneratedTable from .join import Join, JoinCardinality, JoinType from .join_type_relational_visitor import JoinTypeRelationalVisitor from .limit import Limit diff --git a/pydough/relational/relational_nodes/generated_table.py b/pydough/relational/relational_nodes/generated_table.py new file mode 100644 index 000000000..6f378a2c0 --- /dev/null +++ b/pydough/relational/relational_nodes/generated_table.py @@ -0,0 +1,63 @@ +""" +This file contains the relational implementation for a "generatedtable" node, +which generally represents user generated table. +""" + +from pydough.relational.relational_expressions import ( + RelationalExpression, +) +from pydough.relational.relational_expressions.column_reference import ColumnReference +from pydough.user_collections.user_collections import PyDoughUserGeneratedCollection + +from .abstract_node import RelationalNode + + +class GeneratedTable(RelationalNode): + """ + The GeneratedTable node in the relational tree. Represents + a user-generated table stored locally which is assumed to be singular + and always available. + """ + + def __init__( + self, + user_collection: PyDoughUserGeneratedCollection, + ) -> None: + columns: dict[str, RelationalExpression] = { + col_name: ColumnReference(col_name, col_type) + for col_name, col_type in user_collection.column_names_and_types + } + super().__init__(columns) + self._collection = user_collection + + @property + def inputs(self) -> list[RelationalNode]: + return [] + + @property + def name(self) -> str: + """Returns the name of the generated table.""" + return self.collection.name + + @property + def collection(self) -> PyDoughUserGeneratedCollection: + """ + The user-generated collection that this generated table represents. + """ + return self._collection + + def node_equals(self, other: RelationalNode) -> bool: + return isinstance(other, GeneratedTable) and self.name == other.name + + def accept(self, visitor: "RelationalVisitor") -> None: # type: ignore # noqa + visitor.visit_generated_table(self) + + def to_string(self, compact=False) -> str: + return f"GENERATED_TABLE(table={self.name}, columns={self.make_column_string(self.columns, compact)})" + + def node_copy( + self, + columns: dict[str, RelationalExpression], + inputs: list[RelationalNode], + ) -> RelationalNode: + return GeneratedTable(self.collection) diff --git a/pydough/relational/relational_nodes/join_type_relational_visitor.py b/pydough/relational/relational_nodes/join_type_relational_visitor.py index 2f402337c..9ae1ee926 100644 --- a/pydough/relational/relational_nodes/join_type_relational_visitor.py +++ b/pydough/relational/relational_nodes/join_type_relational_visitor.py @@ -35,6 +35,9 @@ def visit_inputs(self, node) -> None: def visit_scan(self, scan: Scan) -> None: pass + def visit_generated_table(self, generated_table) -> None: + pass + def visit_join(self, join: Join) -> None: """ Visit a Join node, collecting join types. diff --git a/pydough/relational/relational_nodes/relational_expression_dispatcher.py b/pydough/relational/relational_nodes/relational_expression_dispatcher.py index b296b9869..e5ad70ba8 100644 --- a/pydough/relational/relational_nodes/relational_expression_dispatcher.py +++ b/pydough/relational/relational_nodes/relational_expression_dispatcher.py @@ -77,3 +77,6 @@ def visit_root(self, root: RelationalRoot) -> None: self.visit_common(root) for order in root.orderings: order.expr.accept(self._expr_visitor) + + def visit_generated_table(self, generated_table) -> None: + self.visit_common(generated_table) diff --git a/pydough/relational/relational_nodes/relational_visitor.py b/pydough/relational/relational_nodes/relational_visitor.py index 7f8ebe79d..51fb2804f 100644 --- a/pydough/relational/relational_nodes/relational_visitor.py +++ b/pydough/relational/relational_nodes/relational_visitor.py @@ -118,3 +118,12 @@ def visit_root(self, root: RelationalRoot) -> None: Args: `root`: The root node to visit. """ + + @abstractmethod + def visit_generated_table(self, generated_table) -> None: + """ + Visit a GeneratedTable node. + + Args: + `generated_table`: The generated table node to visit. + """ diff --git a/pydough/relational/relational_nodes/tree_string_visitor.py b/pydough/relational/relational_nodes/tree_string_visitor.py index d47723c80..3dc56eda3 100644 --- a/pydough/relational/relational_nodes/tree_string_visitor.py +++ b/pydough/relational/relational_nodes/tree_string_visitor.py @@ -62,3 +62,6 @@ def visit_empty_singleton(self, empty_singleton) -> None: def visit_root(self, root) -> None: self.visit_node(root) + + def visit_generated_table(self, root) -> None: + self.visit_node(root) diff --git a/pydough/sqlglot/sqlglot_helpers.py b/pydough/sqlglot/sqlglot_helpers.py index c99c53cf5..97c819bc0 100644 --- a/pydough/sqlglot/sqlglot_helpers.py +++ b/pydough/sqlglot/sqlglot_helpers.py @@ -4,6 +4,7 @@ """ from sqlglot.expressions import Alias as SQLGlotAlias +from sqlglot.expressions import Column as SQLGlotColumn from sqlglot.expressions import Expression as SQLGlotExpression from sqlglot.expressions import Identifier @@ -26,6 +27,8 @@ def get_glot_name(expr: SQLGlotExpression) -> str | None: return expr.alias elif isinstance(expr, Identifier): return expr.this + if isinstance(expr, SQLGlotColumn): + return expr.this else: return None diff --git a/pydough/sqlglot/sqlglot_relational_visitor.py b/pydough/sqlglot/sqlglot_relational_visitor.py index 2905f9cae..4f8ddfb4f 100644 --- a/pydough/sqlglot/sqlglot_relational_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_visitor.py @@ -9,7 +9,13 @@ from sqlglot.expressions import Alias as SQLGlotAlias from sqlglot.expressions import Column as SQLGlotColumn from sqlglot.expressions import Expression as SQLGlotExpression -from sqlglot.expressions import Identifier, Select, Subquery, TableAlias, values +from sqlglot.expressions import ( + Identifier, + Select, + Subquery, + TableAlias, + values, +) from sqlglot.expressions import Literal as SQLGlotLiteral from sqlglot.expressions import Star as SQLGlotStar from sqlglot.expressions import convert as sqlglot_convert @@ -26,6 +32,7 @@ EmptySingleton, ExpressionSortInfo, Filter, + GeneratedTable, Join, JoinType, Limit, @@ -585,6 +592,61 @@ def visit_root(self, root: RelationalRoot) -> None: query = query.order_by(*ordering_exprs) self._stack.append(query) + def visit_generated_table(self, generated_table: "GeneratedTable") -> None: + """convert the `GeneratedTable` to SQL code based on which underlying + `PyDoughUserGeneratedCollection` it uses + + Args: + generated_table: The generated table node to visit. + + """ + # TODO: match on the type of `generated_table.collection` to determine + # how to convert it to SQLGlot. For now, assume it only range case. + + # Step 1: Build the column expression list (only one column in a range) + column_names: list[str] = [] + for column in generated_table.columns.keys(): + column_names.append(column) + + rows = [(i,) for i in generated_table.collection.data] + + # Handle empty range by injecting a single NULL row + if not rows: + from sqlglot import exp + + query = ( + Select() + .select( + SQLGlotAlias( + this=exp.Cast( + this=exp.Null(), to=exp.DataType.build("INTEGER") + ), + alias=Identifier(this=column_names[0]), + ) + ) + .where(exp.false()) + ) + else: + # Step 2: Build VALUES expression WITHOUT column names + values_expr = values(values=rows, alias=generated_table.name) + + # Step 3: Create a SELECT statement from the VALUES expression + # and alias the values column (named "column1" in SQLite) to the first column name. + # TODO: Handle other dialects that may not use "column1" as the default name. + query = ( + Select() + .from_(values_expr) + .select( + SQLGlotAlias( + this=SQLGlotColumn(this=Identifier(this="column1")), + alias=Identifier(this=column_names[0]), + ) + ) + ) + + # Step 4: Append to stack + self._stack.append(query) + def relational_to_sqlglot(self, root: RelationalRoot) -> SQLGlotExpression: """ Interface to convert an entire relational tree to a SQLGlot expression. diff --git a/pydough/unqualified/qualification.py b/pydough/unqualified/qualification.py index 3430ab7e0..459d52cf1 100644 --- a/pydough/unqualified/qualification.py +++ b/pydough/unqualified/qualification.py @@ -48,6 +48,7 @@ UnqualifiedCalculate, UnqualifiedCollation, UnqualifiedCross, + UnqualifiedGeneratedCollection, UnqualifiedLiteral, UnqualifiedNode, UnqualifiedOperation, @@ -617,7 +618,8 @@ def qualify_access( if isinstance(unqualified_parent, UnqualifiedRoot): # If at the root, the access must be a reference to a scalar # attribute accessible in the current context. - return self.builder.build_reference(context, name) + typ: PyDoughType = context.get_expr(name).pydough_type + return self.builder.build_reference(context, name, typ) else: # Otherwise, the access is a reference to a scalar attribute of # a child collection node of the current context. Add this new @@ -1260,6 +1262,41 @@ def qualify_cross( ) return qualified_child + def qualify_generated_collection( + self, + unqualified: UnqualifiedGeneratedCollection, + context: PyDoughCollectionQDAG, + is_child: bool, + is_cross: bool, + ) -> PyDoughCollectionQDAG: + """ + Transforms an `UnqualifiedGeneratedCollection` into a PyDoughCollectionQDAG node. + + Args: + `unqualified`: the UnqualifiedGeneratedCollection instance to be transformed. + `context`: the collection QDAG whose context the collection is being + evaluated within. + `is_child`: whether the collection is being qualified as a child + of a child operator context, such as CALCULATE or PARTITION. + `is_cross`: whether the collection being qualified is a CROSS JOIN operation + + Returns: + The PyDough QDAG object for the qualified collection node. + + """ + + generated_collection_qdag: PyDoughCollectionQDAG = ( + self.builder.build_generated_collection( + context, + unqualified._parcel[0], + ) + ) + if is_child: + generated_collection_qdag = ChildOperatorChildAccess( + generated_collection_qdag + ) + return generated_collection_qdag + def qualify_node( self, unqualified: UnqualifiedNode, @@ -1332,6 +1369,10 @@ def qualify_node( answer = self.qualify_best(unqualified, context, is_child, is_cross) case UnqualifiedCross(): answer = self.qualify_cross(unqualified, context, is_child, is_cross) + case UnqualifiedGeneratedCollection(): + answer = self.qualify_generated_collection( + unqualified, context, is_child, is_cross + ) case _: raise PyDoughUnqualifiedException( f"Cannot qualify {unqualified.__class__.__name__}: {unqualified!r}" diff --git a/pydough/unqualified/unqualified_node.py b/pydough/unqualified/unqualified_node.py index 24ca5eb5f..5ecee4f18 100644 --- a/pydough/unqualified/unqualified_node.py +++ b/pydough/unqualified/unqualified_node.py @@ -8,6 +8,7 @@ "UnqualifiedBinaryOperation", "UnqualifiedCalculate", "UnqualifiedCross", + "UnqualifiedGeneratedCollection", "UnqualifiedLiteral", "UnqualifiedNode", "UnqualifiedOperation", @@ -39,6 +40,7 @@ StringType, UnknownType, ) +from pydough.user_collections.user_collections import PyDoughUserGeneratedCollection from .errors import PyDoughUnqualifiedException @@ -835,6 +837,13 @@ def __init__( ] = (data, by, per, allow_ties, n_best) +class UnqualifiedGeneratedCollection(UnqualifiedNode): + """Represents a user-generated collection of values.""" + + def __init__(self, user_collection: PyDoughUserGeneratedCollection): + self._parcel: tuple[PyDoughUserGeneratedCollection] = (user_collection,) + + def display_raw(unqualified: UnqualifiedNode) -> str: """ Prints an unqualified node in a human-readable manner that shows its @@ -932,6 +941,12 @@ def display_raw(unqualified: UnqualifiedNode) -> str: if unqualified._parcel[4] > 1: result += f", n_best={unqualified._parcel[4]}" return result + ")" + case UnqualifiedGeneratedCollection(): + result = "generated_collection(" + result += f"name={unqualified._parcel[0].name!r}, " + result += f"columns=[{', '.join(unqualified._parcel[0].columns)}]," + result += f"data={unqualified._parcel[0].data}" + return result + ")" case _: raise PyDoughUnqualifiedException( f"Unsupported unqualified node: {unqualified.__class__.__name__}" diff --git a/pydough/unqualified/unqualified_transform.py b/pydough/unqualified/unqualified_transform.py index 3cea3782e..cd2ad0c37 100644 --- a/pydough/unqualified/unqualified_transform.py +++ b/pydough/unqualified/unqualified_transform.py @@ -1,6 +1,6 @@ """ Logic for transforming raw Python code into PyDough code by replacing undefined -variables with unqualified nodes by prepending with with `_ROOT.`. +variables with unqualified nodes by prepending it with `_ROOT.`. """ __all__ = ["from_string", "init_pydough_context", "transform_cell", "transform_code"] @@ -364,8 +364,8 @@ def transform_code( source: str, graph_dict: dict[str, GraphMetadata], known_names: set[str] ) -> ast.AST: """ - Transforms the source code into a new Python QDAG that has had the PyDough - decorator removed, had the definition of `_ROOT` injected at the top of the + Transforms the source code into a new Python QDAG that has the PyDough + decorator removed, has the definition of `_ROOT` injected at the top of the function body, and prepend unknown variables with `_ROOT.` Args: diff --git a/pydough/user_collections/__init__.py b/pydough/user_collections/__init__.py new file mode 100644 index 000000000..fa78c0640 --- /dev/null +++ b/pydough/user_collections/__init__.py @@ -0,0 +1,5 @@ +""" +Module of PyDough dealing with APIs used for user generated collections. +""" + +__all__ = ["range_collection"] diff --git a/pydough/user_collections/range_collection.py b/pydough/user_collections/range_collection.py new file mode 100644 index 000000000..48cbebf50 --- /dev/null +++ b/pydough/user_collections/range_collection.py @@ -0,0 +1,94 @@ +"""A user-defined collection of integers in a specified range. +Usage: +`pydough.range_collection(name, column, *args)` + args: start, end, step + +This module defines a collection that generates integers from `start` to `end` +with a specified `step`. The user must specify the name of the collection and the +name of the column that will hold the integer values. +""" + +from typing import Any + +from pydough.types import NumericType +from pydough.types.pydough_type import PyDoughType +from pydough.user_collections.user_collections import PyDoughUserGeneratedCollection + +all = ["RangeGeneratedCollection"] + + +class RangeGeneratedCollection(PyDoughUserGeneratedCollection): + """Integer range-based collection.""" + + def __init__( + self, + name: str, + column_name: str, + range: range, + ) -> None: + super().__init__( + name=name, + columns=[ + column_name, + ], + types=[NumericType()], + ) + self._range = range + self._start = self._range.start + self._end = self._range.stop + self._step = self._range.step + + @property + def start(self) -> int | None: + """Return the start of the range.""" + return self._start + + @property + def end(self) -> int | None: + """Return the end of the range.""" + return self._end + + @property + def step(self) -> int | None: + """Return the step of the range.""" + return self._step + + @property + def range(self) -> range: + """Return the range object representing the collection.""" + return self._range + + @property + def column_names_and_types(self) -> list[tuple[str, PyDoughType]]: + return [(self.columns[0], NumericType())] + + @property + def data(self) -> Any: + """Return the range as the data of the collection.""" + return self.range + + def __len__(self) -> int: + return len(self._range) + + def is_singular(self) -> bool: + """Returns True if the collection is guaranteed to contain at most one row.""" + return len(self) <= 1 + + def always_exists(self) -> bool: + """Check if the range collection is always non-empty.""" + return len(self) > 0 + + def to_string(self) -> str: + """Return a string representation of the range collection.""" + return f"RangeCollection({self.name}!r, {self.columns[0]}={self.range})" + + def equals(self, other) -> bool: + if not isinstance(other, RangeGeneratedCollection): + return False + return ( + self.name == other.name + and self.columns == other.columns + and self.start == other.start + and self.end == other.end + and self.step == other.step + ) diff --git a/pydough/user_collections/user_collection_apis.py b/pydough/user_collections/user_collection_apis.py new file mode 100644 index 000000000..ea008691e --- /dev/null +++ b/pydough/user_collections/user_collection_apis.py @@ -0,0 +1,45 @@ +""" +Implementation of User Collection APIs in PyDough. +""" + +__all__ = ["range_collection"] + +from pydough.unqualified.unqualified_node import UnqualifiedGeneratedCollection +from pydough.user_collections.range_collection import RangeGeneratedCollection + + +def range_collection( + name: str, column: str, *args: int +) -> UnqualifiedGeneratedCollection: + """ + Implementation of the `pydough.range_collection` function, which provides + a way to create a collection of integer ranges over a specified column in PyDough. + + Args: + `name` : The name of the collection. + `column` : The column to create ranges for. + `*args` : Variable length arguments that specify the range parameters. + Supported formats: + - `range_collection(end)`: generates a range from 0 to `end-1` + with a step of 1. + - `range_collection(start, end)`: generates a range from `start` + to `end-1` with a step of 1. + - `range_collection(start, end, step)`: generates a range from + `start` to `end-1` with the specified step. + Returns: + A collection of integer ranges. + """ + if not isinstance(name, str): + raise TypeError(f"Expected 'name' to be a string, got {type(name).__name__}") + if not isinstance(column, str): + raise TypeError( + f"Expected 'column' to be a string, got {type(column).__name__}" + ) + r = range(*args) + range_collection = RangeGeneratedCollection( + name=name, + column_name=column, + range=r, + ) + + return UnqualifiedGeneratedCollection(range_collection) diff --git a/pydough/user_collections/user_collections.py b/pydough/user_collections/user_collections.py new file mode 100644 index 000000000..3ab73b26f --- /dev/null +++ b/pydough/user_collections/user_collections.py @@ -0,0 +1,88 @@ +""" +Base definition of PyDough QDAG collection type for accesses to a user defined +collection of the current context. +""" + +from abc import ABC, abstractmethod +from typing import Any + +from pydough.types.pydough_type import PyDoughType + +__all__ = ["PyDoughUserGeneratedCollection"] + + +class PyDoughUserGeneratedCollection(ABC): + """ + Abstract base class for a user defined table collection. + This class defines the interface for accessing a user defined table collection + directly, without any specific implementation details. + It is intended to be subclassed by specific implementations that provide + the actual behavior and properties of the collection. + """ + + def __init__(self, name: str, columns: list[str], types: list[PyDoughType]) -> None: + self._name = name + self._columns = columns + self._types = types + + def __eq__(self, other) -> bool: + return self.equals(other) + + def __repr__(self) -> str: + return self.to_string() + + def __hash__(self) -> int: + return hash(repr(self)) + + def __str__(self) -> str: + return self.to_string() + + @property + def name(self) -> str: + """Return the name used for the collection.""" + return self._name + + @property + def columns(self) -> list[str]: + """Return column names.""" + return self._columns + + @property + @abstractmethod + def column_names_and_types(self) -> list[tuple[str, PyDoughType]]: + """Return column names and their types.""" + + @property + @abstractmethod + def data(self) -> Any: + """Return information about data in the collection.""" + + @abstractmethod + def always_exists(self) -> bool: + """Check if the collection is always non-empty.""" + + @abstractmethod + def is_singular(self) -> bool: + """Returns True if the collection is guaranteed to contain at most one row.""" + + @abstractmethod + def to_string(self) -> str: + """Return a string representation of the collection.""" + + @abstractmethod + def equals(self, other) -> bool: + """ + Check if this collection is equal to another collection. + Two collections are considered equal if they have the same name and columns. + """ + + def get_expression_position(self, expr_name: str) -> int: + """ + Get the position of an expression in the collection. + This is used to determine the order of expressions in the collection. + """ + if expr_name not in self.columns: + raise ValueError( + f"Expression {expr_name!r} not found in collection {self.name!r}" + ) + return self.columns.index(expr_name) diff --git a/tests/test_pipeline_tpch_custom.py b/tests/test_pipeline_tpch_custom.py index 56a44feb8..667798120 100644 --- a/tests/test_pipeline_tpch_custom.py +++ b/tests/test_pipeline_tpch_custom.py @@ -177,6 +177,14 @@ year_month_nation_orders, yoy_change_in_num_orders, ) +from tests.test_pydough_functions.user_collections import ( + simple_range_1, + simple_range_2, + simple_range_3, + simple_range_4, + simple_range_5, + user_range_collection_1, +) from .testing_utilities import PyDoughPandasTest, graph_fetcher, run_e2e_error_test @@ -2821,6 +2829,110 @@ ), id="quantile_function_test_4", ), + pytest.param( + PyDoughPandasTest( + simple_range_1, + "TPCH", + lambda: pd.DataFrame({"value": range(10)}), + "simple_range_1", + ), + id="simple_range_1", + ), + pytest.param( + PyDoughPandasTest( + simple_range_2, + "TPCH", + lambda: pd.DataFrame({"value": range(9, -1, -1)}), + "simple_range_2", + ), + id="simple_range_2", + ), + pytest.param( + PyDoughPandasTest( + simple_range_3, + "TPCH", + lambda: pd.DataFrame({"foo": range(15, 20)}), + "simple_range_3", + ), + id="simple_range_3", + ), + pytest.param( + PyDoughPandasTest( + simple_range_4, + "TPCH", + lambda: pd.DataFrame({"foo": range(10, 0, -1)}), + "simple_range_4", + ), + id="simple_range_4", + ), + pytest.param( + PyDoughPandasTest( + simple_range_5, + "TPCH", + # TODO: even though generated SQL has CAST(NULL AS INT) AS x + # it returns x as object datatype. + # using `x: range(-1)` returns int64 so temp. using dtype=object + lambda: pd.DataFrame({"x": pd.Series(range(-1), dtype="object")}), + "simple_range_5", + ), + id="simple_range_5", + ), + pytest.param( + PyDoughPandasTest( + user_range_collection_1, + "TPCH", + lambda: pd.DataFrame( + { + "part_size": [ + 1, + 6, + 11, + 16, + 21, + 26, + 31, + 36, + 41, + 46, + 51, + 56, + 61, + 66, + 71, + 76, + 81, + 86, + 91, + 96, + ], + "n_parts": [ + 228, + 225, + 206, + 234, + 228, + 221, + 231, + 208, + 245, + 226, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + } + ), + "user_range_collection_1", + ), + id="user_range_collection_1", + ), ], ) def tpch_custom_pipeline_test_data(request) -> PyDoughPandasTest: diff --git a/tests/test_plan_refsols/simple_range_1.txt b/tests/test_plan_refsols/simple_range_1.txt new file mode 100644 index 000000000..075221523 --- /dev/null +++ b/tests/test_plan_refsols/simple_range_1.txt @@ -0,0 +1,2 @@ +ROOT(columns=[('value', value)], orderings=[]) + GENERATED_TABLE(table=simple_range, columns={'value': value}) diff --git a/tests/test_plan_refsols/simple_range_2.txt b/tests/test_plan_refsols/simple_range_2.txt new file mode 100644 index 000000000..30cdabc8f --- /dev/null +++ b/tests/test_plan_refsols/simple_range_2.txt @@ -0,0 +1,2 @@ +ROOT(columns=[('value', value)], orderings=[(value):desc_last]) + GENERATED_TABLE(table=simple_range, columns={'value': value}) diff --git a/tests/test_plan_refsols/simple_range_3.txt b/tests/test_plan_refsols/simple_range_3.txt new file mode 100644 index 000000000..e837f9663 --- /dev/null +++ b/tests/test_plan_refsols/simple_range_3.txt @@ -0,0 +1,2 @@ +ROOT(columns=[('foo', foo)], orderings=[(foo):asc_first]) + GENERATED_TABLE(table=T1, columns={'foo': foo}) diff --git a/tests/test_plan_refsols/simple_range_4.txt b/tests/test_plan_refsols/simple_range_4.txt new file mode 100644 index 000000000..7317a3b3c --- /dev/null +++ b/tests/test_plan_refsols/simple_range_4.txt @@ -0,0 +1,2 @@ +ROOT(columns=[('N', N)], orderings=[(N):asc_first]) + GENERATED_TABLE(table=T2, columns={'N': N}) diff --git a/tests/test_plan_refsols/simple_range_5.txt b/tests/test_plan_refsols/simple_range_5.txt new file mode 100644 index 000000000..b74e318cb --- /dev/null +++ b/tests/test_plan_refsols/simple_range_5.txt @@ -0,0 +1,2 @@ +ROOT(columns=[('x', x)], orderings=[]) + GENERATED_TABLE(table=T3, columns={'x': x}) diff --git a/tests/test_plan_refsols/user_range_collection_1.txt b/tests/test_plan_refsols/user_range_collection_1.txt new file mode 100644 index 000000000..496eb58f6 --- /dev/null +++ b/tests/test_plan_refsols/user_range_collection_1.txt @@ -0,0 +1,8 @@ +ROOT(columns=[('part_size', part_size), ('n_parts', DEFAULT_TO(n_rows, 0:numeric))], orderings=[]) + JOIN(condition=t0.part_size == t1.part_size, type=LEFT, cardinality=SINGULAR_ACCESS, columns={'n_rows': t1.n_rows, 'part_size': t0.part_size}) + GENERATED_TABLE(table=sizes, columns={'part_size': part_size}) + AGGREGATE(keys={'part_size': part_size}, aggregations={'n_rows': COUNT()}) + JOIN(condition=t1.p_size == t0.part_size, type=INNER, cardinality=PLURAL_FILTER, columns={'part_size': t0.part_size}) + GENERATED_TABLE(table=sizes, columns={'part_size': part_size}) + FILTER(condition=CONTAINS(p_name, 'turquoise':string), columns={'p_size': p_size}) + SCAN(table=tpch.PART, columns={'p_name': p_name, 'p_size': p_size}) diff --git a/tests/test_pydough_functions/all_pydough_functions_dialects.py b/tests/test_pydough_functions/all_pydough_functions_dialects.py index b5fd4dec7..d3f10ef86 100644 --- a/tests/test_pydough_functions/all_pydough_functions_dialects.py +++ b/tests/test_pydough_functions/all_pydough_functions_dialects.py @@ -16,6 +16,7 @@ import pandas as pd import datetime +import pydough def arithmetic_and_binary_operators(): @@ -228,3 +229,15 @@ def casting_functions(): cast_to_integer=INTEGER(total_price), cast_to_float=FLOAT(ship_priority), ) + + +def range_functions(): + """Test all range PyDough functions. + Main purpose to verify these functions are working as expected with + supported SQL dialects. + """ + return pydough.range_collection( + "simple_range", + "value", + 10, # end value + ).ORDER_BY(value.DESC()) # Order by descending value diff --git a/tests/test_pydough_functions/user_collections.py b/tests/test_pydough_functions/user_collections.py new file mode 100644 index 000000000..753d61931 --- /dev/null +++ b/tests/test_pydough_functions/user_collections.py @@ -0,0 +1,109 @@ +""" +Various functions containing user generated collections as +PyDough code snippets for testing purposes. +""" +# ruff: noqa +# mypy: ignore-errors +# ruff & mypy should not try to typecheck or verify any of this + +import pydough + + +def simple_range_1(): + # Generates a table with column named `value` containing integers from 0 to 9. + return pydough.range_collection( + "simple_range", + "value", + 10, # end value + ) + + +def simple_range_2(): + # Generates a table with column named `value` containing integers from 0 to 9, + # ordered in descending order. + return pydough.range_collection( + "simple_range", + "value", + 10, # end value + ).ORDER_BY(value.DESC()) + + +def simple_range_3(): + # Generates a table with column named `foo` containing integers from 15 to + # 20 exclusive, ordered in ascending order. + return pydough.range_collection("T1", "foo", 15, 20).ORDER_BY(foo.ASC()) + + +def simple_range_4(): + # Generate a table with 1 column named `N` counting backwards + # from 10 to 1 (inclusive) + return pydough.range_collection("T2", "N", 10, 0, -1).ORDER_BY(N.ASC()) + + +def simple_range_5(): + # Generate a table with 1 column named `x` which is an empty range + return pydough.range_collection("T3", "x", -1) + + +def user_range_collection_1(): + # Creates a collection `sizes` with a single property `part_size` whose values are the + # integers from 1 (inclusive) to 100 (exclusive), skipping by 5s, then for each size value, + # counts how many turquoise parts have that size. + sizes = pydough.range_collection("sizes", "part_size", 1, 100, 5) + turquoise_parts = parts.WHERE(CONTAINS(name, "turquoise")) + return sizes.CALCULATE(part_size).CALCULATE( + part_size, n_parts=COUNT(CROSS(turquoise_parts).WHERE(size == part_size)) + ) + + +def user_range_collection_2(): + # Generate two tables with one column: `a` has a column `x` of digits 0-9, + # `b` has a column `y` of every even number from 0 to 1000 (inclusive), and for + # every row of `a` count how many rows of `b` have `x` has a prefix of `y`, and + # how many have `x` as a suffix of `y` + table_a = pydough.range_collection("a", "x", 10) + table_b = pydough.range_collection("b", "y", 0, 1001, 2) + result = ( + table_a.CALCULATE(x) + .CALCULATE( + x, + n_prefix=COUNT(CROSS(table_b).WHERE(STARTSWITH(STRING(y), STRING(x)))), + n_suffix=COUNT(CROSS(table_b).WHERE(ENDSWITH(STRING(y), STRING(x)))), + ) + .ORDER_BY(x.ASC()) + ) + return result + + +def user_range_collection_3(): + # Same as user_range_collection_2 but only includes rows of x that + # have at least one prefix/suffix max + table_a = pydough.range_collection("a", "x", 10) + table_b = pydough.range_collection("b", "y", 0, 1001, 2) + prefix_b = CROSS(table_b).WHERE(STARTSWITH(STRING(y), STRING(x))) + suffix_b = CROSS(table_b).WHERE(ENDSWITH(STRING(y), STRING(x))) + return ( + table_a.CALCULATE(x) + .CALCULATE( + x, + n_prefix=COUNT(prefix_b), + n_suffix=COUNT(suffix_b), + ) + .WHERE(HAS(prefix_b) & HAS(suffix_b)) + .ORDER_BY(x.ASC()) + ) + + +def user_range_collection_4(): + # For every part size 1-10, find the name & retail price of the cheapest part + # of that size that is azure, plated, and has a small drum container + sizes = pydough.range_collection("sizes", "part_size", 10) + turquoise_parts = parts.WHERE(CONTAINS(name, "turquoise")) + return ( + sizes.CALCULATE(part_size) + .CROSS(turquoise_parts) + .WHERE(size == part_size) + .BEST(per="sizes", by=retail_price.ASC()) + .CALCULATE(part_size, name, retail_price) + .ORDER_BY(part_size.ASC()) + ) diff --git a/tests/test_pydough_to_sql.py b/tests/test_pydough_to_sql.py index 2c0a13b85..01775c112 100644 --- a/tests/test_pydough_to_sql.py +++ b/tests/test_pydough_to_sql.py @@ -53,6 +53,17 @@ window_sliding_frame_relsize, window_sliding_frame_relsum, ) +from tests.test_pydough_functions.user_collections import ( + simple_range_1, + simple_range_2, + simple_range_3, + simple_range_4, + simple_range_5, + user_range_collection_1, + user_range_collection_2, + user_range_collection_3, + user_range_collection_4, +) from tests.testing_utilities import ( graph_fetcher, ) @@ -192,6 +203,35 @@ pytest.param( casting_functions, None, "casting_functions", id="casting_functions" ), + pytest.param(simple_range_1, None, "simple_range_1", id="simple_range_1"), + pytest.param(simple_range_2, None, "simple_range_2", id="simple_range_2"), + pytest.param(simple_range_3, None, "simple_range_3", id="simple_range_3"), + pytest.param(simple_range_4, None, "simple_range_4", id="simple_range_4"), + pytest.param(simple_range_5, None, "simple_range_5", id="simple_range_5"), + pytest.param( + user_range_collection_1, + None, + "user_range_collection_1", + id="user_range_collection_1", + ), + pytest.param( + user_range_collection_2, + None, + "user_range_collection_2", + id="user_range_collection_2", + ), + pytest.param( + user_range_collection_3, + None, + "user_range_collection_3", + id="user_range_collection_3", + ), + pytest.param( + user_range_collection_4, + None, + "user_range_collection_4", + id="user_range_collection_4", + ), ], ) def test_pydough_to_sql_tpch( diff --git a/tests/test_qualification.py b/tests/test_qualification.py index aa5da57fa..645f90c3c 100644 --- a/tests/test_qualification.py +++ b/tests/test_qualification.py @@ -63,6 +63,10 @@ impl_tpch_q21, impl_tpch_q22, ) +from tests.test_pydough_functions.user_collections import ( + simple_range_1, + simple_range_2, +) from tests.testing_utilities import ( graph_fetcher, ) @@ -942,6 +946,23 @@ """, id="simple_cross_6", ), + pytest.param( + simple_range_1, + """ +──┬─ TPCH + └─── RangeCollection(simple_range!r, value=range(0, 10)) + """, + id="simple_range_1", + ), + pytest.param( + simple_range_2, + """ +──┬─ TPCH + ├─── RangeCollection(simple_range!r, value=range(0, 10)) + └─── OrderBy[value.DESC(na_pos='last')] + """, + id="simple_range_2", + ), ], ) def test_qualify_node_to_ast_string( diff --git a/tests/test_sql_refsols/simple_range_1_ansi.sql b/tests/test_sql_refsols/simple_range_1_ansi.sql new file mode 100644 index 000000000..df860664c --- /dev/null +++ b/tests/test_sql_refsols/simple_range_1_ansi.sql @@ -0,0 +1,13 @@ +SELECT + column1 AS value +FROM (VALUES + (0), + (1), + (2), + (3), + (4), + (5), + (6), + (7), + (8), + (9)) AS simple_range(_col_0) diff --git a/tests/test_sql_refsols/simple_range_1_sqlite.sql b/tests/test_sql_refsols/simple_range_1_sqlite.sql new file mode 100644 index 000000000..6c4a98470 --- /dev/null +++ b/tests/test_sql_refsols/simple_range_1_sqlite.sql @@ -0,0 +1,13 @@ +SELECT + column1 AS value +FROM (VALUES + (0), + (1), + (2), + (3), + (4), + (5), + (6), + (7), + (8), + (9)) AS simple_range diff --git a/tests/test_sql_refsols/simple_range_2_ansi.sql b/tests/test_sql_refsols/simple_range_2_ansi.sql new file mode 100644 index 000000000..453e9fabb --- /dev/null +++ b/tests/test_sql_refsols/simple_range_2_ansi.sql @@ -0,0 +1,15 @@ +SELECT + column1 AS value +FROM (VALUES + (0), + (1), + (2), + (3), + (4), + (5), + (6), + (7), + (8), + (9)) AS simple_range(_col_0) +ORDER BY + value DESC diff --git a/tests/test_sql_refsols/simple_range_2_sqlite.sql b/tests/test_sql_refsols/simple_range_2_sqlite.sql new file mode 100644 index 000000000..3abbcd413 --- /dev/null +++ b/tests/test_sql_refsols/simple_range_2_sqlite.sql @@ -0,0 +1,15 @@ +SELECT + column1 AS value +FROM (VALUES + (0), + (1), + (2), + (3), + (4), + (5), + (6), + (7), + (8), + (9)) AS simple_range +ORDER BY + value DESC diff --git a/tests/test_sql_refsols/simple_range_3_ansi.sql b/tests/test_sql_refsols/simple_range_3_ansi.sql new file mode 100644 index 000000000..c80b7231c --- /dev/null +++ b/tests/test_sql_refsols/simple_range_3_ansi.sql @@ -0,0 +1,10 @@ +SELECT + column1 AS foo +FROM (VALUES + (15), + (16), + (17), + (18), + (19)) AS t1(_col_0) +ORDER BY + foo diff --git a/tests/test_sql_refsols/simple_range_3_sqlite.sql b/tests/test_sql_refsols/simple_range_3_sqlite.sql new file mode 100644 index 000000000..d1c9bc740 --- /dev/null +++ b/tests/test_sql_refsols/simple_range_3_sqlite.sql @@ -0,0 +1,10 @@ +SELECT + column1 AS foo +FROM (VALUES + (15), + (16), + (17), + (18), + (19)) AS t1 +ORDER BY + foo diff --git a/tests/test_sql_refsols/simple_range_4_ansi.sql b/tests/test_sql_refsols/simple_range_4_ansi.sql new file mode 100644 index 000000000..96b2b5c32 --- /dev/null +++ b/tests/test_sql_refsols/simple_range_4_ansi.sql @@ -0,0 +1,15 @@ +SELECT + column1 AS N +FROM (VALUES + (10), + (9), + (8), + (7), + (6), + (5), + (4), + (3), + (2), + (1)) AS t2(_col_0) +ORDER BY + n diff --git a/tests/test_sql_refsols/simple_range_4_sqlite.sql b/tests/test_sql_refsols/simple_range_4_sqlite.sql new file mode 100644 index 000000000..76addb196 --- /dev/null +++ b/tests/test_sql_refsols/simple_range_4_sqlite.sql @@ -0,0 +1,15 @@ +SELECT + column1 AS N +FROM (VALUES + (10), + (9), + (8), + (7), + (6), + (5), + (4), + (3), + (2), + (1)) AS t2 +ORDER BY + n diff --git a/tests/test_sql_refsols/simple_range_5_ansi.sql b/tests/test_sql_refsols/simple_range_5_ansi.sql new file mode 100644 index 000000000..e4f7b84dd --- /dev/null +++ b/tests/test_sql_refsols/simple_range_5_ansi.sql @@ -0,0 +1,4 @@ +SELECT + CAST(NULL AS INT) AS x +WHERE + FALSE diff --git a/tests/test_sql_refsols/simple_range_5_sqlite.sql b/tests/test_sql_refsols/simple_range_5_sqlite.sql new file mode 100644 index 000000000..cd81c843d --- /dev/null +++ b/tests/test_sql_refsols/simple_range_5_sqlite.sql @@ -0,0 +1,4 @@ +SELECT + CAST(NULL AS INTEGER) AS x +WHERE + FALSE diff --git a/tests/test_sql_refsols/user_range_collection_1_ansi.sql b/tests/test_sql_refsols/user_range_collection_1_ansi.sql new file mode 100644 index 000000000..5f2bdec2d --- /dev/null +++ b/tests/test_sql_refsols/user_range_collection_1_ansi.sql @@ -0,0 +1,40 @@ +WITH _s2 AS ( + SELECT + column1 AS part_size + FROM (VALUES + (1), + (6), + (11), + (16), + (21), + (26), + (31), + (36), + (41), + (46), + (51), + (56), + (61), + (66), + (71), + (76), + (81), + (86), + (91), + (96)) AS sizes(_col_0) +), _s3 AS ( + SELECT + COUNT(*) AS n_rows, + _s0.part_size + FROM _s2 AS _s0 + JOIN tpch.part AS part + ON _s0.part_size = part.p_size AND part.p_name LIKE '%turquoise%' + GROUP BY + _s0.part_size +) +SELECT + _s2.part_size, + COALESCE(_s3.n_rows, 0) AS n_parts +FROM _s2 AS _s2 +LEFT JOIN _s3 AS _s3 + ON _s2.part_size = _s3.part_size diff --git a/tests/test_sql_refsols/user_range_collection_1_sqlite.sql b/tests/test_sql_refsols/user_range_collection_1_sqlite.sql new file mode 100644 index 000000000..2ff0d8eb7 --- /dev/null +++ b/tests/test_sql_refsols/user_range_collection_1_sqlite.sql @@ -0,0 +1,40 @@ +WITH _s2 AS ( + SELECT + column1 AS part_size + FROM (VALUES + (1), + (6), + (11), + (16), + (21), + (26), + (31), + (36), + (41), + (46), + (51), + (56), + (61), + (66), + (71), + (76), + (81), + (86), + (91), + (96)) AS sizes +), _s3 AS ( + SELECT + COUNT(*) AS n_rows, + _s0.part_size + FROM _s2 AS _s0 + JOIN tpch.part AS part + ON _s0.part_size = part.p_size AND part.p_name LIKE '%turquoise%' + GROUP BY + _s0.part_size +) +SELECT + _s2.part_size, + COALESCE(_s3.n_rows, 0) AS n_parts +FROM _s2 AS _s2 +LEFT JOIN _s3 AS _s3 + ON _s2.part_size = _s3.part_size diff --git a/tests/test_sql_refsols/user_range_collection_2_ansi.sql b/tests/test_sql_refsols/user_range_collection_2_ansi.sql new file mode 100644 index 000000000..ba503f99b --- /dev/null +++ b/tests/test_sql_refsols/user_range_collection_2_ansi.sql @@ -0,0 +1,548 @@ +WITH _s0 AS ( + SELECT + column1 AS x + FROM (VALUES + (0), + (1), + (2), + (3), + (4), + (5), + (6), + (7), + (8), + (9)) AS a(_col_0) +), _s1 AS ( + SELECT + column1 AS y + FROM (VALUES + (0), + (2), + (4), + (6), + (8), + (10), + (12), + (14), + (16), + (18), + (20), + (22), + (24), + (26), + (28), + (30), + (32), + (34), + (36), + (38), + (40), + (42), + (44), + (46), + (48), + (50), + (52), + (54), + (56), + (58), + (60), + (62), + (64), + (66), + (68), + (70), + (72), + (74), + (76), + (78), + (80), + (82), + (84), + (86), + (88), + (90), + (92), + (94), + (96), + (98), + (100), + (102), + (104), + (106), + (108), + (110), + (112), + (114), + (116), + (118), + (120), + (122), + (124), + (126), + (128), + (130), + (132), + (134), + (136), + (138), + (140), + (142), + (144), + (146), + (148), + (150), + (152), + (154), + (156), + (158), + (160), + (162), + (164), + (166), + (168), + (170), + (172), + (174), + (176), + (178), + (180), + (182), + (184), + (186), + (188), + (190), + (192), + (194), + (196), + (198), + (200), + (202), + (204), + (206), + (208), + (210), + (212), + (214), + (216), + (218), + (220), + (222), + (224), + (226), + (228), + (230), + (232), + (234), + (236), + (238), + (240), + (242), + (244), + (246), + (248), + (250), + (252), + (254), + (256), + (258), + (260), + (262), + (264), + (266), + (268), + (270), + (272), + (274), + (276), + (278), + (280), + (282), + (284), + (286), + (288), + (290), + (292), + (294), + (296), + (298), + (300), + (302), + (304), + (306), + (308), + (310), + (312), + (314), + (316), + (318), + (320), + (322), + (324), + (326), + (328), + (330), + (332), + (334), + (336), + (338), + (340), + (342), + (344), + (346), + (348), + (350), + (352), + (354), + (356), + (358), + (360), + (362), + (364), + (366), + (368), + (370), + (372), + (374), + (376), + (378), + (380), + (382), + (384), + (386), + (388), + (390), + (392), + (394), + (396), + (398), + (400), + (402), + (404), + (406), + (408), + (410), + (412), + (414), + (416), + (418), + (420), + (422), + (424), + (426), + (428), + (430), + (432), + (434), + (436), + (438), + (440), + (442), + (444), + (446), + (448), + (450), + (452), + (454), + (456), + (458), + (460), + (462), + (464), + (466), + (468), + (470), + (472), + (474), + (476), + (478), + (480), + (482), + (484), + (486), + (488), + (490), + (492), + (494), + (496), + (498), + (500), + (502), + (504), + (506), + (508), + (510), + (512), + (514), + (516), + (518), + (520), + (522), + (524), + (526), + (528), + (530), + (532), + (534), + (536), + (538), + (540), + (542), + (544), + (546), + (548), + (550), + (552), + (554), + (556), + (558), + (560), + (562), + (564), + (566), + (568), + (570), + (572), + (574), + (576), + (578), + (580), + (582), + (584), + (586), + (588), + (590), + (592), + (594), + (596), + (598), + (600), + (602), + (604), + (606), + (608), + (610), + (612), + (614), + (616), + (618), + (620), + (622), + (624), + (626), + (628), + (630), + (632), + (634), + (636), + (638), + (640), + (642), + (644), + (646), + (648), + (650), + (652), + (654), + (656), + (658), + (660), + (662), + (664), + (666), + (668), + (670), + (672), + (674), + (676), + (678), + (680), + (682), + (684), + (686), + (688), + (690), + (692), + (694), + (696), + (698), + (700), + (702), + (704), + (706), + (708), + (710), + (712), + (714), + (716), + (718), + (720), + (722), + (724), + (726), + (728), + (730), + (732), + (734), + (736), + (738), + (740), + (742), + (744), + (746), + (748), + (750), + (752), + (754), + (756), + (758), + (760), + (762), + (764), + (766), + (768), + (770), + (772), + (774), + (776), + (778), + (780), + (782), + (784), + (786), + (788), + (790), + (792), + (794), + (796), + (798), + (800), + (802), + (804), + (806), + (808), + (810), + (812), + (814), + (816), + (818), + (820), + (822), + (824), + (826), + (828), + (830), + (832), + (834), + (836), + (838), + (840), + (842), + (844), + (846), + (848), + (850), + (852), + (854), + (856), + (858), + (860), + (862), + (864), + (866), + (868), + (870), + (872), + (874), + (876), + (878), + (880), + (882), + (884), + (886), + (888), + (890), + (892), + (894), + (896), + (898), + (900), + (902), + (904), + (906), + (908), + (910), + (912), + (914), + (916), + (918), + (920), + (922), + (924), + (926), + (928), + (930), + (932), + (934), + (936), + (938), + (940), + (942), + (944), + (946), + (948), + (950), + (952), + (954), + (956), + (958), + (960), + (962), + (964), + (966), + (968), + (970), + (972), + (974), + (976), + (978), + (980), + (982), + (984), + (986), + (988), + (990), + (992), + (994), + (996), + (998), + (1000)) AS b(_col_0) +), _s4 AS ( + SELECT + ANY_VALUE(_s0.x) AS anything_x, + COUNT(*) AS n_prefix, + _s0.x + FROM _s0 AS _s0 + JOIN _s1 AS _s1 + ON CAST(_s1.y AS TEXT) LIKE CONCAT(CAST(_s0.x AS TEXT), '%') + GROUP BY + _s0.x +), _s5 AS ( + SELECT + COUNT(*) AS n_suffix, + _s2.x + FROM _s0 AS _s2 + JOIN _s1 AS _s3 + ON CAST(_s3.y AS TEXT) LIKE CONCAT('%', CAST(_s2.x AS TEXT)) + GROUP BY + _s2.x +) +SELECT + _s4.x, + _s4.n_prefix, + _s5.n_suffix +FROM _s4 AS _s4 +JOIN _s5 AS _s5 + ON _s4.anything_x = _s5.x +ORDER BY + x diff --git a/tests/test_sql_refsols/user_range_collection_2_sqlite.sql b/tests/test_sql_refsols/user_range_collection_2_sqlite.sql new file mode 100644 index 000000000..0e8413411 --- /dev/null +++ b/tests/test_sql_refsols/user_range_collection_2_sqlite.sql @@ -0,0 +1,552 @@ +WITH _s0 AS ( + SELECT + column1 AS x + FROM (VALUES + (0), + (1), + (2), + (3), + (4), + (5), + (6), + (7), + (8), + (9)) AS a +), _s1 AS ( + SELECT + column1 AS y + FROM (VALUES + (0), + (2), + (4), + (6), + (8), + (10), + (12), + (14), + (16), + (18), + (20), + (22), + (24), + (26), + (28), + (30), + (32), + (34), + (36), + (38), + (40), + (42), + (44), + (46), + (48), + (50), + (52), + (54), + (56), + (58), + (60), + (62), + (64), + (66), + (68), + (70), + (72), + (74), + (76), + (78), + (80), + (82), + (84), + (86), + (88), + (90), + (92), + (94), + (96), + (98), + (100), + (102), + (104), + (106), + (108), + (110), + (112), + (114), + (116), + (118), + (120), + (122), + (124), + (126), + (128), + (130), + (132), + (134), + (136), + (138), + (140), + (142), + (144), + (146), + (148), + (150), + (152), + (154), + (156), + (158), + (160), + (162), + (164), + (166), + (168), + (170), + (172), + (174), + (176), + (178), + (180), + (182), + (184), + (186), + (188), + (190), + (192), + (194), + (196), + (198), + (200), + (202), + (204), + (206), + (208), + (210), + (212), + (214), + (216), + (218), + (220), + (222), + (224), + (226), + (228), + (230), + (232), + (234), + (236), + (238), + (240), + (242), + (244), + (246), + (248), + (250), + (252), + (254), + (256), + (258), + (260), + (262), + (264), + (266), + (268), + (270), + (272), + (274), + (276), + (278), + (280), + (282), + (284), + (286), + (288), + (290), + (292), + (294), + (296), + (298), + (300), + (302), + (304), + (306), + (308), + (310), + (312), + (314), + (316), + (318), + (320), + (322), + (324), + (326), + (328), + (330), + (332), + (334), + (336), + (338), + (340), + (342), + (344), + (346), + (348), + (350), + (352), + (354), + (356), + (358), + (360), + (362), + (364), + (366), + (368), + (370), + (372), + (374), + (376), + (378), + (380), + (382), + (384), + (386), + (388), + (390), + (392), + (394), + (396), + (398), + (400), + (402), + (404), + (406), + (408), + (410), + (412), + (414), + (416), + (418), + (420), + (422), + (424), + (426), + (428), + (430), + (432), + (434), + (436), + (438), + (440), + (442), + (444), + (446), + (448), + (450), + (452), + (454), + (456), + (458), + (460), + (462), + (464), + (466), + (468), + (470), + (472), + (474), + (476), + (478), + (480), + (482), + (484), + (486), + (488), + (490), + (492), + (494), + (496), + (498), + (500), + (502), + (504), + (506), + (508), + (510), + (512), + (514), + (516), + (518), + (520), + (522), + (524), + (526), + (528), + (530), + (532), + (534), + (536), + (538), + (540), + (542), + (544), + (546), + (548), + (550), + (552), + (554), + (556), + (558), + (560), + (562), + (564), + (566), + (568), + (570), + (572), + (574), + (576), + (578), + (580), + (582), + (584), + (586), + (588), + (590), + (592), + (594), + (596), + (598), + (600), + (602), + (604), + (606), + (608), + (610), + (612), + (614), + (616), + (618), + (620), + (622), + (624), + (626), + (628), + (630), + (632), + (634), + (636), + (638), + (640), + (642), + (644), + (646), + (648), + (650), + (652), + (654), + (656), + (658), + (660), + (662), + (664), + (666), + (668), + (670), + (672), + (674), + (676), + (678), + (680), + (682), + (684), + (686), + (688), + (690), + (692), + (694), + (696), + (698), + (700), + (702), + (704), + (706), + (708), + (710), + (712), + (714), + (716), + (718), + (720), + (722), + (724), + (726), + (728), + (730), + (732), + (734), + (736), + (738), + (740), + (742), + (744), + (746), + (748), + (750), + (752), + (754), + (756), + (758), + (760), + (762), + (764), + (766), + (768), + (770), + (772), + (774), + (776), + (778), + (780), + (782), + (784), + (786), + (788), + (790), + (792), + (794), + (796), + (798), + (800), + (802), + (804), + (806), + (808), + (810), + (812), + (814), + (816), + (818), + (820), + (822), + (824), + (826), + (828), + (830), + (832), + (834), + (836), + (838), + (840), + (842), + (844), + (846), + (848), + (850), + (852), + (854), + (856), + (858), + (860), + (862), + (864), + (866), + (868), + (870), + (872), + (874), + (876), + (878), + (880), + (882), + (884), + (886), + (888), + (890), + (892), + (894), + (896), + (898), + (900), + (902), + (904), + (906), + (908), + (910), + (912), + (914), + (916), + (918), + (920), + (922), + (924), + (926), + (928), + (930), + (932), + (934), + (936), + (938), + (940), + (942), + (944), + (946), + (948), + (950), + (952), + (954), + (956), + (958), + (960), + (962), + (964), + (966), + (968), + (970), + (972), + (974), + (976), + (978), + (980), + (982), + (984), + (986), + (988), + (990), + (992), + (994), + (996), + (998), + (1000)) AS b +), _s4 AS ( + SELECT + MAX(_s0.x) AS anything_x, + COUNT(*) AS n_prefix, + _s0.x + FROM _s0 AS _s0 + JOIN _s1 AS _s1 + ON CAST(_s1.y AS TEXT) LIKE ( + CAST(_s0.x AS TEXT) || '%' + ) + GROUP BY + _s0.x +), _s5 AS ( + SELECT + COUNT(*) AS n_suffix, + _s2.x + FROM _s0 AS _s2 + JOIN _s1 AS _s3 + ON CAST(_s3.y AS TEXT) LIKE ( + '%' || CAST(_s2.x AS TEXT) + ) + GROUP BY + _s2.x +) +SELECT + _s4.x, + _s4.n_prefix, + _s5.n_suffix +FROM _s4 AS _s4 +JOIN _s5 AS _s5 + ON _s4.anything_x = _s5.x +ORDER BY + x diff --git a/tests/test_sql_refsols/user_range_collection_3_ansi.sql b/tests/test_sql_refsols/user_range_collection_3_ansi.sql new file mode 100644 index 000000000..ba503f99b --- /dev/null +++ b/tests/test_sql_refsols/user_range_collection_3_ansi.sql @@ -0,0 +1,548 @@ +WITH _s0 AS ( + SELECT + column1 AS x + FROM (VALUES + (0), + (1), + (2), + (3), + (4), + (5), + (6), + (7), + (8), + (9)) AS a(_col_0) +), _s1 AS ( + SELECT + column1 AS y + FROM (VALUES + (0), + (2), + (4), + (6), + (8), + (10), + (12), + (14), + (16), + (18), + (20), + (22), + (24), + (26), + (28), + (30), + (32), + (34), + (36), + (38), + (40), + (42), + (44), + (46), + (48), + (50), + (52), + (54), + (56), + (58), + (60), + (62), + (64), + (66), + (68), + (70), + (72), + (74), + (76), + (78), + (80), + (82), + (84), + (86), + (88), + (90), + (92), + (94), + (96), + (98), + (100), + (102), + (104), + (106), + (108), + (110), + (112), + (114), + (116), + (118), + (120), + (122), + (124), + (126), + (128), + (130), + (132), + (134), + (136), + (138), + (140), + (142), + (144), + (146), + (148), + (150), + (152), + (154), + (156), + (158), + (160), + (162), + (164), + (166), + (168), + (170), + (172), + (174), + (176), + (178), + (180), + (182), + (184), + (186), + (188), + (190), + (192), + (194), + (196), + (198), + (200), + (202), + (204), + (206), + (208), + (210), + (212), + (214), + (216), + (218), + (220), + (222), + (224), + (226), + (228), + (230), + (232), + (234), + (236), + (238), + (240), + (242), + (244), + (246), + (248), + (250), + (252), + (254), + (256), + (258), + (260), + (262), + (264), + (266), + (268), + (270), + (272), + (274), + (276), + (278), + (280), + (282), + (284), + (286), + (288), + (290), + (292), + (294), + (296), + (298), + (300), + (302), + (304), + (306), + (308), + (310), + (312), + (314), + (316), + (318), + (320), + (322), + (324), + (326), + (328), + (330), + (332), + (334), + (336), + (338), + (340), + (342), + (344), + (346), + (348), + (350), + (352), + (354), + (356), + (358), + (360), + (362), + (364), + (366), + (368), + (370), + (372), + (374), + (376), + (378), + (380), + (382), + (384), + (386), + (388), + (390), + (392), + (394), + (396), + (398), + (400), + (402), + (404), + (406), + (408), + (410), + (412), + (414), + (416), + (418), + (420), + (422), + (424), + (426), + (428), + (430), + (432), + (434), + (436), + (438), + (440), + (442), + (444), + (446), + (448), + (450), + (452), + (454), + (456), + (458), + (460), + (462), + (464), + (466), + (468), + (470), + (472), + (474), + (476), + (478), + (480), + (482), + (484), + (486), + (488), + (490), + (492), + (494), + (496), + (498), + (500), + (502), + (504), + (506), + (508), + (510), + (512), + (514), + (516), + (518), + (520), + (522), + (524), + (526), + (528), + (530), + (532), + (534), + (536), + (538), + (540), + (542), + (544), + (546), + (548), + (550), + (552), + (554), + (556), + (558), + (560), + (562), + (564), + (566), + (568), + (570), + (572), + (574), + (576), + (578), + (580), + (582), + (584), + (586), + (588), + (590), + (592), + (594), + (596), + (598), + (600), + (602), + (604), + (606), + (608), + (610), + (612), + (614), + (616), + (618), + (620), + (622), + (624), + (626), + (628), + (630), + (632), + (634), + (636), + (638), + (640), + (642), + (644), + (646), + (648), + (650), + (652), + (654), + (656), + (658), + (660), + (662), + (664), + (666), + (668), + (670), + (672), + (674), + (676), + (678), + (680), + (682), + (684), + (686), + (688), + (690), + (692), + (694), + (696), + (698), + (700), + (702), + (704), + (706), + (708), + (710), + (712), + (714), + (716), + (718), + (720), + (722), + (724), + (726), + (728), + (730), + (732), + (734), + (736), + (738), + (740), + (742), + (744), + (746), + (748), + (750), + (752), + (754), + (756), + (758), + (760), + (762), + (764), + (766), + (768), + (770), + (772), + (774), + (776), + (778), + (780), + (782), + (784), + (786), + (788), + (790), + (792), + (794), + (796), + (798), + (800), + (802), + (804), + (806), + (808), + (810), + (812), + (814), + (816), + (818), + (820), + (822), + (824), + (826), + (828), + (830), + (832), + (834), + (836), + (838), + (840), + (842), + (844), + (846), + (848), + (850), + (852), + (854), + (856), + (858), + (860), + (862), + (864), + (866), + (868), + (870), + (872), + (874), + (876), + (878), + (880), + (882), + (884), + (886), + (888), + (890), + (892), + (894), + (896), + (898), + (900), + (902), + (904), + (906), + (908), + (910), + (912), + (914), + (916), + (918), + (920), + (922), + (924), + (926), + (928), + (930), + (932), + (934), + (936), + (938), + (940), + (942), + (944), + (946), + (948), + (950), + (952), + (954), + (956), + (958), + (960), + (962), + (964), + (966), + (968), + (970), + (972), + (974), + (976), + (978), + (980), + (982), + (984), + (986), + (988), + (990), + (992), + (994), + (996), + (998), + (1000)) AS b(_col_0) +), _s4 AS ( + SELECT + ANY_VALUE(_s0.x) AS anything_x, + COUNT(*) AS n_prefix, + _s0.x + FROM _s0 AS _s0 + JOIN _s1 AS _s1 + ON CAST(_s1.y AS TEXT) LIKE CONCAT(CAST(_s0.x AS TEXT), '%') + GROUP BY + _s0.x +), _s5 AS ( + SELECT + COUNT(*) AS n_suffix, + _s2.x + FROM _s0 AS _s2 + JOIN _s1 AS _s3 + ON CAST(_s3.y AS TEXT) LIKE CONCAT('%', CAST(_s2.x AS TEXT)) + GROUP BY + _s2.x +) +SELECT + _s4.x, + _s4.n_prefix, + _s5.n_suffix +FROM _s4 AS _s4 +JOIN _s5 AS _s5 + ON _s4.anything_x = _s5.x +ORDER BY + x diff --git a/tests/test_sql_refsols/user_range_collection_3_sqlite.sql b/tests/test_sql_refsols/user_range_collection_3_sqlite.sql new file mode 100644 index 000000000..0e8413411 --- /dev/null +++ b/tests/test_sql_refsols/user_range_collection_3_sqlite.sql @@ -0,0 +1,552 @@ +WITH _s0 AS ( + SELECT + column1 AS x + FROM (VALUES + (0), + (1), + (2), + (3), + (4), + (5), + (6), + (7), + (8), + (9)) AS a +), _s1 AS ( + SELECT + column1 AS y + FROM (VALUES + (0), + (2), + (4), + (6), + (8), + (10), + (12), + (14), + (16), + (18), + (20), + (22), + (24), + (26), + (28), + (30), + (32), + (34), + (36), + (38), + (40), + (42), + (44), + (46), + (48), + (50), + (52), + (54), + (56), + (58), + (60), + (62), + (64), + (66), + (68), + (70), + (72), + (74), + (76), + (78), + (80), + (82), + (84), + (86), + (88), + (90), + (92), + (94), + (96), + (98), + (100), + (102), + (104), + (106), + (108), + (110), + (112), + (114), + (116), + (118), + (120), + (122), + (124), + (126), + (128), + (130), + (132), + (134), + (136), + (138), + (140), + (142), + (144), + (146), + (148), + (150), + (152), + (154), + (156), + (158), + (160), + (162), + (164), + (166), + (168), + (170), + (172), + (174), + (176), + (178), + (180), + (182), + (184), + (186), + (188), + (190), + (192), + (194), + (196), + (198), + (200), + (202), + (204), + (206), + (208), + (210), + (212), + (214), + (216), + (218), + (220), + (222), + (224), + (226), + (228), + (230), + (232), + (234), + (236), + (238), + (240), + (242), + (244), + (246), + (248), + (250), + (252), + (254), + (256), + (258), + (260), + (262), + (264), + (266), + (268), + (270), + (272), + (274), + (276), + (278), + (280), + (282), + (284), + (286), + (288), + (290), + (292), + (294), + (296), + (298), + (300), + (302), + (304), + (306), + (308), + (310), + (312), + (314), + (316), + (318), + (320), + (322), + (324), + (326), + (328), + (330), + (332), + (334), + (336), + (338), + (340), + (342), + (344), + (346), + (348), + (350), + (352), + (354), + (356), + (358), + (360), + (362), + (364), + (366), + (368), + (370), + (372), + (374), + (376), + (378), + (380), + (382), + (384), + (386), + (388), + (390), + (392), + (394), + (396), + (398), + (400), + (402), + (404), + (406), + (408), + (410), + (412), + (414), + (416), + (418), + (420), + (422), + (424), + (426), + (428), + (430), + (432), + (434), + (436), + (438), + (440), + (442), + (444), + (446), + (448), + (450), + (452), + (454), + (456), + (458), + (460), + (462), + (464), + (466), + (468), + (470), + (472), + (474), + (476), + (478), + (480), + (482), + (484), + (486), + (488), + (490), + (492), + (494), + (496), + (498), + (500), + (502), + (504), + (506), + (508), + (510), + (512), + (514), + (516), + (518), + (520), + (522), + (524), + (526), + (528), + (530), + (532), + (534), + (536), + (538), + (540), + (542), + (544), + (546), + (548), + (550), + (552), + (554), + (556), + (558), + (560), + (562), + (564), + (566), + (568), + (570), + (572), + (574), + (576), + (578), + (580), + (582), + (584), + (586), + (588), + (590), + (592), + (594), + (596), + (598), + (600), + (602), + (604), + (606), + (608), + (610), + (612), + (614), + (616), + (618), + (620), + (622), + (624), + (626), + (628), + (630), + (632), + (634), + (636), + (638), + (640), + (642), + (644), + (646), + (648), + (650), + (652), + (654), + (656), + (658), + (660), + (662), + (664), + (666), + (668), + (670), + (672), + (674), + (676), + (678), + (680), + (682), + (684), + (686), + (688), + (690), + (692), + (694), + (696), + (698), + (700), + (702), + (704), + (706), + (708), + (710), + (712), + (714), + (716), + (718), + (720), + (722), + (724), + (726), + (728), + (730), + (732), + (734), + (736), + (738), + (740), + (742), + (744), + (746), + (748), + (750), + (752), + (754), + (756), + (758), + (760), + (762), + (764), + (766), + (768), + (770), + (772), + (774), + (776), + (778), + (780), + (782), + (784), + (786), + (788), + (790), + (792), + (794), + (796), + (798), + (800), + (802), + (804), + (806), + (808), + (810), + (812), + (814), + (816), + (818), + (820), + (822), + (824), + (826), + (828), + (830), + (832), + (834), + (836), + (838), + (840), + (842), + (844), + (846), + (848), + (850), + (852), + (854), + (856), + (858), + (860), + (862), + (864), + (866), + (868), + (870), + (872), + (874), + (876), + (878), + (880), + (882), + (884), + (886), + (888), + (890), + (892), + (894), + (896), + (898), + (900), + (902), + (904), + (906), + (908), + (910), + (912), + (914), + (916), + (918), + (920), + (922), + (924), + (926), + (928), + (930), + (932), + (934), + (936), + (938), + (940), + (942), + (944), + (946), + (948), + (950), + (952), + (954), + (956), + (958), + (960), + (962), + (964), + (966), + (968), + (970), + (972), + (974), + (976), + (978), + (980), + (982), + (984), + (986), + (988), + (990), + (992), + (994), + (996), + (998), + (1000)) AS b +), _s4 AS ( + SELECT + MAX(_s0.x) AS anything_x, + COUNT(*) AS n_prefix, + _s0.x + FROM _s0 AS _s0 + JOIN _s1 AS _s1 + ON CAST(_s1.y AS TEXT) LIKE ( + CAST(_s0.x AS TEXT) || '%' + ) + GROUP BY + _s0.x +), _s5 AS ( + SELECT + COUNT(*) AS n_suffix, + _s2.x + FROM _s0 AS _s2 + JOIN _s1 AS _s3 + ON CAST(_s3.y AS TEXT) LIKE ( + '%' || CAST(_s2.x AS TEXT) + ) + GROUP BY + _s2.x +) +SELECT + _s4.x, + _s4.n_prefix, + _s5.n_suffix +FROM _s4 AS _s4 +JOIN _s5 AS _s5 + ON _s4.anything_x = _s5.x +ORDER BY + x diff --git a/tests/test_sql_refsols/user_range_collection_4_ansi.sql b/tests/test_sql_refsols/user_range_collection_4_ansi.sql new file mode 100644 index 000000000..75428595f --- /dev/null +++ b/tests/test_sql_refsols/user_range_collection_4_ansi.sql @@ -0,0 +1,28 @@ +WITH _t0 AS ( + SELECT + part.p_name, + part.p_retailprice, + column1 AS part_size + FROM (VALUES + (0), + (1), + (2), + (3), + (4), + (5), + (6), + (7), + (8), + (9)) AS sizes(_col_0) + JOIN tpch.part AS part + ON column1 = part.p_size AND part.p_name LIKE '%turquoise%' + QUALIFY + ROW_NUMBER() OVER (PARTITION BY column1 ORDER BY part.p_retailprice NULLS LAST) = 1 +) +SELECT + part_size, + p_name AS name, + p_retailprice AS retail_price +FROM _t0 +ORDER BY + part_size diff --git a/tests/test_sql_refsols/user_range_collection_4_sqlite.sql b/tests/test_sql_refsols/user_range_collection_4_sqlite.sql new file mode 100644 index 000000000..8a1cf1c9c --- /dev/null +++ b/tests/test_sql_refsols/user_range_collection_4_sqlite.sql @@ -0,0 +1,29 @@ +WITH _t AS ( + SELECT + part.p_name, + part.p_retailprice, + column1 AS part_size, + ROW_NUMBER() OVER (PARTITION BY column1 ORDER BY part.p_retailprice) AS _w + FROM (VALUES + (0), + (1), + (2), + (3), + (4), + (5), + (6), + (7), + (8), + (9)) AS sizes + JOIN tpch.part AS part + ON column1 = part.p_size AND part.p_name LIKE '%turquoise%' +) +SELECT + part_size, + p_name AS name, + p_retailprice AS retail_price +FROM _t +WHERE + _w = 1 +ORDER BY + part_size diff --git a/tests/testing_utilities.py b/tests/testing_utilities.py index 8c55689b2..7c7837f71 100644 --- a/tests/testing_utilities.py +++ b/tests/testing_utilities.py @@ -292,7 +292,8 @@ def build( assert context is not None, ( "Cannot call .build() on ReferenceInfo without providing a context" ) - return builder.build_reference(context, self.name) + typ: PyDoughType = context.get_expr(self.name).pydough_type + return builder.build_reference(context, self.name, typ) class BackReferenceExpressionInfo(AstNodeTestInfo):