Skip to content
Open
338 changes: 217 additions & 121 deletions c/driver/postgresql/copy/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,155 +225,140 @@ class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
PostgresCopyNumericFieldWriter(int32_t precision, int32_t scale)
: precision_{precision}, scale_{scale} {}

// PostgreSQL NUMERIC Binary Format:
// ===================================
// PostgreSQL stores NUMERIC values in a variable-length binary format:
// - ndigits (int16): Number of base-10000 digits stored
// - weight (int16): Position of the first digit group relative to decimal point
// (weight can be negative for small fractional numbers)
// - sign (int16): kNumericPos (0x0000) or kNumericNeg (0x4000)
// - dscale (int16): Number of decimal digits after the decimal point (display scale)
// - digits[]: Array of int16 values, each 0-9999 (base-10000 representation)
//
// Value calculation: sum(digits[i] * 10000^(weight - i)) * 10^(-dscale)
//
// Example 1: 12300 (from Arrow Decimal value=123, scale=-2)
// - Logical representation: "12300"
// - Grouped in base-10000: [1][2300]
// - ndigits=2, weight=1, sign=0x0000, dscale=0, digits=[1, 2300]
// - Calculation: 1*10000^1 + 2300*10000^0 = 10000 + 2300 = 12300
//
// Example 2: 123.45 (from Arrow Decimal value=12345, scale=2)
// - Logical representation: "123.45"
// - Integer part "123", fractional part "45"
// - Grouped in base-10000: [123][4500] (fractional part right-padded)
// - ndigits=2, weight=0, sign=0x0000, dscale=2, digits=[123, 4500]
// - Calculation: 123*10000^0 + 4500*10000^(-1) = 123 + 0.45 = 123.45
//
// Example 3: 0.00123 (from Arrow Decimal value=123, scale=5)
// - Logical representation: "0.00123"
// - Integer part "0", fractional part "00123"
// - Grouped in base-10000: [123] (leading zeros skipped via negative weight)
// - ndigits=1, weight=-1, sign=0x0000, dscale=5, digits=[123]
// - Calculation: 123*10000^(-1) * 10^0 = 0.0123, but dscale=5 means display as
// 0.00123

ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
struct ArrowDecimal decimal;
ArrowDecimalInit(&decimal, bitwidth_, precision_, scale_);
ArrowArrayViewGetDecimalUnsafe(array_view_, index, &decimal);

const int16_t sign = ArrowDecimalSign(&decimal) > 0 ? kNumericPos : kNumericNeg;

// Number of decimal digits per Postgres digit
constexpr int kDecDigits = 4;
std::vector<int16_t> pg_digits;
// There are `weight + 1` base 10000 digits before the decimal point
// (may be negative)
int16_t weight;
// "decimal scale". Number of digits after the decimal point (>=0)
// dscale may be more than the actual number of stored digits,
// implying there are significant zeroes that were not stored
int16_t dscale;

char decimal_string[max_decimal_digits_ + 1];
int total_digits = DecimalToString<bitwidth_>(&decimal, decimal_string);

// Handle negative scale by appending zeros
int effective_scale = scale_;
if (scale_ < 0) {
int zeros_to_append = -scale_;
std::memset(decimal_string + total_digits, '0', zeros_to_append);
total_digits += zeros_to_append;
decimal_string[total_digits] = '\0';
effective_scale = 0;
}

const int n_int_digits =
total_digits > effective_scale ? total_digits - effective_scale : 0;
int n_frac_digits = total_digits > n_int_digits ? total_digits - n_int_digits : 0;

std::string_view decimal_string_view(decimal_string, total_digits);
std::string_view int_part = decimal_string_view.substr(0, n_int_digits);

std::string frac_part_str;
if (n_int_digits == 0 && total_digits < effective_scale) {
frac_part_str.assign(effective_scale - total_digits, '0');
frac_part_str.append(decimal_string, total_digits);
n_frac_digits = effective_scale;
} else {
frac_part_str.assign(decimal_string_view.substr(n_int_digits, n_frac_digits));
}
std::string_view frac_part(frac_part_str);

