Skip to content

Commit 11205f1

Browse files
feat(c/driver/postgresql): improve conversion of decimal to numeric (#3787)
The logic for converting Arrow Decimal type to the PostgreSQL has been refactored to fix the data not being inserted correctly when the scale is not a multiple of 4. Adds new test cases covering various scales and zero padding scenarios. Closes #3485. --------- Co-authored-by: David Li <[email protected]>
1 parent b51a5d1 commit 11205f1

File tree

6 files changed

+872
-109
lines changed

6 files changed

+872
-109
lines changed

c/driver/postgresql/copy/postgres_copy_writer_test.cc

Lines changed: 569 additions & 15 deletions
Large diffs are not rendered by default.

c/driver/postgresql/copy/writer.h

Lines changed: 226 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#pragma once
1919

20+
#include <algorithm>
2021
#include <charconv>
2122
#include <cinttypes>
2223
#include <limits>
@@ -224,82 +225,141 @@ class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
224225
PostgresCopyNumericFieldWriter(int32_t precision, int32_t scale)
225226
: precision_{precision}, scale_{scale} {}
226227

228+
// PostgreSQL NUMERIC Binary Format:
229+
// ===================================
230+
// PostgreSQL stores NUMERIC values in a variable-length binary format:
231+
// - ndigits (int16): Number of base-10000 digits stored
232+
// - weight (int16): Position of the first digit group relative to decimal point
233+
// (weight can be negative for small fractional numbers)
234+
// - sign (int16): kNumericPos (0x0000) or kNumericNeg (0x4000)
235+
// - dscale (int16): Number of decimal digits after the decimal point (display scale)
236+
// - digits[]: Array of int16 values, each 0-9999 (base-10000 representation)
237+
//
238+
// Value calculation: sum(digits[i] * 10000^(weight - i)) * 10^(-dscale)
239+
//
240+
// Example 1: 12300 (from Arrow Decimal value=123, scale=-2)
241+
// - Logical representation: "12300"
242+
// - Grouped in base-10000: [1][2300]
243+
// - ndigits=2, weight=1, sign=0x0000, dscale=0, digits=[1, 2300]
244+
// - Calculation: 1*10000^1 + 2300*10000^0 = 10000 + 2300 = 12300
245+
//
246+
// Example 2: 123.45 (from Arrow Decimal value=12345, scale=2)
247+
// - Logical representation: "123.45"
248+
// - Integer part "123", fractional part "45"
249+
// - Grouped in base-10000: [123][4500] (fractional part right-padded)
250+
// - ndigits=2, weight=0, sign=0x0000, dscale=2, digits=[123, 4500]
251+
// - Calculation: 123*10000^0 + 4500*10000^(-1) = 123 + 0.45 = 123.45
252+
//
253+
// Example 3: 0.00123 (from Arrow Decimal value=123, scale=5)
254+
// - Logical representation: "0.00123"
255+
// - Integer part "0", fractional part "00123"
256+
// - Grouped in base-10000: [123] (leading zeros skipped via negative weight)
257+
// - ndigits=1, weight=-1, sign=0x0000, dscale=5, digits=[123]
258+
// - Calculation: 123*10000^(-1) * 10^0 = 0.0123, but dscale=5 means display as
259+
// 0.00123
260+
227261
ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
228262
struct ArrowDecimal decimal;
229263
ArrowDecimalInit(&decimal, bitwidth_, precision_, scale_);
230264
ArrowArrayViewGetDecimalUnsafe(array_view_, index, &decimal);
231265

232266
const int16_t sign = ArrowDecimalSign(&decimal) > 0 ? kNumericPos : kNumericNeg;
233267

234-
// Number of decimal digits per Postgres digit
235-
constexpr int kDecDigits = 4;
236-
std::vector<int16_t> pg_digits;
237-
int16_t weight = -(scale_ / kDecDigits);
238-
int16_t dscale = scale_;
239-
bool seen_decimal = scale_ == 0;
240-
bool truncating_trailing_zeros = true;
241-
242-
char decimal_string[max_decimal_digits_ + 1];
243-
int digits_remaining = DecimalToString<bitwidth_>(&decimal, decimal_string);
244-
do {
245-
const int start_pos =
246-
digits_remaining < kDecDigits ? 0 : digits_remaining - kDecDigits;
247-
const size_t len = digits_remaining < 4 ? digits_remaining : kDecDigits;
248-
const std::string_view substr{decimal_string + start_pos, len};
249-
int16_t val{};
250-
std::from_chars(substr.data(), substr.data() + substr.size(), val);
251-
252-
if (val == 0) {
253-
if (!seen_decimal && truncating_trailing_zeros) {
254-
dscale -= kDecDigits;
255-
}
256-
} else {
257-
pg_digits.insert(pg_digits.begin(), val);
258-
if (!seen_decimal && truncating_trailing_zeros) {
259-
if (val % 1000 == 0) {
260-
dscale -= 3;
261-
} else if (val % 100 == 0) {
262-
dscale -= 2;
263-
} else if (val % 10 == 0) {
264-
dscale -= 1;
265-
}
266-
}
267-
truncating_trailing_zeros = false;
268-
}
269-
digits_remaining -= kDecDigits;
270-
if (digits_remaining <= 0) {
271-
break;
272-
}
273-
weight++;
274-
275-
if (start_pos <= static_cast<int>(std::strlen(decimal_string)) - scale_) {
276-
seen_decimal = true;
277-
}
278-
} while (true);
279-
280-
int16_t ndigits = pg_digits.size();
281-
int32_t field_size_bytes = sizeof(ndigits) + sizeof(weight) + sizeof(sign) +
268+
// Convert decimal to string and split into integer/fractional parts
269+
// Example transformation for Arrow Decimal(value=12345, scale=2) representing 123.45:
270+
// Input: decimal.value = 12345, scale_ = 2
271+
// After DecimalToString: raw_decimal_string = "12345", original_digits = 5
272+
// After SplitDecimalParts: parts.integer_part = "123"
273+
// parts.fractional_part = "45"
274+
// parts.effective_scale = 2
275+
char raw_decimal_string[max_decimal_digits_ + 1];
276+
int original_digits = DecimalToString<bitwidth_>(&decimal, raw_decimal_string);
277+
DecimalParts parts = SplitDecimalParts(raw_decimal_string, original_digits, scale_);
278+
279+
// Group into PostgreSQL base-10000 representation
280+
// After GroupIntegerDigits: int_digits = [123], weight = 0
281+
// (groups "123" right-to-left: "123" → 123, only 1 group so weight = 0)
282+
auto [int_digits, weight] = GroupIntegerDigits(parts.integer_part);
283+
284+
// After GroupFractionalDigits: frac_digits = [4500], final_weight = 0
285+
// (groups "45" left-to-right with right-padding: "45" → "4500" → 4500)
286+
auto [frac_digits, final_weight] =
287+
GroupFractionalDigits(parts.fractional_part, weight, !parts.integer_part.empty());
288+
289+
// Combine digit arrays
290+
// After combining: all_digits = [123, 4500]
291+
std::vector<int16_t> all_digits = int_digits;
292+
all_digits.insert(all_digits.end(), frac_digits.begin(), frac_digits.end());
293+
294+
// Calculate display scale by counting trailing zeros in the DECIMAL STRING
295+
// For our example: frac_part="45" has 0 trailing zeros, effective_scale=2
296+
// So dscale = 2 - 0 = 2 (2 fractional digits to display)
297+
int trailing_zeros = 0;
298+
for (int j = parts.fractional_part.length() - 1;
299+
j >= 0 && parts.fractional_part[j] == '0'; j--) {
300+
trailing_zeros++;
301+
}
302+
int16_t dscale =
303+
static_cast<int16_t>((std::max)(0, parts.effective_scale - trailing_zeros));
304+
305+
// Optimize: remove trailing zero digit groups from fractional part
306+
int n_int_digit_groups = int_digits.size();
307+
while (static_cast<int>(all_digits.size()) > n_int_digit_groups &&
308+
all_digits.back() == 0) {
309+
all_digits.pop_back();
310+
}
311+
312+
// Handle zero special case
313+
if (all_digits.empty()) {
314+
final_weight = 0;
315+
dscale = 0;
316+
} else if (static_cast<int>(all_digits.size()) <= n_int_digit_groups) {
317+
// All fractional digits were removed
318+
dscale = 0;
319+
}
320+
321+
if (dscale < 0) dscale = 0;
322+
323+
// Write PostgreSQL NUMERIC binary format to buffer
324+
// Final values for our example: ndigits = 2
325+
// final_weight = 0
326+
// sign = 0x0000
327+
// dscale = 2
328+
// digits = [123, 4500]
329+
// Binary output represents: 123 * 10000^0 + 4500 * 10000^(-1) = 123 + 0.45 = 123.45
330+
int16_t ndigits = all_digits.size();
331+
int32_t field_size_bytes = sizeof(ndigits) + sizeof(final_weight) + sizeof(sign) +
282332
sizeof(dscale) + ndigits * sizeof(int16_t);
283333

284334
NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, field_size_bytes, error));
285335
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, ndigits, error));
286-
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, weight, error));
336+
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, final_weight, error));
287337
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, sign, error));
288338
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, dscale, error));
289339

