Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions c/driver/postgresql/copy/postgres_copy_writer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,180 @@ TEST_F(PostgresCopyTest, PostgresCopyWriteNumeric) {
}
}

// Regression test for bug where 44.123456 with Decimal(10,6) became 4412.345500
// COPY (SELECT CAST(col AS NUMERIC) AS col FROM ( VALUES (44.123456),
// (0.123456), (123.456789)) AS drvd(col)) TO STDOUT WITH (FORMAT binary);
static uint8_t kTestPgCopyNumericScale6[] = {
0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x03, 0x00,
0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x2c, 0x04, 0xd2, 0x15, 0xe0, 0x00, 0x01, 0x00,
0x00, 0x00, 0x0c, 0x00, 0x02, 0xff, 0xff, 0x00, 0x00, 0x00, 0x06, 0x04, 0xd2, 0x15,
0xe0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
0x06, 0x00, 0x7b, 0x11, 0xd7, 0x22, 0xc4, 0xff, 0xff};

TEST_F(PostgresCopyTest, PostgresCopyWriteNumericScale6) {
adbc_validation::Handle<struct ArrowSchema> schema;
adbc_validation::Handle<struct ArrowArray> array;
struct ArrowError na_error;
constexpr enum ArrowType type = NANOARROW_TYPE_DECIMAL128;
constexpr int32_t size = 128;
constexpr int32_t precision = 38;
constexpr int32_t scale = 6;

struct ArrowDecimal decimal1;
struct ArrowDecimal decimal2;
struct ArrowDecimal decimal3;

ArrowDecimalInit(&decimal1, size, precision, scale);
ArrowDecimalSetInt(&decimal1, 44123456);

ArrowDecimalInit(&decimal2, size, precision, scale);
ArrowDecimalSetInt(&decimal2, 123456);

ArrowDecimalInit(&decimal3, size, precision, scale);
ArrowDecimalSetInt(&decimal3, 123456789);

const std::vector<std::optional<ArrowDecimal*>> values = {&decimal1, &decimal2,
&decimal3};

ArrowSchemaInit(&schema.value);
ASSERT_EQ(ArrowSchemaSetTypeStruct(&schema.value, 1), 0);
ASSERT_EQ(ArrowSchemaSetTypeDecimal(schema.value.children[0], type, precision, scale),
0);
ASSERT_EQ(ArrowSchemaSetName(schema.value.children[0], "col"), 0);
ASSERT_EQ(adbc_validation::MakeBatch<ArrowDecimal*>(&schema.value, &array.value,
&na_error, values),
ADBC_STATUS_OK);

PostgresCopyStreamWriteTester tester;
ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK);
ASSERT_EQ(tester.WriteAll(nullptr), ENODATA);

const struct ArrowBuffer buf = tester.WriteBuffer();

constexpr size_t buf_size = sizeof(kTestPgCopyNumericScale6) - 2;
ASSERT_EQ(buf.size_bytes, static_cast<int64_t>(buf_size));
for (size_t i = 0; i < buf_size; i++) {
ASSERT_EQ(buf.data[i], kTestPgCopyNumericScale6[i]) << " at position " << i;
}
}

// Test for scale=5 (remainder 1 when divided by 4)
// COPY (SELECT CAST(col AS NUMERIC) AS col FROM ( VALUES (12.34567),
// (-9.87654), (0.00123)) AS drvd(col)) TO STDOUT WITH (FORMAT binary);
static uint8_t kTestPgCopyNumericScale5[] = {
0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x03, 0x00,
0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x0c, 0x0d, 0x80, 0x1b, 0x58, 0x00, 0x01, 0x00,
0x00, 0x00, 0x0e, 0x00, 0x03, 0x00, 0x00, 0x40, 0x00, 0x00, 0x05, 0x00, 0x09, 0x22,
0x3d, 0x0f, 0xa0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x02, 0xff, 0xff, 0x00,
0x00, 0x00, 0x05, 0x00, 0x0c, 0x0b, 0xb8, 0xff, 0xff};

TEST_F(PostgresCopyTest, PostgresCopyWriteNumericScale5) {
adbc_validation::Handle<struct ArrowSchema> schema;
adbc_validation::Handle<struct ArrowArray> array;
struct ArrowError na_error;
constexpr enum ArrowType type = NANOARROW_TYPE_DECIMAL128;
constexpr int32_t size = 128;
constexpr int32_t precision = 38;
constexpr int32_t scale = 5;

struct ArrowDecimal decimal1;
struct ArrowDecimal decimal2;
struct ArrowDecimal decimal3;

ArrowDecimalInit(&decimal1, size, precision, scale);
ArrowDecimalSetInt(&decimal1, 1234567);

ArrowDecimalInit(&decimal2, size, precision, scale);
ArrowDecimalSetInt(&decimal2, -987654);

ArrowDecimalInit(&decimal3, size, precision, scale);
ArrowDecimalSetInt(&decimal3, 123);

const std::vector<std::optional<ArrowDecimal*>> values = {&decimal1, &decimal2,
&decimal3};

ArrowSchemaInit(&schema.value);
ASSERT_EQ(ArrowSchemaSetTypeStruct(&schema.value, 1), 0);
ASSERT_EQ(ArrowSchemaSetTypeDecimal(schema.value.children[0], type, precision, scale),
0);
ASSERT_EQ(ArrowSchemaSetName(schema.value.children[0], "col"), 0);
ASSERT_EQ(adbc_validation::MakeBatch<ArrowDecimal*>(&schema.value, &array.value,
&na_error, values),
ADBC_STATUS_OK);

PostgresCopyStreamWriteTester tester;
ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK);
ASSERT_EQ(tester.WriteAll(nullptr), ENODATA);