// Count trailing zeros in the fractional part to minimize dscale
int actual_trailing_zeros = 0;
for (int j = frac_part.length() - 1; j >= 0 && frac_part[j] == '0'; j--) {
actual_trailing_zeros++;
}

// Group integer part
int i = int_part.length();
std::vector<int16_t> int_digits;
int n_int_digit_groups = 0;
if (i > 0) {
// Calculate weight based on original integer length
weight = (i + kDecDigits - 1) / kDecDigits - 1;

while (i > 0) {
int chunk_size = std::min(i, kDecDigits);
std::string_view chunk = int_part.substr(i - chunk_size, chunk_size);
int16_t val{};
std::from_chars(chunk.data(), chunk.data() + chunk.size(), val);
// Skip trailing zeros in integer part (which appear first when processing
// right-to-left)
if (val != 0 || !int_digits.empty()) {
int_digits.insert(int_digits.begin(), val);
}
i -= chunk_size;
}
n_int_digit_groups = int_digits.size();
pg_digits.insert(pg_digits.end(), int_digits.begin(), int_digits.end());
} else {
weight = -1;
n_int_digit_groups = 0;
}

// Group fractional part
// Chunk in 4-digit groups, padding the LAST group on the right if needed
i = 0;
bool skip_leading_zeros = (n_int_digits == 0);

while (i < static_cast<int>(frac_part.length())) {
int chunk_size = std::min(static_cast<int>(frac_part.length()) - i, kDecDigits);
std::string chunk_str(frac_part.substr(i, chunk_size));

// Pad the last group on the RIGHT if it's less than 4 digits
chunk_str.resize(kDecDigits, '0');

int16_t val{};
std::from_chars(chunk_str.data(), chunk_str.data() + chunk_str.size(), val);

if (skip_leading_zeros && val == 0) {
weight--;
} else {
pg_digits.push_back(val);
skip_leading_zeros = false;
}
i += chunk_size;
}

// Calculate dscale by removing trailing zeros
dscale = effective_scale - actual_trailing_zeros;

// Trim trailing full zero digit groups from fractional part
// (these zeros are already accounted for in actual_trailing_zeros)
while (static_cast<int64_t>(pg_digits.size()) > n_int_digit_groups &&
pg_digits.back() == 0) {
pg_digits.pop_back();
}