290-
const size_t pg_digit_bytes = sizeof(int16_t) * pg_digits.size();
340+
const size_t pg_digit_bytes = sizeof(int16_t) * all_digits.size();
291341
NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(buffer, pg_digit_bytes));
292-
for (auto pg_digit : pg_digits) {
342+
for (auto pg_digit : all_digits) {
293343
WriteUnsafe<int16_t>(buffer, pg_digit);
294344
}
295345

296346
return ADBC_STATUS_OK;
297347
}
298348

299349
private:
300-
// returns the length of the string
350+
// Helper struct for organizing data flow between functions
351+
struct DecimalParts {
352+
std::string integer_part; // e.g., "12300" or "123"
353+
std::string fractional_part; // e.g., "45" or "00123"
354+
int effective_scale; // Scale after handling negative values
355+
};
356+
357+
// Helper function implementations for decimal-to-PostgreSQL NUMERIC conversion
358+
359+
// Convert decimal to string (absolute value, no sign)
360+
// Returns the length of the string
301361
template <int32_t DEC_WIDTH>
302-
int DecimalToString(struct ArrowDecimal* decimal, char* out) {
362+
int DecimalToString(struct ArrowDecimal* decimal, char* out) const {
303363
constexpr size_t nwords = (DEC_WIDTH == 128) ? 2 : 4;
304364
uint8_t tmp[DEC_WIDTH / 8];
305365
ArrowDecimalGetBytes(decimal, tmp);
@@ -322,10 +382,9 @@ class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
322382
for (size_t i = 0; i < DEC_WIDTH; i++) {
323383
int carry;
324384

325-
carry = (buf[nwords - 1] >= 0x7FFFFFFFFFFFFFFF);
385+
carry = (buf[nwords - 1] > 0x7FFFFFFFFFFFFFFF);
326386
for (size_t j = nwords - 1; j > 0; j--) {
327-
buf[j] =
328-
((buf[j] << 1) & 0xFFFFFFFFFFFFFFFF) + (buf[j - 1] >= 0x7FFFFFFFFFFFFFFF);
387+
buf[j] = ((buf[j] << 1) & 0xFFFFFFFFFFFFFFFF) + (buf[j - 1] > 0x7FFFFFFFFFFFFFFF);
329388
}
330389
buf[0] = ((buf[0] << 1) & 0xFFFFFFFFFFFFFFFF);
331390

@@ -350,6 +409,117 @@ class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
350409
return ndigits;
351410
}
352411

412+
DecimalParts SplitDecimalParts(const char* decimal_digits, int digit_count,
413+
int scale) const {
414+
// Virtual zeros represent the logical zeros appended for negative scale
415+
// Example: value=123, scale=-2 → "123" with 2 virtual zeros = "12300"
416+
const int virtual_zeros = (scale < 0) ? -scale : 0;
417+
const int effective_scale = (scale < 0) ? 0 : scale;
418+
const int total_logical_digits = digit_count + virtual_zeros;
419+
420+
// Calculate split point
421+
const int n_int_digits = total_logical_digits > effective_scale
422+
? total_logical_digits - effective_scale
423+
: 0;
424+
const int n_frac_digits = total_logical_digits - n_int_digits;
425+
426+
DecimalParts parts;
427+
parts.effective_scale = effective_scale;
428+
429+
// Extract integer part
430+
if (n_int_digits > 0) {
431+
if (n_int_digits <= digit_count) {
432+
// Integer part is within the original digits
433+
parts.integer_part.assign(decimal_digits, n_int_digits);
434+
} else {
435+
// Integer part includes all original digits + virtual zeros
436+
parts.integer_part.assign(decimal_digits, digit_count);
437+
parts.integer_part.append(virtual_zeros, '0');
438+
}
439+
}
440+
441+
// Extract fractional part (only exists if scale > 0)
442+
if (n_int_digits == 0 && total_logical_digits < effective_scale) {
443+
// Small fractional: 0.00123 needs leading zeros
444+
parts.fractional_part.assign(effective_scale - total_logical_digits, '0');
445+
parts.fractional_part.append(decimal_digits, digit_count);
446+
} else if (n_frac_digits > 0 && n_int_digits < digit_count) {
447+
// Fractional part from remaining digits (virtual zeros don't appear in fractional
448+
// part)
449+
parts.fractional_part.assign(decimal_digits + n_int_digits,
450+
digit_count - n_int_digits);
451+
}
452+
453+
return parts;
454+
}
455+
456+
std::pair<std::vector<int16_t>, int16_t> GroupIntegerDigits(
457+
const std::string& int_part) const {
458+
constexpr int kDecDigits = 4;
459+
std::vector<int16_t> digits;
460+
461+
if (int_part.empty()) {
462+
return {digits, -1}; // weight = -1 for pure fractional numbers
463+
}
464+
465+
// Calculate weight: ceil(length / 4) - 1
466+
int16_t weight = (int_part.length() + kDecDigits - 1) / kDecDigits - 1;
467+
468+
// Group right-to-left in chunks of 4
469+
int i = int_part.length();
470+
while (i > 0) {
471+
int chunk_size = (std::min)(i, kDecDigits);
472+
std::string_view chunk =
473+
std::string_view(int_part).substr(i - chunk_size, chunk_size);
474+
475+
int16_t val{};
476+
std::from_chars(chunk.data(), chunk.data() + chunk.size(), val);
477+
478+
// Skip trailing zeros
479+
if (val != 0 || !digits.empty()) {
480+
digits.insert(digits.begin(), val);
481+
}
482+
i -= chunk_size;
483+
}
484+
485+
return {digits, weight};
486+
}
487+
488+
std::pair<std::vector<int16_t>, int16_t> GroupFractionalDigits(
489+
const std::string& frac_part, int16_t initial_weight, bool has_integer_part) const {
490+
constexpr int kDecDigits = 4;
491+
std::vector<int16_t> digits;
492+
int16_t weight = initial_weight;
493+
494+
if (frac_part.empty()) {
495+
return {digits, weight};
496+
}
497+
498+
bool skip_leading_zeros = !has_integer_part;
499+
500+
// Group left-to-right in chunks of 4, right-padding last chunk
501+
for (size_t i = 0; i < frac_part.length(); i += kDecDigits) {
502+
int chunk_size = (std::min)(kDecDigits, static_cast<int>(frac_part.length() - i));
503+
std::string chunk_str = frac_part.substr(i, chunk_size);
504+
505+
// Right-pad to 4 digits (e.g., "45" → "4500")
506+
chunk_str.resize(kDecDigits, '0');
507+
508+
int16_t val{};
509+
std::from_chars(chunk_str.data(), chunk_str.data() + chunk_str.size(), val);
510+
511+
if (skip_leading_zeros && val == 0) {
512+
// Skip leading zero groups in fractional part (e.g., 0.0012 → skip "0012")
513+
weight--;
514+
} else {
515+
digits.push_back(val);
516+
skip_leading_zeros = false;
517+
}
518+
}
519+
520+
return {digits, weight};
521+
}
522+
353523
static constexpr uint16_t kNumericPos = 0x0000;
354524
static constexpr uint16_t kNumericNeg = 0x4000;
355525
static constexpr int32_t bitwidth_ = (T == NANOARROW_TYPE_DECIMAL128) ? 128 : 256;

c/driver/postgresql/validation/queries/ingest/decimal.toml

Lines changed: 0 additions & 19 deletions
This file was deleted.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
// part: expected_schema
19+
{
20+
"format": "+s",
21+
"children": [
22+
{
23+
"name": "idx",
24+
"format": "l",
25+
"flags": ["nullable"]
26+
},
27+
{
28+
"name": "value",
29+
"format": "u",
30+
"flags": ["nullable"],
31+
"metadata": {
32+
"ARROW:extension:name": "arrow.opaque",
33+
"ARROW:extension:metadata": "{\"type_name\": \"numeric\", \"vendor_name\": \"PostgreSQL\"}"
34+
}
35+
}
36+
]
37+
}
38+
39+
// part: expected
40+
41+
{"idx": 0, "value": "0"}
42+
{"idx": 1, "value": "123.45"}
43+
{"idx": 2, "value": "-123.45"}
44+
{"idx": 3, "value": "9999999.99"}
45+
{"idx": 4, "value": "-9999999.99"}

0 commit comments

Comments
 (0)