Skip to content

Commit cfc3353

Browse files
abhinavmuk04meta-codesync[bot]
authored andcommitted
Add TRANSFORM_WITH_INDEX UDF (#15978)
Summary: Pull Request resolved: #15978 Add a new TRANSFORM_WITH_INDEX function that allows transformation of array elements with access to their index, enabling more flexible feature engineering. The existing TRANSFORM function only passes the element to the lambda, making it difficult to use indices in array transform operations. This new function passes both the element and its 1-based index to the lambda: ```sql TRANSFORM_WITH_INDEX(arr, (elem, index) -> ...) ``` ## Function Signature ``` transform_with_index(array(T), function(T, bigint, U)) -> array(U) ``` ## Examples ```sql SELECT transform_with_index(ARRAY [5, 6, 7], (x, i) -> x * i); -- [5, 12, 21] SELECT transform_with_index(ARRAY ['a', 'b', 'c'], (x, i) -> concat(x, cast(i as varchar))); -- ['a1', 'b2', 'c3'] SELECT transform_with_index(ARRAY [10, 20, 30], (x, i) -> i); -- [1, 2, 3] ``` ## Implementation Details - Uses 1-based indexing for Presto compatibility - Follows the same pattern as the existing TRANSFORM function - Added to fuzzer exclusion lists as this is a Velox-only function not available in Presto Reviewed By: zacw7 Differential Revision: D90478316 fbshipit-source-id: 2689eeeb3ed3f8e6df8d4950735e9e31eef82e81
1 parent b6e5220 commit cfc3353

File tree

7 files changed

+823
-0
lines changed

7 files changed

+823
-0
lines changed

velox/docs/functions/presto/array.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,18 @@ Array Functions
458458
SELECT transform(ARRAY ['x', 'abc', 'z'], x -> x || '0'); -- ['x0', 'abc0', 'z0']
459459
SELECT transform(ARRAY [ARRAY [1, NULL, 2], ARRAY[3, NULL]], a -> filter(a, x -> x IS NOT NULL)); -- [[1, 2], [3]]
460460

461+
.. function:: transform_with_index(array(T), function(T,bigint,U)) -> array(U)
462+
463+
Returns an array that is the result of applying ``function`` to each element of ``array``.
464+
The lambda function receives both the element and its 1-based index as arguments.
465+
This is useful for transformations that need to know the position of each element::
466+
467+
SELECT transform_with_index(ARRAY [], (x, i) -> x + i); -- []
468+
SELECT transform_with_index(ARRAY [5, 6, 7], (x, i) -> x * i); -- [5, 12, 21]
469+
SELECT transform_with_index(ARRAY ['a', 'b', 'c'], (x, i) -> concat(x, cast(i as varchar))); -- ['a1', 'b2', 'c3']
470+
SELECT transform_with_index(ARRAY [10, 20, 30], (x, i) -> i); -- [1, 2, 3]
471+
SELECT transform_with_index(ARRAY [1, 2, 3], (x, i) -> if(i % 2 = 1, x, x * 2)); -- [1, 4, 3]
472+
461473
.. function:: trim_array(x, n) -> array
462474

463475
Remove n elements from the end of ``array``::

velox/expression/fuzzer/ExpressionFuzzerTest.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ std::unordered_set<std::string> skipFunctionsSOT = {
299299
// instances
300300
"array_subset", // Velox-only function, not available in Presto
301301
"map_values_in_range", // Velox-only function, not available in Presto
302+
"transform_with_index", // Velox-only function, not available in Presto
302303
"remap_keys", // Velox-only function, not available in Presto
303304
"map_intersect", // Velox-only function, not available in Presto
304305
"map_keys_overlap", // Velox-only function, not available in Presto

velox/functions/prestosql/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ velox_add_library(
5353
Subscript.cpp
5454
ToUtf8.cpp
5555
Transform.cpp
56+
TransformWithIndex.cpp
5657
TransformKeys.cpp
5758
TransformValues.cpp
5859
TypeOf.cpp
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "velox/expression/Expr.h"
17+
#include "velox/expression/VectorFunction.h"
18+
#include "velox/functions/lib/LambdaFunctionUtil.h"
19+
#include "velox/functions/lib/RowsTranslationUtil.h"
20+
#include "velox/vector/FunctionVector.h"
21+
22+
namespace facebook::velox::functions {
23+
namespace {
24+
25+
// transform_with_index(array(T), function(T, bigint, U)) -> array(U)
26+
//
27+
// Transforms each element of an array using the provided function.
28+
// The lambda function receives both the element and its 1-based index.
29+
class TransformWithIndexFunction : public exec::VectorFunction {
30+
public:
31+
void apply(
32+
const SelectivityVector& rows,
33+
std::vector<VectorPtr>& args,
34+
const TypePtr& outputType,
35+
exec::EvalCtx& context,
36+
VectorPtr& result) const override {
37+
VELOX_CHECK_EQ(args.size(), 2);
38+
39+
// Flatten input array.
40+
exec::LocalDecodedVector arrayDecoder(context, *args[0], rows);
41+
auto& decodedArray = *arrayDecoder.get();
42+
43+
auto flatArray = flattenArray(rows, args[0], decodedArray);
44+
45+
auto newNumElements = flatArray->elements()->size();
46+
47+
// Create indices vector (1-based indexing for Presto compatibility)
48+
auto indices = createIndicesVector(flatArray, rows, context.pool());
49+
50+
std::vector<VectorPtr> lambdaArgs = {flatArray->elements(), indices};
51+
52+
SelectivityVector validRowsInReusedResult =
53+
toElementRows<ArrayVector>(newNumElements, rows, flatArray.get());
54+
55+
VectorPtr newElements;
56+
57+
auto elementToTopLevelRows = getElementToTopLevelRows(
58+
newNumElements, rows, flatArray.get(), context.pool());
59+
60+
// Loop over lambda functions and apply these to elements of the base array;
61+
// in most cases there will be only one function and the loop will run once
62+
auto it = args[1]->asUnchecked<FunctionVector>()->iterator(&rows);
63+
while (auto entry = it.next()) {
64+
auto elementRows = toElementRows<ArrayVector>(
65+
newNumElements, *entry.rows, flatArray.get());
66+
auto wrapCapture = toWrapCapture<ArrayVector>(
67+
newNumElements, entry.callable, *entry.rows, flatArray);
68+
69+
entry.callable->apply(
70+
elementRows,
71+
&validRowsInReusedResult,
72+
wrapCapture,
73+
&context,
74+
lambdaArgs,
75+
elementToTopLevelRows,
76+
&newElements);
77+
}
78+
79+
// Set nulls for rows not present in 'rows'.
80+
BufferPtr newNulls = addNullsForUnselectedRows(flatArray, rows);
81+
82+
VectorPtr localResult = std::make_shared<ArrayVector>(
83+
flatArray->pool(),
84+
outputType,
85+
std::move(newNulls),
86+
rows.end(),
87+
flatArray->offsets(),
88+
flatArray->sizes(),
89+
newElements);
90+
context.moveOrCopyResult(localResult, rows, result);
91+
}
92+
93+
static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
94+
// array(T), function(T, bigint, U) -> array(U)
95+
return {exec::FunctionSignatureBuilder()
96+
.typeVariable("T")
97+
.typeVariable("U")
98+
.returnType("array(U)")
99+
.argumentType("array(T)")
100+
.argumentType("function(T, bigint, U)")
101+
.build()};
102+
}
103+
104+
private:
105+
// Creates a vector of 1-based indices for each element in the flattened
106+
// array. For example, if we have arrays [[a, b, c], [d, e]], the indices
107+
// will be [1, 2, 3, 1, 2] (1-based for each array).
108+
static VectorPtr createIndicesVector(
109+
const std::shared_ptr<ArrayVector>& flatArray,
110+
const SelectivityVector& rows,
111+
memory::MemoryPool* pool) {
112+
const auto numElements = flatArray->elements()->size();
113+
auto indicesVector =
114+
BaseVector::create<FlatVector<int64_t>>(BIGINT(), numElements, pool);
115+
auto* rawIndices = indicesVector->mutableRawValues();
116+
117+
const auto* rawOffsets = flatArray->rawOffsets();
118+
const auto* rawSizes = flatArray->rawSizes();
119+
const auto* rawNulls = flatArray->rawNulls();
120+
121+
rows.applyToSelected([&](vector_size_t row) {
122+
if (rawNulls && bits::isBitNull(rawNulls, row)) {
123+
return;
124+
}
125+
const auto offset = rawOffsets[row];
126+
const auto size = rawSizes[row];
127+
for (vector_size_t i = 0; i < size; ++i) {
128+
// Use 1-based indexing for Presto compatibility
129+
rawIndices[offset + i] = i + 1;
130+
}
131+
});
132+
133+
return indicesVector;
134+
}
135+
};
136+
} // namespace
137+
138+
/// transform_with_index is null preserving for the array. But since an
139+
/// expr tree with a lambda depends on all named fields, including
140+
/// captures, a null in a capture does not automatically make a
141+
/// null result.
142+
143+
VELOX_DECLARE_VECTOR_FUNCTION_WITH_METADATA(
144+
udf_transform_with_index,
145+
TransformWithIndexFunction::signatures(),
146+
exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(),
147+
std::make_unique<TransformWithIndexFunction>());
148+
149+
} // namespace facebook::velox::functions

velox/functions/prestosql/registration/GeneralFunctionsRegistration.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ void registerGeneralFunctions(const std::string& prefix) {
9191
registerElementAtFunction(prefix + "element_at", true);
9292

9393
VELOX_REGISTER_VECTOR_FUNCTION(udf_transform, prefix + "transform");
94+
VELOX_REGISTER_VECTOR_FUNCTION(
95+
udf_transform_with_index, prefix + "transform_with_index");
9496
VELOX_REGISTER_VECTOR_FUNCTION(udf_reduce, prefix + "reduce");
9597
registerReduceRewrites(prefix);
9698
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_filter, prefix + "filter");

velox/functions/prestosql/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ add_executable(
126126
TransformKeysTest.cpp
127127
TransformTest.cpp
128128
TransformValuesTest.cpp
129+
TransformWithIndexTest.cpp
129130
TrimFunctionsTest.cpp
130131
TypeOfTest.cpp
131132
TDigestCastTest.cpp

0 commit comments

Comments
 (0)