// If all fractional digits were removed, dscale should be 0
if (static_cast<int64_t>(pg_digits.size()) <= n_int_digit_groups) {
// Convert decimal to string and split into integer/fractional parts
// Example transformation for Arrow Decimal(value=12345, scale=2) representing 123.45:
// Input: decimal.value = 12345, scale_ = 2
// After DecimalToString: raw_decimal_string = "12345", original_digits = 5
// After SplitDecimalParts: parts.integer_part = "123"
// parts.fractional_part = "45"
// parts.effective_scale = 2
char raw_decimal_string[max_decimal_digits_ + 1];
int original_digits = DecimalToString<bitwidth_>(&decimal, raw_decimal_string);
DecimalParts parts = SplitDecimalParts(raw_decimal_string, original_digits, scale_);

// Group into PostgreSQL base-10000 representation
// After GroupIntegerDigits: int_digits = [123], weight = 0
// (groups "123" right-to-left: "123" → 123, only 1 group so weight = 0)
auto [int_digits, weight] = GroupIntegerDigits(parts.integer_part);

// After GroupFractionalDigits: frac_digits = [4500], final_weight = 0
// (groups "45" left-to-right with right-padding: "45" → "4500" → 4500)
auto [frac_digits, final_weight] =
GroupFractionalDigits(parts.fractional_part, weight, !parts.integer_part.empty());

// Combine digit arrays
// After combining: all_digits = [123, 4500]
std::vector<int16_t> all_digits = int_digits;
all_digits.insert(all_digits.end(), frac_digits.begin(), frac_digits.end());

// Calculate display scale by counting trailing zeros in the DECIMAL STRING
// For our example: frac_part="45" has 0 trailing zeros, effective_scale=2
// So dscale = 2 - 0 = 2 (2 fractional digits to display)
int trailing_zeros = 0;
for (int j = parts.fractional_part.length() - 1;
j >= 0 && parts.fractional_part[j] == '0'; j--) {
trailing_zeros++;
}
int16_t dscale = std::max<int16_t>(0, parts.effective_scale - trailing_zeros);

// Optimize: remove trailing zero digit groups from fractional part
int n_int_digit_groups = int_digits.size();
while (static_cast<int>(all_digits.size()) > n_int_digit_groups &&
all_digits.back() == 0) {
all_digits.pop_back();
}

// Handle zero special case
if (all_digits.empty()) {
final_weight = 0;
dscale = 0;
} else if (static_cast<int>(all_digits.size()) <= n_int_digit_groups) {
// All fractional digits were removed
dscale = 0;
// For zero (no digits at all), use canonical weight=0
if (pg_digits.empty()) {
weight = 0;
}
}

if (dscale < 0) dscale = 0;

int16_t ndigits = pg_digits.size();
int32_t field_size_bytes = sizeof(ndigits) + sizeof(weight) + sizeof(sign) +
// Write PostgreSQL NUMERIC binary format to buffer
// Final values for our example: ndigits = 2
// final_weight = 0
// sign = 0x0000
// dscale = 2
// digits = [123, 4500]
// Binary output represents: 123 * 10000^0 + 4500 * 10000^(-1) = 123 + 0.45 = 123.45
int16_t ndigits = all_digits.size();
int32_t field_size_bytes = sizeof(ndigits) + sizeof(final_weight) + sizeof(sign) +
sizeof(dscale) + ndigits * sizeof(int16_t);

NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, field_size_bytes, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, ndigits, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, weight, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, final_weight, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, sign, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, dscale, error));

const size_t pg_digit_bytes = sizeof(int16_t) * pg_digits.size();
const size_t pg_digit_bytes = sizeof(int16_t) * all_digits.size();
NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(buffer, pg_digit_bytes));
for (auto pg_digit : pg_digits) {
for (auto pg_digit : all_digits) {
WriteUnsafe<int16_t>(buffer, pg_digit);
}

return ADBC_STATUS_OK;
}

private:
// returns the length of the string
// Helper struct for organizing data flow between functions
struct DecimalParts {
std::string integer_part; // e.g., "12300" or "123"
std::string fractional_part; // e.g., "45" or "00123"
int effective_scale; // Scale after handling negative values
};

// Helper function implementations for decimal-to-PostgreSQL NUMERIC conversion

