Skip to content

Commit 3016156

Browse files
committed
Add df data and tz type columns back into the same loc after type conversion
1 parent 135f09c commit 3016156

File tree

3 files changed

+29
-15
lines changed

3 files changed

+29
-15
lines changed

src/duckdb_py/include/duckdb_python/pyresult.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ struct DuckDBPyResult {
6666

6767
PandasDataFrame FrameFromNumpy(bool date_as_object, const py::handle &o);
6868

69-
void ChangeToTZType(PandasDataFrame &df);
69+
void ConvertDateTimeTypes(PandasDataFrame &df, bool date_as_object) const;
7070
unique_ptr<DataChunk> FetchNext(QueryResult &result);
7171
unique_ptr<DataChunk> FetchNextRaw(QueryResult &result);
7272
unique_ptr<NumpyResultConversion> InitializeNumpyConversion(bool pandas = false);

src/duckdb_py/pyresult.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,13 @@ py::dict DuckDBPyResult::FetchNumpyInternal(bool stream, idx_t vectors_per_chunk
287287
return res;
288288
}
289289

290+
static void ReplaceDFColumn(PandasDataFrame &df, const char *col_name, idx_t idx, const py::handle &new_value) {
291+
df.attr("drop")("columns"_a = col_name, "inplace"_a = true);
292+
df.attr("insert")(idx, col_name, new_value, "allow_duplicates"_a = false);
293+
}
294+
290295
// TODO: unify these with an enum/flag to indicate which conversions to do
291-
void DuckDBPyResult::ChangeToTZType(PandasDataFrame &df) {
296+
void DuckDBPyResult::ConvertDateTimeTypes(PandasDataFrame &df, bool date_as_object) const {
292297
auto names = df.attr("columns").cast<vector<string>>();
293298

294299
for (idx_t i = 0; i < result->ColumnCount(); i++) {
@@ -297,8 +302,10 @@ void DuckDBPyResult::ChangeToTZType(PandasDataFrame &df) {
297302
auto utc_local = df[names[i].c_str()].attr("dt").attr("tz_localize")("UTC");
298303
auto new_value = utc_local.attr("dt").attr("tz_convert")(result->client_properties.time_zone);
299304
// We need to create the column anew because the exact dt changed to a new timezone
300-
df.attr("drop")("columns"_a = names[i].c_str(), "inplace"_a = true);
301-
df.attr("__setitem__")(names[i].c_str(), new_value);
305+
ReplaceDFColumn(df, names[i].c_str(), i, new_value);
306+
} else if (date_as_object && result->types[i] == LogicalType::DATE) {
307+
auto new_value = df[names[i].c_str()].attr("dt").attr("date");
308+
ReplaceDFColumn(df, names[i].c_str(), i, new_value);
302309
}
303310
}
304311
}
@@ -374,20 +381,11 @@ PandasDataFrame DuckDBPyResult::FrameFromNumpy(bool date_as_object, const py::ha
374381
}
375382

376383
PandasDataFrame df = py::cast<PandasDataFrame>(pandas.attr("DataFrame").attr("from_dict")(o));
377-
// Unfortunately we have to do a type change here for timezones since these types are not supported by numpy
378-
ChangeToTZType(df);
384+
// Convert TZ and (optionally) Date types
385+
ConvertDateTimeTypes(df, date_as_object);
379386

380387
auto names = df.attr("columns").cast<vector<string>>();
381388
D_ASSERT(result->ColumnCount() == names.size());
382-
if (date_as_object) {
383-
for (idx_t i = 0; i < result->ColumnCount(); i++) {
384-
if (result->types[i] == LogicalType::DATE) {
385-
auto new_value = df[names[i].c_str()].attr("dt").attr("date");
386-
df.attr("drop")("columns"_a = names[i].c_str(), "inplace"_a = true);
387-
df.attr("__setitem__")(names[i].c_str(), new_value);
388-
}
389-
}
390-
}
391389
return df;
392390
}
393391

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import duckdb
2+
3+
4+
class TestColumnOrder:
5+
def test_column_order(self, duckdb_cursor):
6+
to_execute = """
7+
CREATE OR REPLACE TABLE t1 AS (
8+
SELECT NULL AS col1,
9+
NULL::TIMESTAMPTZ AS timepoint,
10+
NULL::DATE AS date,
11+
);
12+
SELECT timepoint, date, col1 FROM t1;
13+
"""
14+
df = duckdb.execute(to_execute).fetchdf()
15+
cols = list(df.columns)
16+
assert cols == ["timepoint", "date", "col1"]

0 commit comments

Comments
 (0)