Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,24 @@ public class SQLQueryUtils {
private static final Logger logger = LogManager.getLogger(SQLQueryUtils.class);

public static List<FullyQualifiedTableName> extractFullyQualifiedTableNames(String sqlQuery) {
SqlBaseParser sqlBaseParser =
new SqlBaseParser(
new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery))));
sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener());
return extractFullyQualifiedTableNamesWithMetadata(sqlQuery).getFullyQualifiedTableNames();
}

public static TableExtractionResult extractFullyQualifiedTableNamesWithMetadata(String sqlQuery) {
SqlBaseParser sqlBaseParser = getBaseParser(sqlQuery);
StatementContext statement = sqlBaseParser.statement();
SparkSqlTableNameVisitor sparkSqlTableNameVisitor = new SparkSqlTableNameVisitor();
statement.accept(sparkSqlTableNameVisitor);
return sparkSqlTableNameVisitor.getFullyQualifiedTableNames();
SparkSqlTableNameVisitor visitor = new SparkSqlTableNameVisitor();
statement.accept(visitor);

// Remove duplicate table names
List<FullyQualifiedTableName> uniqueFullyQualifiedTableNames = new LinkedList<>();
for (FullyQualifiedTableName fullyQualifiedTableName : visitor.getFullyQualifiedTableNames()) {
if (!uniqueFullyQualifiedTableNames.contains(fullyQualifiedTableName)) {
uniqueFullyQualifiedTableNames.add(fullyQualifiedTableName);
}
}

return new TableExtractionResult(uniqueFullyQualifiedTableNames, visitor.isCreateTable());
}

public static IndexQueryDetails extractIndexDetails(String sqlQuery) {
Expand Down Expand Up @@ -90,7 +100,10 @@ public static SqlBaseParser getBaseParser(String sqlQuery) {

public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor<Void> {

@Getter private List<FullyQualifiedTableName> fullyQualifiedTableNames = new LinkedList<>();
@Getter
private final List<FullyQualifiedTableName> fullyQualifiedTableNames = new LinkedList<>();

@Getter private boolean isCreateTable = false;

public SparkSqlTableNameVisitor() {}

Expand Down Expand Up @@ -130,6 +143,12 @@ public Void visitCreateTableHeader(SqlBaseParser.CreateTableHeaderContext ctx) {
}
return super.visitCreateTableHeader(ctx);
}

@Override
public Void visitCreateTable(SqlBaseParser.CreateTableContext ctx) {
isCreateTable = true;
return super.visitCreateTable(ctx);
}
}

public static class FlintSQLIndexDetailsVisitor extends FlintSparkSqlExtensionsBaseVisitor<Void> {
Expand Down Expand Up @@ -380,4 +399,15 @@ public String removeUnwantedQuotes(String input) {
return input.replaceAll("^\"|\"$", "");
}
}

public static class TableExtractionResult {
@Getter private final List<FullyQualifiedTableName> fullyQualifiedTableNames;
@Getter private final boolean isCreateTableQuery;

public TableExtractionResult(
List<FullyQualifiedTableName> fullyQualifiedTableNames, boolean isCreateTableQuery) {
this.fullyQualifiedTableNames = fullyQualifiedTableNames;
this.isCreateTableQuery = isCreateTableQuery;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType;
import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails;
import org.opensearch.sql.spark.flint.FlintIndexType;
import org.opensearch.sql.spark.utils.SQLQueryUtils.TableExtractionResult;

@ExtendWith(MockitoExtension.class)
public class SQLQueryUtilsTest {
Expand Down Expand Up @@ -444,6 +445,69 @@ void testRecoverIndex() {
assertEquals(IndexQueryActionType.RECOVER, indexDetails.getIndexQueryActionType());
}

@Test
void testExtractFullyQualifiedTableNamesWithMetadata() {
// Test CREATE TABLE queries
String createTableQuery =
"CREATE EXTERNAL TABLE\n"
+ "myS3.default.alb_logs\n"
+ "[ PARTITIONED BY (col_name [, … ] ) ]\n"
+ "[ ROW FORMAT DELIMITED row_format ]\n"
+ "STORED AS file_format\n"
+ "LOCATION { 's3://bucket/folder/' }";

TableExtractionResult result =
SQLQueryUtils.extractFullyQualifiedTableNamesWithMetadata(createTableQuery);
assertTrue(result.isCreateTableQuery());
assertEquals(1, result.getFullyQualifiedTableNames().size());
assertFullyQualifiedTableName(
"myS3", "default", "alb_logs", result.getFullyQualifiedTableNames().get(0));

String createTableQuery2 =
"CREATE TABLE myS3.default.new_table (id INT, name STRING) USING PARQUET";
result = SQLQueryUtils.extractFullyQualifiedTableNamesWithMetadata(createTableQuery2);
assertTrue(result.isCreateTableQuery());
assertEquals(1, result.getFullyQualifiedTableNames().size());
assertFullyQualifiedTableName(
"myS3", "default", "new_table", result.getFullyQualifiedTableNames().get(0));

// Test SELECT queries
String selectQuery = "SELECT * FROM myS3.default.alb_logs";
result = SQLQueryUtils.extractFullyQualifiedTableNamesWithMetadata(selectQuery);
assertFalse(result.isCreateTableQuery());
assertEquals(1, result.getFullyQualifiedTableNames().size());
assertFullyQualifiedTableName(
"myS3", "default", "alb_logs", result.getFullyQualifiedTableNames().get(0));

// Test DROP TABLE queries
String dropTableQuery = "DROP TABLE myS3.default.alb_logs";
result = SQLQueryUtils.extractFullyQualifiedTableNamesWithMetadata(dropTableQuery);
assertFalse(result.isCreateTableQuery());
assertEquals(1, result.getFullyQualifiedTableNames().size());
assertFullyQualifiedTableName(
"myS3", "default", "alb_logs", result.getFullyQualifiedTableNames().get(0));

// Test DESCRIBE TABLE queries
String describeTableQuery = "DESCRIBE TABLE myS3.default.alb_logs";
result = SQLQueryUtils.extractFullyQualifiedTableNamesWithMetadata(describeTableQuery);
assertFalse(result.isCreateTableQuery());
assertEquals(1, result.getFullyQualifiedTableNames().size());
assertFullyQualifiedTableName(
"myS3", "default", "alb_logs", result.getFullyQualifiedTableNames().get(0));

// Test JOIN queries
String joinQuery =
"SELECT * FROM myS3.default.alb_logs JOIN myS3.default.http_logs ON alb_logs.id ="
+ " http_logs.id";
result = SQLQueryUtils.extractFullyQualifiedTableNamesWithMetadata(joinQuery);
assertFalse(result.isCreateTableQuery());
assertEquals(2, result.getFullyQualifiedTableNames().size());
assertFullyQualifiedTableName(
"myS3", "default", "alb_logs", result.getFullyQualifiedTableNames().get(0));
assertFullyQualifiedTableName(
"myS3", "default", "http_logs", result.getFullyQualifiedTableNames().get(1));
}

@Getter
protected static class IndexQuery {
private String query;
Expand Down
Loading