@@ -666,6 +666,60 @@ def list_tables():
666
666
"message" : str (e )
667
667
}), 500
668
668
669
+ def assemble_query (aggregate_fields_and_functions , group_fields , columns , table_name ):
670
+ """
671
+ Assembles a SELECT query string based on binning, aggregation, and grouping specifications.
672
+
673
+ Args:
674
+ bin_fields (list): Fields to be binned into ranges
675
+ aggregate_fields_and_functions (list): List of tuples (field, function) for aggregation
676
+ group_fields (list): Fields to group by
677
+ columns (list): All available column names
678
+
679
+ Returns:
680
+ str: The assembled SELECT query projection part
681
+ """
682
+ select_parts = []
683
+ output_column_names = []
684
+
685
+ # Handle aggregate fields and functions
686
+ for field , function in aggregate_fields_and_functions :
687
+ if field is None :
688
+ # Handle count(*) case
689
+ if function .lower () == 'count' :
690
+ select_parts .append ('COUNT(*) as _count' )
691
+ output_column_names .append ('_count' )
692
+ elif field in columns :
693
+ if function .lower () == 'count' :
694
+ alias = f'_count'
695
+ select_parts .append (f'COUNT(*) as { alias } ' )
696
+ output_column_names .append (alias )
697
+ else :
698
+ # Sanitize function name and create alias
699
+ if function in ["avg" , "average" , "mean" ]:
700
+ aggregate_function = "AVG"
701
+ else :
702
+ aggregate_function = function .upper ()
703
+
704
+ alias = f'{ field } _{ function } '
705
+ select_parts .append (f'{ aggregate_function } ("{ field } ") as { alias } ' )
706
+ output_column_names .append (alias )
707
+
708
+ # Handle group fields
709
+ for field in group_fields :
710
+ if field in columns :
711
+ select_parts .append (f'"{ field } "' )
712
+ output_column_names .append (field )
713
+ # If no fields are specified, select all columns
714
+ if not select_parts :
715
+ select_parts = ["*" ]
716
+ output_column_names = columns
717
+
718
+ from_clause = f"FROM { table_name } "
719
+ group_by_clause = f"GROUP BY { ', ' .join (group_fields )} " if len (group_fields ) > 0 and len (aggregate_fields_and_functions ) > 0 else ""
720
+
721
+ query = f"SELECT { ', ' .join (select_parts )} { from_clause } { group_by_clause } "
722
+ return query , output_column_names
669
723
670
724
@app .route ('/api/tables/sample-table' , methods = ['POST' ])
671
725
def sample_table ():
@@ -674,52 +728,58 @@ def sample_table():
674
728
data = request .get_json ()
675
729
table_id = data .get ('table' )
676
730
sample_size = data .get ('size' , 1000 )
677
- projection_fields = data .get ('projection_fields' , []) # if empty, we want to include all fields
731
+ aggregate_fields_and_functions = data .get ('aggregate_fields_and_functions' , []) # each element is a tuple (field, function)
732
+ select_fields = data .get ('select_fields' , []) # if empty, we want to include all fields
678
733
method = data .get ('method' , 'random' ) # one of 'random', 'head', 'bottom'
679
734
order_by_fields = data .get ('order_by_fields' , [])
680
735
681
- print (f"sample_table: { table_id } , { sample_size } , { projection_fields } , { method } , { order_by_fields } " )
736
+ print (f"sample_table: { table_id } , { sample_size } , { aggregate_fields_and_functions } , { select_fields } , { method } , { order_by_fields } " )
682
737
738
+ total_row_count = 0
683
739
# Validate field names against table columns to prevent SQL injection
684
740
with db_manager .connection (session ['session_id' ]) as db :
685
741
# Get valid column names
686
742
columns = [col [0 ] for col in db .execute (f"DESCRIBE { table_id } " ).fetchall ()]
687
743
688
744
# Filter order_by_fields to only include valid column names
689
745
valid_order_by_fields = [field for field in order_by_fields if field in columns ]
690
- valid_projection_fields = [field for field in projection_fields if field in columns ]
746
+ valid_aggregate_fields_and_functions = [
747
+ field_and_function for field_and_function in aggregate_fields_and_functions
748
+ if field_and_function [0 ] is None or field_and_function [0 ] in columns
749
+ ]
750
+ valid_select_fields = [field for field in select_fields if field in columns ]
691
751
692
- if len (valid_projection_fields ) == 0 :
693
- projection_fields_str = "*"
694
- else :
695
- projection_fields_str = ", " .join (valid_projection_fields )
752
+ query , output_column_names = assemble_query (valid_aggregate_fields_and_functions , valid_select_fields , columns , table_id )
696
753
754
+ # Modify the original query to include the count:
755
+ count_query = f"SELECT *, COUNT(*) OVER () as total_count FROM ({ query } ) as subq LIMIT 1"
756
+ result = db .execute (count_query ).fetchone ()
757
+ total_row_count = result [- 1 ] if result else 0
758
+
759
+ # Add ordering and limit to the main query
697
760
if method == 'random' :
698
- result = db . execute ( f"SELECT { projection_fields_str } FROM { table_id } ORDER BY RANDOM() LIMIT { sample_size } "). fetchall ()
761
+ query += f" ORDER BY RANDOM() LIMIT { sample_size } "
699
762
elif method == 'head' :
700
763
if valid_order_by_fields :
701
764
# Build ORDER BY clause with validated fields
702
765
order_by_clause = ", " .join ([f'"{ field } "' for field in valid_order_by_fields ])
703
- result = db . execute ( f"SELECT { projection_fields_str } FROM { table_id } ORDER BY { order_by_clause } LIMIT { sample_size } "). fetchall ()
766
+ query += f" ORDER BY { order_by_clause } LIMIT { sample_size } "
704
767
else :
705
- result = db . execute ( f"SELECT { projection_fields_str } FROM { table_id } LIMIT { sample_size } "). fetchall ()
768
+ query += f" LIMIT { sample_size } "
706
769
elif method == 'bottom' :
707
770
if valid_order_by_fields :
708
771
# Build ORDER BY clause with validated fields in descending order
709
772
order_by_clause = ", " .join ([f'"{ field } " DESC' for field in valid_order_by_fields ])
710
- result = db . execute ( f"SELECT { projection_fields_str } FROM { table_id } ORDER BY { order_by_clause } LIMIT { sample_size } "). fetchall ()
773
+ query += f" ORDER BY { order_by_clause } LIMIT { sample_size } "
711
774
else :
712
- result = db . execute ( f"SELECT { projection_fields_str } FROM { table_id } ORDER BY ROWID DESC LIMIT { sample_size } "). fetchall ()
775
+ query += f" ORDER BY ROWID DESC LIMIT { sample_size } "
713
776
714
- # When using projection_fields, we need to use those as our column names
715
- if len (valid_projection_fields ) > 0 :
716
- column_names = valid_projection_fields
717
- else :
718
- column_names = columns
777
+ result = db .execute (query ).fetchall ()
719
778
720
779
return jsonify ({
721
780
"status" : "success" ,
722
- "rows" : [dict (zip (column_names , row )) for row in result ]
781
+ "rows" : [dict (zip (output_column_names , row )) for row in result ],
782
+ "total_row_count" : total_row_count
723
783
})
724
784
except Exception as e :
725
785
print (e )
0 commit comments