Skip to content

Commit 899438b

Browse files
authored
Merge pull request #138 from AbrJA/main
Bug fixs for Athena
2 parents cc90b11 + be31c2e commit 899438b

File tree

1 file changed

+20
-21
lines changed

1 file changed

+20
-21
lines changed

ext/AWSExt.jl

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,35 @@ using AWS, HTTP, JSON3
66
__init__() = println("Extension was loaded!")
77

88

9-
109
function collect_athena(result)
1110
# Extract column names and types from the result set metadata
1211
column_names = [col["Label"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]]
1312
column_types = [col["Type"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]]
14-
13+
1514
# Process data rows, starting from the second row to skip header information
1615
data_rows = result["ResultSet"]["Rows"]
1716
filtered_column_names = filter(c -> !isempty(c), column_names)
1817
num_columns = length(filtered_column_names)
19-
18+
2019
data_for_df = [
2120
[get(col, "VarCharValue", missing) for col in row["Data"]] for row in data_rows[2:end]
2221
]
23-
22+
2423
# Ensure each row has the correct number of elements
2524
adjusted_data_for_df = [
2625
length(row) == num_columns ? row : resize!(copy(row), num_columns) for row in data_for_df
2726
]
28-
27+
2928
# Pad rows with missing values if they are shorter than expected
3029
for row in adjusted_data_for_df
3130
if length(row) < num_columns
3231
append!(row, fill(missing, num_columns - length(row)))
3332
end
3433
end
35-
34+
3635
# Transpose the data to match DataFrame constructor requirements
3736
data_transposed = permutedims(hcat(adjusted_data_for_df...))
38-
37+
3938
# Create the DataFrame
4039
df = DataFrame(data_transposed, Symbol.(filtered_column_names))
4140
TidierDB.parse_athena_df(df, column_types)
@@ -45,13 +44,13 @@ end
4544

4645
@service Athena
4746

48-
function TidierDB.get_table_metadata(AWS_GLOBAL_CONFIG, table_name::String, athena_params)
47+
function TidierDB.get_table_metadata(AWS_GLOBAL_CONFIG, table_name::String; athena_params)
4948
schema, table = split(table_name, '.') # Ensure this correctly parses your input
5049
query = """SELECT * FROM $schema.$table limit 0;"""
5150
# println(query)
5251
# try
5352
exe_query = Athena.start_query_execution(query, athena_params; aws_config = AWS_GLOBAL_CONFIG)
54-
53+
5554
# Poll Athena to check if the query has completed
5655
status = "RUNNING"
5756
while status in ["RUNNING", "QUEUED"]
@@ -64,10 +63,10 @@ function TidierDB.get_table_metadata(AWS_GLOBAL_CONFIG, table_name::String, athe
6463
error("Query was cancelled.")
6564
end
6665
end
67-
66+
6867
# Fetch the results once the query completes
6968
result = Athena.get_query_results(exe_query["QueryExecutionId"], athena_params; aws_config = AWS_GLOBAL_CONFIG)
70-
69+
7170
column_names = [col["Label"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]]
7271
column_types = [col["Type"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]]
7372
df = DataFrame(name = column_names, type = column_types)
@@ -92,17 +91,17 @@ function TidierDB.final_collect(sqlquery::SQLQuery, ::Type{<:athena})
9291
error("Query was cancelled.")
9392
end
9493
end
95-
result = Athena.get_query_results(exe_query["QueryExecutionId"], sqlquery.athena_params; aws_config = sqlquery.db)
96-
return collect_athena(result)
94+
dfs = []
95+
next = true
96+
params = sqlquery.athena_params
97+
while next
98+
result = Athena.get_query_results(exe_query["QueryExecutionId"], params; aws_config = sqlquery.db)
99+
next = haskey(result, "NextToken")
100+
params = Dict{String, Any}(mergewith(_merge, next ? Dict("NextToken" => result["NextToken"]) : Dict(), sqlquery.athena_params))
101+
push!(dfs, collect_athena(result))
102+
end
103+
return vcat(dfs...)
97104
end
98105

99106

100107
end
101-
102-
103-
104-
105-
106-
107-
108-

0 commit comments

Comments
 (0)