const struct ArrowBuffer buf = tester.WriteBuffer();
constexpr size_t buf_size = sizeof(kTestPgCopyNumericScale5) - 2;
ASSERT_EQ(buf.size_bytes, static_cast<int64_t>(buf_size));
for (size_t i = 0; i < buf_size; i++) {
ASSERT_EQ(buf.data[i], kTestPgCopyNumericScale5[i]) << " at position " << i;
}
}

// Test for scale=7 (remainder 3 when divided by 4)
// COPY (SELECT CAST(col AS NUMERIC) AS col FROM ( VALUES (5.1234567),
// (-123.456789), (0.0000001)) AS drvd(col)) TO STDOUT WITH (FORMAT binary);
static uint8_t kTestPgCopyNumericScale7[] = {
0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0e, 0x00,
0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x00, 0x05, 0x04, 0xd2, 0x16, 0x26,
0x00, 0x01, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x03, 0x00, 0x00, 0x40, 0x00, 0x00,
0x06, 0x00, 0x7b, 0x11, 0xd7, 0x22, 0xc4, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a,
0x00, 0x01, 0xff, 0xfe, 0x00, 0x00, 0x00, 0x07, 0x00, 0x0a, 0xff, 0xff};

TEST_F(PostgresCopyTest, PostgresCopyWriteNumericScale7) {
adbc_validation::Handle<struct ArrowSchema> schema;
adbc_validation::Handle<struct ArrowArray> array;
struct ArrowError na_error;
constexpr enum ArrowType type = NANOARROW_TYPE_DECIMAL128;
constexpr int32_t size = 128;
constexpr int32_t precision = 38;
constexpr int32_t scale = 7;

struct ArrowDecimal decimal1;
struct ArrowDecimal decimal2;
struct ArrowDecimal decimal3;

ArrowDecimalInit(&decimal1, size, precision, scale);
ArrowDecimalSetInt(&decimal1, 51234567);

// This represents -123.456789, but NUMERIC(10,7) will display it as -123.4567890
ArrowDecimalInit(&decimal2, size, precision, scale);
ArrowDecimalSetInt(&decimal2, -1234567890);

// 0.0000001 with scale=7 -> internal value: 1
ArrowDecimalInit(&decimal3, size, precision, scale);
ArrowDecimalSetInt(&decimal3, 1);

const std::vector<std::optional<ArrowDecimal*>> values = {&decimal1, &decimal2,
&decimal3};

ArrowSchemaInit(&schema.value);
ASSERT_EQ(ArrowSchemaSetTypeStruct(&schema.value, 1), 0);
ASSERT_EQ(ArrowSchemaSetTypeDecimal(schema.value.children[0], type, precision, scale),
0);
ASSERT_EQ(ArrowSchemaSetName(schema.value.children[0], "col"), 0);
ASSERT_EQ(adbc_validation::MakeBatch<ArrowDecimal*>(&schema.value, &array.value,
&na_error, values),
ADBC_STATUS_OK);

PostgresCopyStreamWriteTester tester;
ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK);
ASSERT_EQ(tester.WriteAll(nullptr), ENODATA);

const struct ArrowBuffer buf = tester.WriteBuffer();
constexpr size_t buf_size = sizeof(kTestPgCopyNumericScale7) - 2;
ASSERT_EQ(buf.size_bytes, static_cast<int64_t>(buf_size));
for (size_t i = 0; i < buf_size; i++) {
ASSERT_EQ(buf.data[i], kTestPgCopyNumericScale7[i]) << " at position " << i;
}
}

using TimestampTestParamType =
std::tuple<enum ArrowTimeUnit, const char*, std::vector<std::optional<int64_t>>>;

Expand Down
124 changes: 89 additions & 35 deletions c/driver/postgresql/copy/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#pragma once

#include <algorithm>
#include <charconv>
#include <cinttypes>
#include <limits>
Expand Down Expand Up @@ -234,48 +235,101 @@ class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
// Number of decimal digits per Postgres digit
constexpr int kDecDigits = 4;
std::vector<int16_t> pg_digits;
int16_t weight = -(scale_ / kDecDigits);
int16_t dscale = scale_;
bool seen_decimal = scale_ == 0;
bool truncating_trailing_zeros = true;
int16_t weight;
int16_t dscale;

