@@ -503,20 +503,20 @@ func middlewareFileQuery(queryData *QueryData) bool {
503
503
}
504
504
505
505
// Walk the AST and replace the table functions
506
- sqlparser .Rewrite (stmt , nil , func (cursor * sqlparser.Cursor ) bool {
507
- // Get the table function
506
+ // with a CREATE VIRTUAL TABLE statement
507
+ rewrite := func (cursor * sqlparser.Cursor ) bool {
508
+ // Get the table name
508
509
tableFunction , ok := cursor .Node ().(sqlparser.TableName )
509
510
if ! ok {
510
511
return true
511
512
}
512
513
513
514
loweredName := strings .ToLower (tableFunction .Name .String ())
514
- // Check if the table function is a file module
515
+
515
516
if ! strings .HasPrefix (loweredName , "read_" ) {
516
517
return true
517
518
}
518
519
519
- // Replace the table function with a random one
520
520
tableName := generateRandomString (16 )
521
521
preExecBuilder := strings.Builder {}
522
522
preExecBuilder .WriteString ("CREATE VIRTUAL TABLE " )
@@ -525,11 +525,9 @@ func middlewareFileQuery(queryData *QueryData) bool {
525
525
if reader , ok := supportedTableFunctions [loweredName ]; ok {
526
526
preExecBuilder .WriteString (reader )
527
527
} else {
528
- // If the user writes read_foo, and we don't have a reader for foo
529
- // we skip the table function
530
528
return true
531
529
}
532
- // Add the arguments if any
530
+
533
531
if len (tableFunction .Args ) > 0 {
534
532
preExecBuilder .WriteString ("(" )
535
533
for i , arg := range tableFunction .Args {
@@ -543,17 +541,22 @@ func middlewareFileQuery(queryData *QueryData) bool {
543
541
544
542
preExecBuilder .WriteString (";" )
545
543
546
- // Add the pre-execution statement
547
544
queryData .PreExec = append (queryData .PreExec , preExecBuilder .String ())
548
545
549
- // Add a post-execution statement to drop the table
550
546
queryData .PostExec = append (queryData .PostExec , "DROP TABLE " + tableName + ";" )
551
547
552
- // Replace the table function with the new table name
553
548
cursor .Replace (sqlparser .NewTableName (tableName ))
554
549
555
550
return true
556
- })
551
+ }
552
+ sqlparser .Rewrite (stmt , nil , rewrite )
553
+
554
+ // In case it is a CREATE TABLE statement, we need to rewrite the select statement
555
+ if createTable , ok := stmt .(* sqlparser.CreateTable ); ok {
556
+ if createTable .Select != nil {
557
+ sqlparser .Rewrite (createTable .Select , nil , rewrite )
558
+ }
559
+ }
557
560
558
561
// Deparse the query
559
562
queryData .SQLQuery = sqlparser .String (stmt )
0 commit comments