// Convert decimal to string (absolute value, no sign)
// Returns the length of the string
template <int32_t DEC_WIDTH>
int DecimalToString(struct ArrowDecimal* decimal, char* out) {
int DecimalToString(struct ArrowDecimal* decimal, char* out) const {
constexpr size_t nwords = (DEC_WIDTH == 128) ? 2 : 4;
uint8_t tmp[DEC_WIDTH / 8];
ArrowDecimalGetBytes(decimal, tmp);
Expand Down Expand Up @@ -423,6 +408,117 @@ class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
return ndigits;
}

DecimalParts SplitDecimalParts(const char* decimal_digits, int digit_count,
int scale) const {
// Virtual zeros represent the logical zeros appended for negative scale
// Example: value=123, scale=-2 → "123" with 2 virtual zeros = "12300"
const int virtual_zeros = (scale < 0) ? -scale : 0;
const int effective_scale = (scale < 0) ? 0 : scale;
const int total_logical_digits = digit_count + virtual_zeros;

// Calculate split point
const int n_int_digits = total_logical_digits > effective_scale
? total_logical_digits - effective_scale
: 0;
const int n_frac_digits = total_logical_digits - n_int_digits;

DecimalParts parts;
parts.effective_scale = effective_scale;

// Extract integer part
if (n_int_digits > 0) {
if (n_int_digits <= digit_count) {
// Integer part is within the original digits
parts.integer_part.assign(decimal_digits, n_int_digits);
} else {
// Integer part includes all original digits + virtual zeros
parts.integer_part.assign(decimal_digits, digit_count);
parts.integer_part.append(virtual_zeros, '0');
}
}

// Extract fractional part (only exists if scale > 0)
if (n_int_digits == 0 && total_logical_digits < effective_scale) {
// Small fractional: 0.00123 needs leading zeros
parts.fractional_part.assign(effective_scale - total_logical_digits, '0');
parts.fractional_part.append(decimal_digits, digit_count);
} else if (n_frac_digits > 0 && n_int_digits < digit_count) {
// Fractional part from remaining digits (virtual zeros don't appear in fractional
// part)
parts.fractional_part.assign(decimal_digits + n_int_digits,
digit_count - n_int_digits);
}

return parts;
}

std::pair<std::vector<int16_t>, int16_t> GroupIntegerDigits(
const std::string& int_part) const {
constexpr int kDecDigits = 4;
std::vector<int16_t> digits;

if (int_part.empty()) {
return {digits, -1}; // weight = -1 for pure fractional numbers
}

// Calculate weight: ceil(length / 4) - 1
int16_t weight = (int_part.length() + kDecDigits - 1) / kDecDigits - 1;

// Group right-to-left in chunks of 4
int i = int_part.length();
while (i > 0) {
int chunk_size = std::min(i, kDecDigits);
std::string_view chunk =
std::string_view(int_part).substr(i - chunk_size, chunk_size);

int16_t val{};
std::from_chars(chunk.data(), chunk.data() + chunk.size(), val);

// Skip trailing zeros
if (val != 0 || !digits.empty()) {
digits.insert(digits.begin(), val);
}
i -= chunk_size;
}

return {digits, weight};
}

std::pair<std::vector<int16_t>, int16_t> GroupFractionalDigits(
const std::string& frac_part, int16_t initial_weight, bool has_integer_part) const {
constexpr int kDecDigits = 4;
std::vector<int16_t> digits;
int16_t weight = initial_weight;

if (frac_part.empty()) {
return {digits, weight};
}

bool skip_leading_zeros = !has_integer_part;

// Group left-to-right in chunks of 4, right-padding last chunk
for (size_t i = 0; i < frac_part.length(); i += kDecDigits) {
int chunk_size = std::min(kDecDigits, static_cast<int>(frac_part.length() - i));
std::string chunk_str = frac_part.substr(i, chunk_size);

// Right-pad to 4 digits (e.g., "45" → "4500")
chunk_str.resize(kDecDigits, '0');

int16_t val{};
std::from_chars(chunk_str.data(), chunk_str.data() + chunk_str.size(), val);

if (skip_leading_zeros && val == 0) {
// Skip leading zero groups in fractional part (e.g., 0.0012 → skip "0012")
weight--;
} else {
digits.push_back(val);
skip_leading_zeros = false;
}
}

return {digits, weight};
}

static constexpr uint16_t kNumericPos = 0x0000;
static constexpr uint16_t kNumericNeg = 0x4000;
static constexpr int32_t bitwidth_ = (T == NANOARROW_TYPE_DECIMAL128) ? 128 : 256;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
# under the License.


skip = "decimal ingest code has a bug and fix has not been merged in yet. https://github.com/apache/arrow-adbc/pull/3787"
skip = "AssertionError: Field types do not match: assert Decimal128Type(decimal128(10, 2)) == OpaqueType(extension<arrow.opaque[storage_type=string, type_name=numeric, vendor_name=PostgreSQL]>)"
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
# under the License.


skip = "decimal ingest code has a bug and fix has not been merged in yet. https://github.com/apache/arrow-adbc/pull/3787"
skip = "AssertionError: Field types do not match: assert Decimal128Type(decimal128(10, 2)) == OpaqueType(extension<arrow.opaque[storage_type=string, type_name=numeric, vendor_name=PostgreSQL]>)"
Loading