Skip to content

Commit ed77af1

Browse files
authored
Merge pull request #151 from AbrJA/main
2 parents 8f6d7c4 + b1e41ab commit ed77af1

File tree

1 file changed

+37
-40
lines changed

1 file changed

+37
-40
lines changed

ext/AWSExt.jl

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,65 +7,60 @@ __init__() = println("Extension was loaded!")
77

88
function collect_athena(result, has_header = true)
99
# Extract column names and types from the result set metadata
10-
column_names = [col["Label"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]]
11-
column_types = [col["Type"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]]
10+
column_metadata = result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]
11+
column_names = [col["Label"] for col in column_metadata]
12+
column_types = [col["Type"] for col in column_metadata]
13+
num_columns = length(column_names)
1214

1315
# Process data rows, starting from the second row to skip header information
14-
data_rows = result["ResultSet"]["Rows"]
15-
filtered_column_names = filter(c -> !isempty(c), column_names)
16-
num_columns = length(filtered_column_names)
16+
start = has_header ? 2 : 1
17+
data_rows = result["ResultSet"]["Rows"][start:end]
1718

18-
has_header ? start = 2 : start = 1
19-
data_for_df = [
20-
[get(col, "VarCharValue", missing) for col in row["Data"]] for row in data_rows[start:end]
21-
]
19+
if isempty(data_rows)
20+
df = DataFrame([name => String[] for name in Symbol.(column_names)])
21+
return TidierDB.parse_athena_df(df, column_types)
22+
end
2223

23-
# Ensure each row has the correct number of elements
24-
adjusted_data_for_df = [
25-
length(row) == num_columns ? row : resize!(copy(row), num_columns) for row in data_for_df
26-
]
24+
# Extract data from each row and handle missing values
25+
data_matrix = Matrix{Union{String, Missing}}(undef, length(data_rows), num_columns)
2726

28-
# Pad rows with missing values if they are shorter than expected
29-
for row in adjusted_data_for_df
30-
if length(row) < num_columns
31-
append!(row, fill(missing, num_columns - length(row)))
27+
for (row_idx, row) in enumerate(data_rows)
28+
row_data = row["Data"]
29+
for col_idx in 1:num_columns
30+
data_matrix[row_idx, col_idx] = get(row_data[col_idx], "VarCharValue", missing)
3231
end
3332
end
3433

35-
# Transpose the data to match DataFrame constructor requirements
36-
data_transposed = permutedims(hcat(adjusted_data_for_df...))
37-
3834
# Create the DataFrame
39-
df = DataFrame(data_transposed, Symbol.(filtered_column_names))
40-
TidierDB.parse_athena_df(df, column_types)
35+
df = DataFrame(data_matrix, Symbol.(column_names))
4136
# Return the DataFrame
42-
return df
37+
return TidierDB.parse_athena_df(df, column_types)
4338
end
4439

4540
@service Athena
4641

4742
function TidierDB.get_table_metadata(AWS_GLOBAL_CONFIG, table_name::String; athena_params)
4843
schema, table = split(table_name, '.') # Ensure this correctly parses your input
4944
query = """SELECT * FROM $schema.$table limit 0;"""
50-
# println(query)
51-
# try
52-
exe_query = Athena.start_query_execution(query, athena_params; aws_config = AWS_GLOBAL_CONFIG)
45+
exe_query = Athena.start_query_execution(query, athena_params; aws_config = AWS_GLOBAL_CONFIG)
5346

54-
# Poll Athena to check if the query has completed
55-
status = "RUNNING"
56-
while status in ["RUNNING", "QUEUED"]
57-
sleep(1) # Wait for 1 second before checking the status again to avoid flooding the API
58-
query_status = Athena.get_query_execution(exe_query["QueryExecutionId"], athena_params; aws_config = AWS_GLOBAL_CONFIG)
59-
status = query_status["QueryExecution"]["Status"]["State"]
60-
if status == "FAILED"
61-
error("Query failed: ", query_status["QueryExecution"]["Status"]["StateChangeReason"])
62-
elseif status == "CANCELLED"
63-
error("Query was cancelled.")
64-
end
47+
# Poll Athena to check if the query has completed
48+
wait_time = 1.0
49+
status = "RUNNING"
50+
while status in ["RUNNING", "QUEUED"]
51+
sleep(round(wait_time)) # Wait for wait_time second before checking the status again to avoid flooding the API
52+
query_status = Athena.get_query_execution(exe_query["QueryExecutionId"], athena_params; aws_config = AWS_GLOBAL_CONFIG)
53+
status = query_status["QueryExecution"]["Status"]["State"]
54+
if status == "FAILED"
55+
error("Query failed: ", query_status["QueryExecution"]["Status"]["StateChangeReason"])
56+
elseif status == "CANCELLED"
57+
error("Query was cancelled.")
6558
end
59+
wait_time = min(wait_time * 1.2, 10.0) # Exponential backoff, max wait time of 10 seconds
60+
end
6661

67-
# Fetch the results once the query completes
68-
result = Athena.get_query_results(exe_query["QueryExecutionId"], athena_params; aws_config = AWS_GLOBAL_CONFIG)
62+
# Fetch the results once the query completes
63+
result = Athena.get_query_results(exe_query["QueryExecutionId"], athena_params; aws_config = AWS_GLOBAL_CONFIG)
6964

7065
column_names = [col["Label"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]]
7166
column_types = [col["Type"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]]
@@ -79,16 +74,18 @@ end
7974
function TidierDB.final_collect(sqlquery::SQLQuery, ::Type{<:athena})
8075
final_query = TidierDB.finalize_query(sqlquery)
8176
exe_query = Athena.start_query_execution(final_query, sqlquery.athena_params; aws_config = sqlquery.db)
77+
wait_time = 1.0
8278
status = "RUNNING"
8379
while status in ["RUNNING", "QUEUED"]
84-
sleep(1) # Wait for 1 second before checking the status again to avoid flooding the API
80+
sleep(round(wait_time)) # Wait for wait_time seconds before checking the status again to avoid flooding the API
8581
query_status = Athena.get_query_execution(exe_query["QueryExecutionId"], sqlquery.athena_params; aws_config = sqlquery.db)
8682
status = query_status["QueryExecution"]["Status"]["State"]
8783
if status == "FAILED"
8884
error("Query failed: ", query_status["QueryExecution"]["Status"]["StateChangeReason"])
8985
elseif status == "CANCELLED"
9086
error("Query was cancelled.")
9187
end
88+
wait_time = min(wait_time * 1.2, 10.0) # Exponential backoff, max wait time of 10 seconds
9289
end
9390
dfs = []
9491
next = true

0 commit comments

Comments
 (0)