char decimal_string[max_decimal_digits_ + 1];
int digits_remaining = DecimalToString<bitwidth_>(&decimal, decimal_string);
do {
const int start_pos =
digits_remaining < kDecDigits ? 0 : digits_remaining - kDecDigits;
const size_t len = digits_remaining < 4 ? digits_remaining : kDecDigits;
const std::string_view substr{decimal_string + start_pos, len};
int16_t val{};
std::from_chars(substr.data(), substr.data() + substr.size(), val);
int total_digits = DecimalToString<bitwidth_>(&decimal, decimal_string);

if (val == 0) {
if (!seen_decimal && truncating_trailing_zeros) {
dscale -= kDecDigits;
}
} else {
pg_digits.insert(pg_digits.begin(), val);
if (!seen_decimal && truncating_trailing_zeros) {
if (val % 1000 == 0) {
dscale -= 3;
} else if (val % 100 == 0) {
dscale -= 2;
} else if (val % 10 == 0) {
dscale -= 1;
}
const int n_int_digits = total_digits > scale_ ? total_digits - scale_ : 0;
int n_frac_digits = total_digits > n_int_digits ? total_digits - n_int_digits : 0;

std::string_view decimal_string_view(decimal_string, total_digits);
std::string_view int_part = decimal_string_view.substr(0, n_int_digits);

std::string frac_part_str;
if (n_int_digits == 0 && total_digits < scale_) {
frac_part_str.assign(scale_ - total_digits, '0');
frac_part_str.append(decimal_string, total_digits);
n_frac_digits = scale_;
} else {
frac_part_str.assign(decimal_string_view.substr(n_int_digits, n_frac_digits));
}
std::string_view frac_part(frac_part_str);

// Count trailing zeros in the fractional part to minimize dscale
int actual_trailing_zeros = 0;
for (int j = frac_part.length() - 1; j >= 0 && frac_part[j] == '0'; j--) {
actual_trailing_zeros++;
}

// Group integer part
int i = int_part.length();
std::vector<int16_t> int_digits;
int n_int_digit_groups = 0;
if (i > 0) {
// Calculate weight based on original integer length
weight = (i + kDecDigits - 1) / kDecDigits - 1;

while (i > 0) {
int chunk_size = std::min(i, kDecDigits);
std::string_view chunk = int_part.substr(i - chunk_size, chunk_size);
int16_t val{};
std::from_chars(chunk.data(), chunk.data() + chunk.size(), val);
// Skip trailing zeros in integer part (which appear first when processing
// right-to-left)
if (val != 0 || !int_digits.empty()) {
int_digits.insert(int_digits.begin(), val);
}
truncating_trailing_zeros = false;
i -= chunk_size;
}
digits_remaining -= kDecDigits;
if (digits_remaining <= 0) {
break;
}
weight++;
n_int_digit_groups = int_digits.size();
pg_digits.insert(pg_digits.end(), int_digits.begin(), int_digits.end());
} else {
weight = -1;
n_int_digit_groups = 0;
}

// Group fractional part
// Chunk in 4-digit groups, padding the LAST group on the right if needed
i = 0;
bool skip_leading_zeros = (n_int_digits == 0);

while (i < (int)frac_part.length()) {
int chunk_size = std::min((int)frac_part.length() - i, kDecDigits);
std::string chunk_str(frac_part.substr(i, chunk_size));

if (start_pos <= static_cast<int>(std::strlen(decimal_string)) - scale_) {
seen_decimal = true;
// Pad the last group on the RIGHT if it's less than 4 digits
chunk_str.resize(kDecDigits, '0');

int16_t val{};
std::from_chars(chunk_str.data(), chunk_str.data() + chunk_str.size(), val);

if (skip_leading_zeros && val == 0) {
weight--;
} else {
pg_digits.push_back(val);
skip_leading_zeros = false;
}
} while (true);
i += chunk_size;
}

// Calculate dscale by removing trailing zeros
dscale = scale_ - actual_trailing_zeros;

// Trim trailing full zero digit groups from fractional part
// (these zeros are already accounted for in actual_trailing_zeros)
while (static_cast<int64_t>(pg_digits.size()) > n_int_digit_groups &&
pg_digits.back() == 0) {
pg_digits.pop_back();
}

// If all fractional digits were removed, dscale should be 0
if (static_cast<int64_t>(pg_digits.size()) <= n_int_digit_groups) {
dscale = 0;
}

if (dscale < 0) dscale = 0;

int16_t ndigits = pg_digits.size();
int32_t field_size_bytes = sizeof(ndigits) + sizeof(weight) + sizeof(sign) +
Expand Down
Loading