1717
1818#pragma once
1919
20+ #include < algorithm>
2021#include < charconv>
2122#include < cinttypes>
2223#include < limits>
@@ -224,82 +225,141 @@ class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
224225 PostgresCopyNumericFieldWriter (int32_t precision, int32_t scale)
225226 : precision_{precision}, scale_{scale} {}
226227
228+ // PostgreSQL NUMERIC Binary Format:
229+ // ===================================
230+ // PostgreSQL stores NUMERIC values in a variable-length binary format:
231+ // - ndigits (int16): Number of base-10000 digits stored
232+ // - weight (int16): Position of the first digit group relative to decimal point
233+ // (weight can be negative for small fractional numbers)
234+ // - sign (int16): kNumericPos (0x0000) or kNumericNeg (0x4000)
235+ // - dscale (int16): Number of decimal digits after the decimal point (display scale)
236+ // - digits[]: Array of int16 values, each 0-9999 (base-10000 representation)
237+ //
238+ // Value calculation: sum(digits[i] * 10000^(weight - i)) * 10^(-dscale)
239+ //
240+ // Example 1: 12300 (from Arrow Decimal value=123, scale=-2)
241+ // - Logical representation: "12300"
242+ // - Grouped in base-10000: [1][2300]
243+ // - ndigits=2, weight=1, sign=0x0000, dscale=0, digits=[1, 2300]
244+ // - Calculation: 1*10000^1 + 2300*10000^0 = 10000 + 2300 = 12300
245+ //
246+ // Example 2: 123.45 (from Arrow Decimal value=12345, scale=2)
247+ // - Logical representation: "123.45"
248+ // - Integer part "123", fractional part "45"
249+ // - Grouped in base-10000: [123][4500] (fractional part right-padded)
250+ // - ndigits=2, weight=0, sign=0x0000, dscale=2, digits=[123, 4500]
251+ // - Calculation: 123*10000^0 + 4500*10000^(-1) = 123 + 0.45 = 123.45
252+ //
253+ // Example 3: 0.00123 (from Arrow Decimal value=123, scale=5)
254+ // - Logical representation: "0.00123"
255+ // - Integer part "0", fractional part "00123"
256+ // - Grouped in base-10000: [123] (leading zeros skipped via negative weight)
257+ // - ndigits=1, weight=-1, sign=0x0000, dscale=5, digits=[123]
258+ // - Calculation: 123*10000^(-1) * 10^0 = 0.0123, but dscale=5 means display as
259+ // 0.00123
260+
227261 ArrowErrorCode Write (ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
228262 struct ArrowDecimal decimal;
229263 ArrowDecimalInit (&decimal, bitwidth_, precision_, scale_);
230264 ArrowArrayViewGetDecimalUnsafe (array_view_, index, &decimal);
231265
232266 const int16_t sign = ArrowDecimalSign (&decimal) > 0 ? kNumericPos : kNumericNeg ;
233267
234- // Number of decimal digits per Postgres digit
235- constexpr int kDecDigits = 4 ;
236- std::vector<int16_t > pg_digits;
237- int16_t weight = -(scale_ / kDecDigits );
238- int16_t dscale = scale_;
239- bool seen_decimal = scale_ == 0 ;
240- bool truncating_trailing_zeros = true ;
241-
242- char decimal_string[max_decimal_digits_ + 1 ];
243- int digits_remaining = DecimalToString<bitwidth_>(&decimal, decimal_string);
244- do {
245- const int start_pos =
246- digits_remaining < kDecDigits ? 0 : digits_remaining - kDecDigits ;
247- const size_t len = digits_remaining < 4 ? digits_remaining : kDecDigits ;
248- const std::string_view substr{decimal_string + start_pos, len};
249- int16_t val{};
250- std::from_chars (substr.data (), substr.data () + substr.size (), val);
251-
252- if (val == 0 ) {
253- if (!seen_decimal && truncating_trailing_zeros) {
254- dscale -= kDecDigits ;
255- }
256- } else {
257- pg_digits.insert (pg_digits.begin (), val);
258- if (!seen_decimal && truncating_trailing_zeros) {
259- if (val % 1000 == 0 ) {
260- dscale -= 3 ;
261- } else if (val % 100 == 0 ) {
262- dscale -= 2 ;
263- } else if (val % 10 == 0 ) {
264- dscale -= 1 ;
265- }
266- }
267- truncating_trailing_zeros = false ;
268- }
269- digits_remaining -= kDecDigits ;
270- if (digits_remaining <= 0 ) {
271- break ;
272- }
273- weight++;
274-
275- if (start_pos <= static_cast <int >(std::strlen (decimal_string)) - scale_) {
276- seen_decimal = true ;
277- }
278- } while (true );
279-
280- int16_t ndigits = pg_digits.size ();
281- int32_t field_size_bytes = sizeof (ndigits) + sizeof (weight) + sizeof (sign) +
268+ // Convert decimal to string and split into integer/fractional parts
269+ // Example transformation for Arrow Decimal(value=12345, scale=2) representing 123.45:
270+ // Input: decimal.value = 12345, scale_ = 2
271+ // After DecimalToString: raw_decimal_string = "12345", original_digits = 5
272+ // After SplitDecimalParts: parts.integer_part = "123"
273+ // parts.fractional_part = "45"
274+ // parts.effective_scale = 2
275+ char raw_decimal_string[max_decimal_digits_ + 1 ];
276+ int original_digits = DecimalToString<bitwidth_>(&decimal, raw_decimal_string);
277+ DecimalParts parts = SplitDecimalParts (raw_decimal_string, original_digits, scale_);
278+
279+ // Group into PostgreSQL base-10000 representation
280+ // After GroupIntegerDigits: int_digits = [123], weight = 0
281+ // (groups "123" right-to-left: "123" → 123, only 1 group so weight = 0)
282+ auto [int_digits, weight] = GroupIntegerDigits (parts.integer_part );
283+
284+ // After GroupFractionalDigits: frac_digits = [4500], final_weight = 0
285+ // (groups "45" left-to-right with right-padding: "45" → "4500" → 4500)
286+ auto [frac_digits, final_weight] =
287+ GroupFractionalDigits (parts.fractional_part , weight, !parts.integer_part .empty ());
288+
289+ // Combine digit arrays
290+ // After combining: all_digits = [123, 4500]
291+ std::vector<int16_t > all_digits = int_digits;
292+ all_digits.insert (all_digits.end (), frac_digits.begin (), frac_digits.end ());
293+
294+ // Calculate display scale by counting trailing zeros in the DECIMAL STRING
295+ // For our example: frac_part="45" has 0 trailing zeros, effective_scale=2
296+ // So dscale = 2 - 0 = 2 (2 fractional digits to display)
297+ int trailing_zeros = 0 ;
298+ for (int j = parts.fractional_part .length () - 1 ;
299+ j >= 0 && parts.fractional_part [j] == ' 0' ; j--) {
300+ trailing_zeros++;
301+ }
302+ int16_t dscale =
303+ static_cast <int16_t >((std::max)(0 , parts.effective_scale - trailing_zeros));
304+
305+ // Optimize: remove trailing zero digit groups from fractional part
306+ int n_int_digit_groups = int_digits.size ();
307+ while (static_cast <int >(all_digits.size ()) > n_int_digit_groups &&
308+ all_digits.back () == 0 ) {
309+ all_digits.pop_back ();
310+ }
311+
312+ // Handle zero special case
313+ if (all_digits.empty ()) {
314+ final_weight = 0 ;
315+ dscale = 0 ;
316+ } else if (static_cast <int >(all_digits.size ()) <= n_int_digit_groups) {
317+ // All fractional digits were removed
318+ dscale = 0 ;
319+ }
320+
321+ if (dscale < 0 ) dscale = 0 ;
322+
323+ // Write PostgreSQL NUMERIC binary format to buffer
324+ // Final values for our example: ndigits = 2
325+ // final_weight = 0
326+ // sign = 0x0000
327+ // dscale = 2
328+ // digits = [123, 4500]
329+ // Binary output represents: 123 * 10000^0 + 4500 * 10000^(-1) = 123 + 0.45 = 123.45
330+ int16_t ndigits = all_digits.size ();
331+ int32_t field_size_bytes = sizeof (ndigits) + sizeof (final_weight) + sizeof (sign) +
282332 sizeof (dscale) + ndigits * sizeof (int16_t );
283333
284334 NANOARROW_RETURN_NOT_OK (WriteChecked<int32_t >(buffer, field_size_bytes, error));
285335 NANOARROW_RETURN_NOT_OK (WriteChecked<int16_t >(buffer, ndigits, error));
286- NANOARROW_RETURN_NOT_OK (WriteChecked<int16_t >(buffer, weight , error));
336+ NANOARROW_RETURN_NOT_OK (WriteChecked<int16_t >(buffer, final_weight , error));
287337 NANOARROW_RETURN_NOT_OK (WriteChecked<int16_t >(buffer, sign, error));
288338 NANOARROW_RETURN_NOT_OK (WriteChecked<int16_t >(buffer, dscale, error));
289339
290- const size_t pg_digit_bytes = sizeof (int16_t ) * pg_digits .size ();
340+ const size_t pg_digit_bytes = sizeof (int16_t ) * all_digits .size ();
291341 NANOARROW_RETURN_NOT_OK (ArrowBufferReserve (buffer, pg_digit_bytes));
292- for (auto pg_digit : pg_digits ) {
342+ for (auto pg_digit : all_digits ) {
293343 WriteUnsafe<int16_t >(buffer, pg_digit);
294344 }
295345
296346 return ADBC_STATUS_OK;
297347 }
298348
299349 private:
300- // returns the length of the string
350+ // Helper struct for organizing data flow between functions
351+ struct DecimalParts {
352+ std::string integer_part; // e.g., "12300" or "123"
353+ std::string fractional_part; // e.g., "45" or "00123"
354+ int effective_scale; // Scale after handling negative values
355+ };
356+
357+ // Helper function implementations for decimal-to-PostgreSQL NUMERIC conversion
358+
359+ // Convert decimal to string (absolute value, no sign)
360+ // Returns the length of the string
301361 template <int32_t DEC_WIDTH>
302- int DecimalToString (struct ArrowDecimal * decimal, char * out) {
362+ int DecimalToString (struct ArrowDecimal * decimal, char * out) const {
303363 constexpr size_t nwords = (DEC_WIDTH == 128 ) ? 2 : 4 ;
304364 uint8_t tmp[DEC_WIDTH / 8 ];
305365 ArrowDecimalGetBytes (decimal, tmp);
@@ -322,10 +382,9 @@ class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
322382 for (size_t i = 0 ; i < DEC_WIDTH; i++) {
323383 int carry;
324384
325- carry = (buf[nwords - 1 ] >= 0x7FFFFFFFFFFFFFFF );
385+ carry = (buf[nwords - 1 ] > 0x7FFFFFFFFFFFFFFF );
326386 for (size_t j = nwords - 1 ; j > 0 ; j--) {
327- buf[j] =
328- ((buf[j] << 1 ) & 0xFFFFFFFFFFFFFFFF ) + (buf[j - 1 ] >= 0x7FFFFFFFFFFFFFFF );
387+ buf[j] = ((buf[j] << 1 ) & 0xFFFFFFFFFFFFFFFF ) + (buf[j - 1 ] > 0x7FFFFFFFFFFFFFFF );
329388 }
330389 buf[0 ] = ((buf[0 ] << 1 ) & 0xFFFFFFFFFFFFFFFF );
331390
@@ -350,6 +409,117 @@ class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
350409 return ndigits;
351410 }
352411
412+ DecimalParts SplitDecimalParts (const char * decimal_digits, int digit_count,
413+ int scale) const {
414+ // Virtual zeros represent the logical zeros appended for negative scale
415+ // Example: value=123, scale=-2 → "123" with 2 virtual zeros = "12300"
416+ const int virtual_zeros = (scale < 0 ) ? -scale : 0 ;
417+ const int effective_scale = (scale < 0 ) ? 0 : scale;
418+ const int total_logical_digits = digit_count + virtual_zeros;
419+
420+ // Calculate split point
421+ const int n_int_digits = total_logical_digits > effective_scale
422+ ? total_logical_digits - effective_scale
423+ : 0 ;
424+ const int n_frac_digits = total_logical_digits - n_int_digits;
425+
426+ DecimalParts parts;
427+ parts.effective_scale = effective_scale;
428+
429+ // Extract integer part
430+ if (n_int_digits > 0 ) {
431+ if (n_int_digits <= digit_count) {
432+ // Integer part is within the original digits
433+ parts.integer_part .assign (decimal_digits, n_int_digits);
434+ } else {
435+ // Integer part includes all original digits + virtual zeros
436+ parts.integer_part .assign (decimal_digits, digit_count);
437+ parts.integer_part .append (virtual_zeros, ' 0' );
438+ }
439+ }
440+
441+ // Extract fractional part (only exists if scale > 0)
442+ if (n_int_digits == 0 && total_logical_digits < effective_scale) {
443+ // Small fractional: 0.00123 needs leading zeros
444+ parts.fractional_part .assign (effective_scale - total_logical_digits, ' 0' );
445+ parts.fractional_part .append (decimal_digits, digit_count);
446+ } else if (n_frac_digits > 0 && n_int_digits < digit_count) {
447+ // Fractional part from remaining digits (virtual zeros don't appear in fractional
448+ // part)
449+ parts.fractional_part .assign (decimal_digits + n_int_digits,
450+ digit_count - n_int_digits);
451+ }
452+
453+ return parts;
454+ }
455+
456+ std::pair<std::vector<int16_t >, int16_t > GroupIntegerDigits (
457+ const std::string& int_part) const {
458+ constexpr int kDecDigits = 4 ;
459+ std::vector<int16_t > digits;
460+
461+ if (int_part.empty ()) {
462+ return {digits, -1 }; // weight = -1 for pure fractional numbers
463+ }
464+
465+ // Calculate weight: ceil(length / 4) - 1
466+ int16_t weight = (int_part.length () + kDecDigits - 1 ) / kDecDigits - 1 ;
467+
468+ // Group right-to-left in chunks of 4
469+ int i = int_part.length ();
470+ while (i > 0 ) {
471+ int chunk_size = (std::min)(i, kDecDigits );
472+ std::string_view chunk =
473+ std::string_view (int_part).substr (i - chunk_size, chunk_size);
474+
475+ int16_t val{};
476+ std::from_chars (chunk.data (), chunk.data () + chunk.size (), val);
477+
478+ // Skip trailing zeros
479+ if (val != 0 || !digits.empty ()) {
480+ digits.insert (digits.begin (), val);
481+ }
482+ i -= chunk_size;
483+ }
484+
485+ return {digits, weight};
486+ }
487+
488+ std::pair<std::vector<int16_t >, int16_t > GroupFractionalDigits (
489+ const std::string& frac_part, int16_t initial_weight, bool has_integer_part) const {
490+ constexpr int kDecDigits = 4 ;
491+ std::vector<int16_t > digits;
492+ int16_t weight = initial_weight;
493+
494+ if (frac_part.empty ()) {
495+ return {digits, weight};
496+ }
497+
498+ bool skip_leading_zeros = !has_integer_part;
499+
500+ // Group left-to-right in chunks of 4, right-padding last chunk
501+ for (size_t i = 0 ; i < frac_part.length (); i += kDecDigits ) {
502+ int chunk_size = (std::min)(kDecDigits , static_cast <int >(frac_part.length () - i));
503+ std::string chunk_str = frac_part.substr (i, chunk_size);
504+
505+ // Right-pad to 4 digits (e.g., "45" → "4500")
506+ chunk_str.resize (kDecDigits , ' 0' );
507+
508+ int16_t val{};
509+ std::from_chars (chunk_str.data (), chunk_str.data () + chunk_str.size (), val);
510+
511+ if (skip_leading_zeros && val == 0 ) {
512+ // Skip leading zero groups in fractional part (e.g., 0.0012 → skip "0012")
513+ weight--;
514+ } else {
515+ digits.push_back (val);
516+ skip_leading_zeros = false ;
517+ }
518+ }
519+
520+ return {digits, weight};
521+ }
522+
353523 static constexpr uint16_t kNumericPos = 0x0000 ;
354524 static constexpr uint16_t kNumericNeg = 0x4000 ;
355525 static constexpr int32_t bitwidth_ = (T == NANOARROW_TYPE_DECIMAL128) ? 128 : 256 ;
0 commit comments