Skip to content

Commit 4cd1e67

Browse files
committed
Alternative approach with output_type as optional in cpp.
1 parent 89896d5 commit 4cd1e67

File tree

6 files changed

+19
-19
lines changed

6 files changed

+19
-19
lines changed

cpp/src/arrow/compute/api_vector.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ ListFlattenOptions::ListFlattenOptions(bool recursive)
257257
constexpr char ListFlattenOptions::kTypeName[];
258258

259259
InversePermutationOptions::InversePermutationOptions(
260-
int64_t max_index, std::shared_ptr<DataType> output_type)
260+
int64_t max_index, std::optional<std::shared_ptr<DataType>> output_type)
261261
: FunctionOptions(internal::kInversePermutationOptionsType),
262262
max_index(max_index),
263263
output_type(std::move(output_type)) {}

cpp/src/arrow/compute/api_vector.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ class ARROW_EXPORT ListFlattenOptions : public FunctionOptions {
299299
class ARROW_EXPORT InversePermutationOptions : public FunctionOptions {
300300
public:
301301
explicit InversePermutationOptions(int64_t max_index = -1,
302-
std::shared_ptr<DataType> output_type = NULLPTR);
302+
std::optional<std::shared_ptr<DataType>> output_type = std::nullopt);
303303
static constexpr const char kTypeName[] = "InversePermutationOptions";
304304
static InversePermutationOptions Defaults() { return InversePermutationOptions(); }
305305

@@ -308,11 +308,11 @@ class ARROW_EXPORT InversePermutationOptions : public FunctionOptions {
308308
/// of the input indices minus 1 and the length of the function's output will be the
309309
/// length of the input indices.
310310
int64_t max_index = -1;
311-
/// \brief The type of the output inverse permutation. If null, the output will be of
312-
/// the same type as the input indices, otherwise must be signed integer type. An
311+
/// \brief Optional type of the output inverse permutation. Default of `nullopt` will
312+
/// use the same type as the input indices, otherwise must be signed integer type. An
313313
/// invalid error will be reported if this type is not able to store the length of the
314314
/// input indices.
315-
std::shared_ptr<DataType> output_type = NULLPTR;
315+
std::optional<std::shared_ptr<DataType>> output_type;
316316
};
317317

318318
/// \brief Options for scatter function

cpp/src/arrow/compute/function_internal.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,11 +347,19 @@ static inline Result<std::shared_ptr<Scalar>> GenericToScalar(
347347
static inline Result<std::shared_ptr<Scalar>> GenericToScalar(
348348
const std::shared_ptr<DataType>& value) {
349349
if (!value) {
350-
return std::make_shared<NullScalar>();
350+
return Status::Invalid("shared_ptr<DataType> is nullptr");
351351
}
352352
return MakeNullScalar(value);
353353
}
354354

355+
static inline Result<std::shared_ptr<Scalar>> GenericToScalar(
356+
const std::optional<std::shared_ptr<DataType>>& value) {
357+
if (!value.has_value()) {
358+
return std::make_shared<NullScalar>();
359+
}
360+
return GenericToScalar(value.value());
361+
}
362+
355363
static inline Result<std::shared_ptr<Scalar>> GenericToScalar(const TypeHolder& value) {
356364
return GenericToScalar(value.GetSharedPtr());
357365
}
@@ -448,9 +456,6 @@ static inline enable_if_same_result<T, SortKey> GenericFromScalar(
448456
template <typename T>
449457
static inline enable_if_same_result<T, std::shared_ptr<DataType>> GenericFromScalar(
450458
const std::shared_ptr<Scalar>& value) {
451-
if (value->type->id() == Type::NA) {
452-
return std::shared_ptr<NullType>();
453-
}
454459
return value->type;
455460
}
456461

cpp/src/arrow/compute/kernels/vector_swizzle.cc

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,8 @@ Result<TypeHolder> ResolveInversePermutationOutputType(
5151
DCHECK_EQ(input_types.size(), 1);
5252
DCHECK_NE(input_types[0], nullptr);
5353

54-
std::shared_ptr<DataType> output_type = InversePermutationState::Get(ctx).output_type;
55-
if (!output_type) {
56-
output_type = input_types[0].owned_type;
57-
}
54+
std::shared_ptr<DataType> output_type =
55+
InversePermutationState::Get(ctx).output_type.value_or(input_types[0].owned_type);
5856
if (!is_signed_integer(output_type->id())) {
5957
return Status::TypeError(
6058
"Output type of inverse_permutation must be signed integer, got " +
@@ -78,10 +76,7 @@ struct InversePermutationImpl {
7876

7977
// Apply default options semantics.
8078
int64_t output_length = options.max_index < 0 ? input_length : options.max_index + 1;
81-
std::shared_ptr<DataType> output_type = options.output_type;
82-
if (!output_type) {
83-
output_type = input_type;
84-
}
79+
std::shared_ptr<DataType> output_type = options.output_type.value_or(input_type);
8580

8681
ThisType impl(ctx, indices, input_length, output_length);
8782
RETURN_NOT_OK(VisitTypeInline(*output_type, &impl));

python/pyarrow/_compute.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1467,7 +1467,7 @@ class InversePermutationOptions(_InversePermutationOptions):
14671467
If negative, this value will be set to the length of the input indices
14681468
minus 1 and the length of the function’s output will be the length
14691469
of the input indices.
1470-
output_type : DataType, default None
1470+
output_type : Optional[DataType], default None
14711471
The type of the output inverse permutation.
14721472
If None, the output will be of the same type as the input indices, otherwise
14731473
must be signed integer type. An invalid error will be reported if this type

python/pyarrow/includes/libarrow.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2593,7 +2593,7 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
25932593
CInversePermutationOptions(int64_t max_index)
25942594
CInversePermutationOptions(int64_t max_index, shared_ptr[CDataType] output_type)
25952595
int64_t max_index
2596-
shared_ptr[CDataType] output_type
2596+
optional[shared_ptr[CDataType]] output_type
25972597

25982598
cdef cppclass CScatterOptions \
25992599
"arrow::compute::ScatterOptions"(CFunctionOptions):

0 commit comments

Comments
 (0)