Skip to content

Commit 33ac70d

Browse files
authored
minor: refactoring of some ScalarValue code (#19439)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> N/A ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Some minor things I noticed in `ScalarValue` that I wanted to refactor. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> Various refactors. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Existing tests. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> No. <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 677c543 commit 33ac70d

File tree

1 file changed

+58
-101
lines changed
  • datafusion/common/src/scalar

1 file changed

+58
-101
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 58 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,13 @@ use arrow::compute::kernels::numeric::{
7878
add, add_wrapping, div, mul, mul_wrapping, rem, sub, sub_wrapping,
7979
};
8080
use arrow::datatypes::{
81-
ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType,
82-
DECIMAL128_MAX_PRECISION, DataType, Date32Type, Decimal32Type, Decimal64Type,
83-
Decimal128Type, Decimal256Type, DecimalType, Field, Float32Type, Int8Type, Int16Type,
84-
Int32Type, Int64Type, IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano,
85-
IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, TimeUnit,
86-
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
87-
TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type, UnionFields,
88-
UnionMode, i256, validate_decimal_precision_and_scale,
81+
ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType, Date32Type,
82+
Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType, Field,
83+
Float32Type, Int8Type, Int16Type, Int32Type, Int64Type, IntervalDayTime,
84+
IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalUnit,
85+
IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
86+
TimestampNanosecondType, TimestampSecondType, UInt8Type, UInt16Type, UInt32Type,
87+
UInt64Type, UnionFields, UnionMode, i256, validate_decimal_precision_and_scale,
8988
};
9089
use arrow::util::display::{ArrayFormatter, FormatOptions, array_value_to_string};
9190
use cache::{get_or_create_cached_key_array, get_or_create_cached_null_array};
@@ -1152,13 +1151,8 @@ impl ScalarValue {
11521151

11531152
/// Create a decimal Scalar from value/precision and scale.
11541153
pub fn try_new_decimal128(value: i128, precision: u8, scale: i8) -> Result<Self> {
1155-
// make sure the precision and scale is valid
1156-
if precision <= DECIMAL128_MAX_PRECISION && scale.unsigned_abs() <= precision {
1157-
return Ok(ScalarValue::Decimal128(Some(value), precision, scale));
1158-
}
1159-
_internal_err!(
1160-
"Can not new a decimal type ScalarValue for precision {precision} and scale {scale}"
1161-
)
1154+
Self::validate_decimal_or_internal_err::<Decimal128Type>(precision, scale)?;
1155+
Ok(ScalarValue::Decimal128(Some(value), precision, scale))
11621156
}
11631157

11641158
/// Create a Null instance of ScalarValue for this datatype
@@ -1250,15 +1244,15 @@ impl ScalarValue {
12501244
index_type.clone(),
12511245
Box::new(value_type.as_ref().try_into()?),
12521246
),
1253-
// `ScalaValue::List` contains single element `ListArray`.
1247+
// `ScalarValue::List` contains single element `ListArray`.
12541248
DataType::List(field_ref) => ScalarValue::List(Arc::new(
12551249
GenericListArray::new_null(Arc::clone(field_ref), 1),
12561250
)),
12571251
// `ScalarValue::LargeList` contains single element `LargeListArray`.
12581252
DataType::LargeList(field_ref) => ScalarValue::LargeList(Arc::new(
12591253
GenericListArray::new_null(Arc::clone(field_ref), 1),
12601254
)),
1261-
// `ScalaValue::FixedSizeList` contains single element `FixedSizeList`.
1255+
// `ScalarValue::FixedSizeList` contains single element `FixedSizeList`.
12621256
DataType::FixedSizeList(field_ref, fixed_length) => {
12631257
ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::new_null(
12641258
Arc::clone(field_ref),
@@ -1338,6 +1332,7 @@ impl ScalarValue {
13381332
/// Returns a [`ScalarValue`] representing PI
13391333
pub fn new_pi(datatype: &DataType) -> Result<ScalarValue> {
13401334
match datatype {
1335+
DataType::Float16 => Ok(ScalarValue::from(f16::PI)),
13411336
DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::PI)),
13421337
DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::PI)),
13431338
_ => _internal_err!("PI is not supported for data type: {}", datatype),
@@ -1347,6 +1342,8 @@ impl ScalarValue {
13471342
/// Returns a [`ScalarValue`] representing PI's upper bound
13481343
pub fn new_pi_upper(datatype: &DataType) -> Result<ScalarValue> {
13491344
match datatype {
1345+
// TODO: half::f16 doesn't seem to have equivalent
1346+
// https://github.com/apache/datafusion/issues/19465
13501347
DataType::Float32 => Ok(ScalarValue::from(consts::PI_UPPER_F32)),
13511348
DataType::Float64 => Ok(ScalarValue::from(consts::PI_UPPER_F64)),
13521349
_ => {
@@ -1358,6 +1355,8 @@ impl ScalarValue {
13581355
/// Returns a [`ScalarValue`] representing -PI's lower bound
13591356
pub fn new_negative_pi_lower(datatype: &DataType) -> Result<ScalarValue> {
13601357
match datatype {
1358+
// TODO: half::f16 doesn't seem to have equivalent
1359+
// https://github.com/apache/datafusion/issues/19465
13611360
DataType::Float32 => Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F32)),
13621361
DataType::Float64 => Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F64)),
13631362
_ => {
@@ -1369,6 +1368,8 @@ impl ScalarValue {
13691368
/// Returns a [`ScalarValue`] representing FRAC_PI_2's upper bound
13701369
pub fn new_frac_pi_2_upper(datatype: &DataType) -> Result<ScalarValue> {
13711370
match datatype {
1371+
// TODO: half::f16 doesn't seem to have equivalent
1372+
// https://github.com/apache/datafusion/issues/19465
13721373
DataType::Float32 => Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F32)),
13731374
DataType::Float64 => Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F64)),
13741375
_ => {
@@ -1380,6 +1381,8 @@ impl ScalarValue {
13801381
// Returns a [`ScalarValue`] representing FRAC_PI_2's lower bound
13811382
pub fn new_neg_frac_pi_2_lower(datatype: &DataType) -> Result<ScalarValue> {
13821383
match datatype {
1384+
// TODO: half::f16 doesn't seem to have equivalent
1385+
// https://github.com/apache/datafusion/issues/19465
13831386
DataType::Float32 => {
13841387
Ok(ScalarValue::from(consts::NEGATIVE_FRAC_PI_2_LOWER_F32))
13851388
}
@@ -1395,6 +1398,7 @@ impl ScalarValue {
13951398
/// Returns a [`ScalarValue`] representing -PI
13961399
pub fn new_negative_pi(datatype: &DataType) -> Result<ScalarValue> {
13971400
match datatype {
1401+
DataType::Float16 => Ok(ScalarValue::from(-f16::PI)),
13981402
DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::PI)),
13991403
DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::PI)),
14001404
_ => _internal_err!("-PI is not supported for data type: {}", datatype),
@@ -1404,6 +1408,7 @@ impl ScalarValue {
14041408
/// Returns a [`ScalarValue`] representing PI/2
14051409
pub fn new_frac_pi_2(datatype: &DataType) -> Result<ScalarValue> {
14061410
match datatype {
1411+
DataType::Float16 => Ok(ScalarValue::from(f16::FRAC_PI_2)),
14071412
DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::FRAC_PI_2)),
14081413
DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::FRAC_PI_2)),
14091414
_ => _internal_err!("PI/2 is not supported for data type: {}", datatype),
@@ -1413,6 +1418,7 @@ impl ScalarValue {
14131418
/// Returns a [`ScalarValue`] representing -PI/2
14141419
pub fn new_neg_frac_pi_2(datatype: &DataType) -> Result<ScalarValue> {
14151420
match datatype {
1421+
DataType::Float16 => Ok(ScalarValue::from(-f16::FRAC_PI_2)),
14161422
DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::FRAC_PI_2)),
14171423
DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::FRAC_PI_2)),
14181424
_ => _internal_err!("-PI/2 is not supported for data type: {}", datatype),
@@ -1422,6 +1428,7 @@ impl ScalarValue {
14221428
/// Returns a [`ScalarValue`] representing infinity
14231429
pub fn new_infinity(datatype: &DataType) -> Result<ScalarValue> {
14241430
match datatype {
1431+
DataType::Float16 => Ok(ScalarValue::from(f16::INFINITY)),
14251432
DataType::Float32 => Ok(ScalarValue::from(f32::INFINITY)),
14261433
DataType::Float64 => Ok(ScalarValue::from(f64::INFINITY)),
14271434
_ => {
@@ -1433,6 +1440,7 @@ impl ScalarValue {
14331440
/// Returns a [`ScalarValue`] representing negative infinity
14341441
pub fn new_neg_infinity(datatype: &DataType) -> Result<ScalarValue> {
14351442
match datatype {
1443+
DataType::Float16 => Ok(ScalarValue::from(f16::NEG_INFINITY)),
14361444
DataType::Float32 => Ok(ScalarValue::from(f32::NEG_INFINITY)),
14371445
DataType::Float64 => Ok(ScalarValue::from(f64::NEG_INFINITY)),
14381446
_ => {
@@ -1456,7 +1464,7 @@ impl ScalarValue {
14561464
DataType::UInt16 => ScalarValue::UInt16(Some(0)),
14571465
DataType::UInt32 => ScalarValue::UInt32(Some(0)),
14581466
DataType::UInt64 => ScalarValue::UInt64(Some(0)),
1459-
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))),
1467+
DataType::Float16 => ScalarValue::Float16(Some(f16::ZERO)),
14601468
DataType::Float32 => ScalarValue::Float32(Some(0.0)),
14611469
DataType::Float64 => ScalarValue::Float64(Some(0.0)),
14621470
DataType::Decimal32(precision, scale) => {
@@ -1671,7 +1679,7 @@ impl ScalarValue {
16711679
DataType::UInt16 => ScalarValue::UInt16(Some(1)),
16721680
DataType::UInt32 => ScalarValue::UInt32(Some(1)),
16731681
DataType::UInt64 => ScalarValue::UInt64(Some(1)),
1674-
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(1.0))),
1682+
DataType::Float16 => ScalarValue::Float16(Some(f16::ONE)),
16751683
DataType::Float32 => ScalarValue::Float32(Some(1.0)),
16761684
DataType::Float64 => ScalarValue::Float64(Some(1.0)),
16771685
DataType::Decimal32(precision, scale) => {
@@ -1737,7 +1745,7 @@ impl ScalarValue {
17371745
DataType::Int16 | DataType::UInt16 => ScalarValue::Int16(Some(-1)),
17381746
DataType::Int32 | DataType::UInt32 => ScalarValue::Int32(Some(-1)),
17391747
DataType::Int64 | DataType::UInt64 => ScalarValue::Int64(Some(-1)),
1740-
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(-1.0))),
1748+
DataType::Float16 => ScalarValue::Float16(Some(f16::NEG_ONE)),
17411749
DataType::Float32 => ScalarValue::Float32(Some(-1.0)),
17421750
DataType::Float64 => ScalarValue::Float64(Some(-1.0)),
17431751
DataType::Decimal32(precision, scale) => {
@@ -1964,9 +1972,7 @@ impl ScalarValue {
19641972
| ScalarValue::Float16(None)
19651973
| ScalarValue::Float32(None)
19661974
| ScalarValue::Float64(None) => Ok(self.clone()),
1967-
ScalarValue::Float16(Some(v)) => {
1968-
Ok(ScalarValue::Float16(Some(f16::from_f32(-v.to_f32()))))
1969-
}
1975+
ScalarValue::Float16(Some(v)) => Ok(ScalarValue::Float16(Some(-v))),
19701976
ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))),
19711977
ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))),
19721978
ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(v.neg_checked()?))),
@@ -2087,6 +2093,7 @@ impl ScalarValue {
20872093
let r = add_wrapping(&self.to_scalar()?, &other.borrow().to_scalar()?)?;
20882094
Self::try_from_array(r.as_ref(), 0)
20892095
}
2096+
20902097
/// Checked addition of `ScalarValue`
20912098
///
20922099
/// NB: operating on `ScalarValue` directly is not efficient, performance sensitive code
@@ -2719,71 +2726,6 @@ impl ScalarValue {
27192726
Ok(array)
27202727
}
27212728

2722-
fn build_decimal32_array(
2723-
value: Option<i32>,
2724-
precision: u8,
2725-
scale: i8,
2726-
size: usize,
2727-
) -> Result<Decimal32Array> {
2728-
Ok(match value {
2729-
Some(val) => Decimal32Array::from(vec![val; size])
2730-
.with_precision_and_scale(precision, scale)?,
2731-
None => {
2732-
let mut builder = Decimal32Array::builder(size)
2733-
.with_precision_and_scale(precision, scale)?;
2734-
builder.append_nulls(size);
2735-
builder.finish()
2736-
}
2737-
})
2738-
}
2739-
2740-
fn build_decimal64_array(
2741-
value: Option<i64>,
2742-
precision: u8,
2743-
scale: i8,
2744-
size: usize,
2745-
) -> Result<Decimal64Array> {
2746-
Ok(match value {
2747-
Some(val) => Decimal64Array::from(vec![val; size])
2748-
.with_precision_and_scale(precision, scale)?,
2749-
None => {
2750-
let mut builder = Decimal64Array::builder(size)
2751-
.with_precision_and_scale(precision, scale)?;
2752-
builder.append_nulls(size);
2753-
builder.finish()
2754-
}
2755-
})
2756-
}
2757-
2758-
fn build_decimal128_array(
2759-
value: Option<i128>,
2760-
precision: u8,
2761-
scale: i8,
2762-
size: usize,
2763-
) -> Result<Decimal128Array> {
2764-
Ok(match value {
2765-
Some(val) => Decimal128Array::from(vec![val; size])
2766-
.with_precision_and_scale(precision, scale)?,
2767-
None => {
2768-
let mut builder = Decimal128Array::builder(size)
2769-
.with_precision_and_scale(precision, scale)?;
2770-
builder.append_nulls(size);
2771-
builder.finish()
2772-
}
2773-
})
2774-
}
2775-
2776-
fn build_decimal256_array(
2777-
value: Option<i256>,
2778-
precision: u8,
2779-
scale: i8,
2780-
size: usize,
2781-
) -> Result<Decimal256Array> {
2782-
Ok(repeat_n(value, size)
2783-
.collect::<Decimal256Array>()
2784-
.with_precision_and_scale(precision, scale)?)
2785-
}
2786-
27872729
/// Converts `Vec<ScalarValue>` where each element has type corresponding to
27882730
/// `data_type`, to a single element [`ListArray`].
27892731
///
@@ -2939,18 +2881,35 @@ impl ScalarValue {
29392881
/// - a `Dictionary` that fails be converted to a dictionary array of size
29402882
pub fn to_array_of_size(&self, size: usize) -> Result<ArrayRef> {
29412883
Ok(match self {
2942-
ScalarValue::Decimal32(e, precision, scale) => Arc::new(
2943-
ScalarValue::build_decimal32_array(*e, *precision, *scale, size)?,
2884+
ScalarValue::Decimal32(Some(e), precision, scale) => Arc::new(
2885+
Decimal32Array::from_value(*e, size)
2886+
.with_precision_and_scale(*precision, *scale)?,
29442887
),
2945-
ScalarValue::Decimal64(e, precision, scale) => Arc::new(
2946-
ScalarValue::build_decimal64_array(*e, *precision, *scale, size)?,
2888+
ScalarValue::Decimal32(None, precision, scale) => {
2889+
new_null_array(&DataType::Decimal32(*precision, *scale), size)
2890+
}
2891+
ScalarValue::Decimal64(Some(e), precision, scale) => Arc::new(
2892+
Decimal64Array::from_value(*e, size)
2893+
.with_precision_and_scale(*precision, *scale)?,
29472894
),
2948-
ScalarValue::Decimal128(e, precision, scale) => Arc::new(
2949-
ScalarValue::build_decimal128_array(*e, *precision, *scale, size)?,
2895+
ScalarValue::Decimal64(None, precision, scale) => {
2896+
new_null_array(&DataType::Decimal64(*precision, *scale), size)
2897+
}
2898+
ScalarValue::Decimal128(Some(e), precision, scale) => Arc::new(
2899+
Decimal128Array::from_value(*e, size)
2900+
.with_precision_and_scale(*precision, *scale)?,
29502901
),
2951-
ScalarValue::Decimal256(e, precision, scale) => Arc::new(
2952-
ScalarValue::build_decimal256_array(*e, *precision, *scale, size)?,
2902+
ScalarValue::Decimal128(None, precision, scale) => {
2903+
new_null_array(&DataType::Decimal128(*precision, *scale), size)
2904+
}
2905+
ScalarValue::Decimal256(Some(e), precision, scale) => Arc::new(
2906+
Decimal256Array::from_value(*e, size)
2907+
.with_precision_and_scale(*precision, *scale)?,
29532908
),
2909+
ScalarValue::Decimal256(None, precision, scale) => {
2910+
new_null_array(&DataType::Decimal256(*precision, *scale), size)
2911+
}
2912+
29542913
ScalarValue::Boolean(e) => match e {
29552914
None => new_null_array(&DataType::Boolean, size),
29562915
Some(true) => {
@@ -3239,10 +3198,7 @@ impl ScalarValue {
32393198
.map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
32403199
Arc::new(ar)
32413200
}
3242-
None => {
3243-
let dt = self.data_type();
3244-
new_null_array(&dt, size)
3245-
}
3201+
None => new_null_array(&DataType::Union(fields.clone(), *mode), size),
32463202
},
32473203
ScalarValue::Dictionary(key_type, v) => {
32483204
// values array is one element long (the value)
@@ -5123,7 +5079,8 @@ mod tests {
51235079
use arrow::buffer::{Buffer, NullBuffer, OffsetBuffer};
51245080
use arrow::compute::{is_null, kernels};
51255081
use arrow::datatypes::{
5126-
ArrowNumericType, DECIMAL256_MAX_PRECISION, Fields, Float64Type, TimeUnit,
5082+
ArrowNumericType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, Fields,
5083+
Float64Type, TimeUnit,
51275084
};
51285085
use arrow::error::ArrowError;
51295086
use arrow::util::pretty::pretty_format_columns;

0 commit comments

Comments
 (0)