Skip to content

Commit 41c92a6

Browse files
committed
add XYZTransform::Make
1 parent f820aa2 commit 41c92a6

File tree

4 files changed

+140
-36
lines changed

4 files changed

+140
-36
lines changed

src/iceberg/transform.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,34 +115,34 @@ Result<std::unique_ptr<TransformFunction>> Transform::Bind(
115115

116116
switch (transform_type_) {
117117
case TransformType::kIdentity:
118-
return std::make_unique<IdentityTransform>(source_type);
118+
return IdentityTransform::Make(source_type);
119119

120120
case TransformType::kBucket: {
121121
if (auto param = std::get_if<int32_t>(&param_)) {
122-
return std::make_unique<BucketTransform>(source_type, *param);
122+
return BucketTransform::Make(source_type, *param);
123123
}
124124
return InvalidArgument("Bucket requires int32 param, none found in transform '{}'",
125125
type_str);
126126
}
127127

128128
case TransformType::kTruncate: {
129129
if (auto param = std::get_if<int32_t>(&param_)) {
130-
return std::make_unique<TruncateTransform>(source_type, *param);
130+
return TruncateTransform::Make(source_type, *param);
131131
}
132132
return InvalidArgument(
133133
"Truncate requires int32 param, none found in transform '{}'", type_str);
134134
}
135135

136136
case TransformType::kYear:
137-
return std::make_unique<YearTransform>(source_type);
137+
return YearTransform::Make(source_type);
138138
case TransformType::kMonth:
139-
return std::make_unique<MonthTransform>(source_type);
139+
return MonthTransform::Make(source_type);
140140
case TransformType::kDay:
141-
return std::make_unique<DayTransform>(source_type);
141+
return DayTransform::Make(source_type);
142142
case TransformType::kHour:
143-
return std::make_unique<HourTransform>(source_type);
143+
return HourTransform::Make(source_type);
144144
case TransformType::kVoid:
145-
return std::make_unique<VoidTransform>(source_type);
145+
return VoidTransform::Make(source_type);
146146

147147
default:
148148
return NotSupported("Unsupported transform type: '{}'", type_str);

src/iceberg/transform_function.cc

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,16 @@ Result<ArrowArray> IdentityTransform::Transform(const ArrowArray& input) {
3131
}
3232

3333
Result<std::shared_ptr<Type>> IdentityTransform::ResultType() const {
34-
auto src_type = source_type();
35-
if (!src_type || !src_type->is_primitive()) {
34+
return source_type();
35+
}
36+
37+
Result<std::unique_ptr<TransformFunction>> IdentityTransform::Make(
38+
std::shared_ptr<Type> const& source_type) {
39+
if (!source_type || !source_type->is_primitive()) {
3640
return NotSupported("{} is not a valid input type for identity transform",
37-
src_type ? src_type->ToString() : "null");
41+
source_type ? source_type->ToString() : "null");
3842
}
39-
return src_type;
43+
return std::make_unique<IdentityTransform>(source_type);
4044
}
4145

4246
BucketTransform::BucketTransform(std::shared_ptr<Type> const& source_type,
@@ -49,9 +53,6 @@ Result<ArrowArray> BucketTransform::Transform(const ArrowArray& input) {
4953

5054
Result<std::shared_ptr<Type>> BucketTransform::ResultType() const {
5155
auto src_type = source_type();
52-
if (!src_type) {
53-
return NotSupported("null is not a valid input type for bucket transform");
54-
}
5556
switch (src_type->type_id()) {
5657
case TypeId::kInt:
5758
case TypeId::kLong:
@@ -71,6 +72,14 @@ Result<std::shared_ptr<Type>> BucketTransform::ResultType() const {
7172
}
7273
}
7374

75+
Result<std::unique_ptr<TransformFunction>> BucketTransform::Make(
76+
std::shared_ptr<Type> const& source_type, int32_t num_buckets) {
77+
if (!source_type) {
78+
return NotSupported("null is not a valid input type for bucket transform");
79+
}
80+
return std::make_unique<BucketTransform>(source_type, num_buckets);
81+
}
82+
7483
TruncateTransform::TruncateTransform(std::shared_ptr<Type> const& source_type,
7584
int32_t width)
7685
: TransformFunction(TransformType::kTruncate, source_type), width_(width) {}
@@ -81,9 +90,6 @@ Result<ArrowArray> TruncateTransform::Transform(const ArrowArray& input) {
8190

8291
Result<std::shared_ptr<Type>> TruncateTransform::ResultType() const {
8392
auto src_type = source_type();
84-
if (!src_type) {
85-
return NotSupported("null is not a valid input type for truncate transform");
86-
}
8793
switch (src_type->type_id()) {
8894
case TypeId::kInt:
8995
case TypeId::kLong:
@@ -97,6 +103,14 @@ Result<std::shared_ptr<Type>> TruncateTransform::ResultType() const {
97103
}
98104
}
99105

106+
Result<std::unique_ptr<TransformFunction>> TruncateTransform::Make(
107+
std::shared_ptr<Type> const& source_type, int32_t width) {
108+
if (!source_type) {
109+
return NotSupported("null is not a valid input type for truncate transform");
110+
}
111+
return std::make_unique<TruncateTransform>(source_type, width);
112+
}
113+
100114
YearTransform::YearTransform(std::shared_ptr<Type> const& source_type)
101115
: TransformFunction(TransformType::kTruncate, source_type) {}
102116

@@ -106,9 +120,6 @@ Result<ArrowArray> YearTransform::Transform(const ArrowArray& input) {
106120

107121
Result<std::shared_ptr<Type>> YearTransform::ResultType() const {
108122
auto src_type = source_type();
109-
if (!src_type) {
110-
return NotSupported("null is not a valid input type for year transform");
111-
}
112123
switch (src_type->type_id()) {
113124
case TypeId::kDate:
114125
case TypeId::kTimestamp:
@@ -120,6 +131,14 @@ Result<std::shared_ptr<Type>> YearTransform::ResultType() const {
120131
}
121132
}
122133

134+
Result<std::unique_ptr<TransformFunction>> YearTransform::Make(
135+
std::shared_ptr<Type> const& source_type) {
136+
if (!source_type) {
137+
return NotSupported("null is not a valid input type for year transform");
138+
}
139+
return std::make_unique<YearTransform>(source_type);
140+
}
141+
123142
MonthTransform::MonthTransform(std::shared_ptr<Type> const& source_type)
124143
: TransformFunction(TransformType::kMonth, source_type) {}
125144

@@ -129,9 +148,6 @@ Result<ArrowArray> MonthTransform::Transform(const ArrowArray& input) {
129148

130149
Result<std::shared_ptr<Type>> MonthTransform::ResultType() const {
131150
auto src_type = source_type();
132-
if (!src_type) {
133-
return NotSupported("null is not a valid input type for month transform");
134-
}
135151
switch (src_type->type_id()) {
136152
case TypeId::kDate:
137153
case TypeId::kTimestamp:
@@ -143,6 +159,14 @@ Result<std::shared_ptr<Type>> MonthTransform::ResultType() const {
143159
}
144160
}
145161

162+
Result<std::unique_ptr<TransformFunction>> MonthTransform::Make(
163+
std::shared_ptr<Type> const& source_type) {
164+
if (!source_type) {
165+
return NotSupported("null is not a valid input type for month transform");
166+
}
167+
return std::make_unique<MonthTransform>(source_type);
168+
}
169+
146170
DayTransform::DayTransform(std::shared_ptr<Type> const& source_type)
147171
: TransformFunction(TransformType::kDay, source_type) {}
148172

@@ -152,9 +176,6 @@ Result<ArrowArray> DayTransform::Transform(const ArrowArray& input) {
152176

153177
Result<std::shared_ptr<Type>> DayTransform::ResultType() const {
154178
auto src_type = source_type();
155-
if (!src_type) {
156-
return NotSupported("null is not a valid input type for day transform");
157-
}
158179
switch (src_type->type_id()) {
159180
case TypeId::kDate:
160181
case TypeId::kTimestamp:
@@ -166,6 +187,14 @@ Result<std::shared_ptr<Type>> DayTransform::ResultType() const {
166187
}
167188
}
168189

190+
Result<std::unique_ptr<TransformFunction>> DayTransform::Make(
191+
std::shared_ptr<Type> const& source_type) {
192+
if (!source_type) {
193+
return NotSupported("null is not a valid input type for day transform");
194+
}
195+
return std::make_unique<DayTransform>(source_type);
196+
}
197+
169198
HourTransform::HourTransform(std::shared_ptr<Type> const& source_type)
170199
: TransformFunction(TransformType::kHour, source_type) {}
171200

@@ -175,9 +204,6 @@ Result<ArrowArray> HourTransform::Transform(const ArrowArray& input) {
175204

176205
Result<std::shared_ptr<Type>> HourTransform::ResultType() const {
177206
auto src_type = source_type();
178-
if (!src_type) {
179-
return NotSupported("null is not a valid input type for hour transform");
180-
}
181207
switch (src_type->type_id()) {
182208
case TypeId::kTimestamp:
183209
case TypeId::kTimestampTz:
@@ -188,19 +214,29 @@ Result<std::shared_ptr<Type>> HourTransform::ResultType() const {
188214
}
189215
}
190216

217+
Result<std::unique_ptr<TransformFunction>> HourTransform::Make(
218+
std::shared_ptr<Type> const& source_type) {
219+
if (!source_type) {
220+
return NotSupported("null is not a valid input type for hour transform");
221+
}
222+
return std::make_unique<HourTransform>(source_type);
223+
}
224+
191225
VoidTransform::VoidTransform(std::shared_ptr<Type> const& source_type)
192226
: TransformFunction(TransformType::kVoid, source_type) {}
193227

194228
Result<ArrowArray> VoidTransform::Transform(const ArrowArray& input) {
195229
return NotImplemented("VoidTransform::Transform");
196230
}
197231

198-
Result<std::shared_ptr<Type>> VoidTransform::ResultType() const {
199-
auto src_type = source_type();
200-
if (!src_type) {
232+
Result<std::shared_ptr<Type>> VoidTransform::ResultType() const { return source_type(); }
233+
234+
Result<std::unique_ptr<TransformFunction>> VoidTransform::Make(
235+
std::shared_ptr<Type> const& source_type) {
236+
if (!source_type) {
201237
return NotSupported("null is not a valid input type for void transform");
202238
}
203-
return src_type;
239+
return std::make_unique<VoidTransform>(source_type);
204240
}
205241

206242
} // namespace iceberg

src/iceberg/transform_function.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ class IdentityTransform : public TransformFunction {
3535

3636
/// \brief Returns the same type as the source type if it is valid.
3737
Result<std::shared_ptr<Type>> ResultType() const override;
38+
39+
/// \brief Create an IdentityTransform.
40+
/// \param source_type Type of the input data.
41+
/// \return A Result containing the IdentityTransform or an error.
42+
static Result<std::unique_ptr<TransformFunction>> Make(
43+
std::shared_ptr<Type> const& source_type);
3844
};
3945

4046
/// \brief Bucket transform that hashes input values into N buckets.
@@ -50,6 +56,13 @@ class BucketTransform : public TransformFunction {
5056
/// \brief Returns INT32 as the output type.
5157
Result<std::shared_ptr<Type>> ResultType() const override;
5258

59+
/// \brief Create a BucketTransform.
60+
/// \param source_type Type of the input data.
61+
/// \param num_buckets Number of buckets to hash into.
62+
/// \return A Result containing the BucketTransform or an error.
63+
static Result<std::unique_ptr<TransformFunction>> Make(
64+
std::shared_ptr<Type> const& source_type, int32_t num_buckets);
65+
5366
private:
5467
int32_t num_buckets_;
5568
};
@@ -67,6 +80,13 @@ class TruncateTransform : public TransformFunction {
6780
/// \brief Returns the same type as source_type.
6881
Result<std::shared_ptr<Type>> ResultType() const override;
6982

83+
/// \brief Create a TruncateTransform.
84+
/// \param source_type Type of the input data.
85+
/// \param width The width to truncate to.
86+
/// \return A Result containing the TruncateTransform or an error.
87+
static Result<std::unique_ptr<TransformFunction>> Make(
88+
std::shared_ptr<Type> const& source_type, int32_t width);
89+
7090
private:
7191
int32_t width_;
7292
};
@@ -82,6 +102,12 @@ class YearTransform : public TransformFunction {
82102

83103
/// \brief Returns INT32 as the output type.
84104
Result<std::shared_ptr<Type>> ResultType() const override;
105+
106+
/// \brief Create a YearTransform.
107+
/// \param source_type Type of the input data.
108+
/// \return A Result containing the YearTransform or an error.
109+
static Result<std::unique_ptr<TransformFunction>> Make(
110+
std::shared_ptr<Type> const& source_type);
85111
};
86112

87113
/// \brief Month transform that extracts the month component from timestamp inputs.
@@ -95,6 +121,12 @@ class MonthTransform : public TransformFunction {
95121

96122
/// \brief Returns INT32 as the output type.
97123
Result<std::shared_ptr<Type>> ResultType() const override;
124+
125+
/// \brief Create a MonthTransform.
126+
/// \param source_type Type of the input data.
127+
/// \return A Result containing the MonthTransform or an error.
128+
static Result<std::unique_ptr<TransformFunction>> Make(
129+
std::shared_ptr<Type> const& source_type);
98130
};
99131

100132
/// \brief Day transform that extracts the day of the month from timestamp inputs.
@@ -108,6 +140,12 @@ class DayTransform : public TransformFunction {
108140

109141
/// \brief Returns INT32 as the output type.
110142
Result<std::shared_ptr<Type>> ResultType() const override;
143+
144+
/// \brief Create a DayTransform.
145+
/// \param source_type Type of the input data.
146+
/// \return A Result containing the DayTransform or an error.
147+
static Result<std::unique_ptr<TransformFunction>> Make(
148+
std::shared_ptr<Type> const& source_type);
111149
};
112150

113151
/// \brief Hour transform that extracts the hour component from timestamp inputs.
@@ -121,6 +159,12 @@ class HourTransform : public TransformFunction {
121159

122160
/// \brief Returns INT32 as the output type.
123161
Result<std::shared_ptr<Type>> ResultType() const override;
162+
163+
/// \brief Create a HourTransform.
164+
/// \param source_type Type of the input data.
165+
/// \return A Result containing the HourTransform or an error.
166+
static Result<std::unique_ptr<TransformFunction>> Make(
167+
std::shared_ptr<Type> const& source_type);
124168
};
125169

126170
/// \brief Void transform that discards the input and always returns null.
@@ -134,6 +178,12 @@ class VoidTransform : public TransformFunction {
134178

135179
/// \brief Returns null type or a sentinel type indicating void.
136180
Result<std::shared_ptr<Type>> ResultType() const override;
181+
182+
/// \brief Create a VoidTransform.
183+
/// \param source_type Input type (ignored).
184+
/// \return A Result containing the VoidTransform or an error.
185+
static Result<std::unique_ptr<TransformFunction>> Make(
186+
std::shared_ptr<Type> const& source_type);
137187
};
138188

139189
} // namespace iceberg

test/transform_test.cc

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,22 @@ TEST(TransformResultTypeTest, NegativeCases) {
173173
};
174174

175175
const std::vector<Case> cases = {
176-
{.str = "identity", .source_type = nullptr},
177176
{.str = "year", .source_type = iceberg::string()},
178177
{.str = "month", .source_type = iceberg::string()},
179178
{.str = "day", .source_type = iceberg::string()},
180179
{.str = "hour", .source_type = iceberg::string()},
181-
{.str = "void", .source_type = nullptr},
182180
{.str = "bucket[16]", .source_type = iceberg::float32()},
183181
{.str = "truncate[32]", .source_type = iceberg::float64()}};
184182

183+
const std::vector<Case> null_cases = {{.str = "identity", .source_type = nullptr},
184+
{.str = "year", .source_type = nullptr},
185+
{.str = "month", .source_type = nullptr},
186+
{.str = "day", .source_type = nullptr},
187+
{.str = "hour", .source_type = nullptr},
188+
{.str = "void", .source_type = nullptr},
189+
{.str = "bucket[16]", .source_type = nullptr},
190+
{.str = "truncate[32]", .source_type = nullptr}};
191+
185192
for (const auto& c : cases) {
186193
auto result = TransformFromString(c.str);
187194
ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str;
@@ -192,6 +199,17 @@ TEST(TransformResultTypeTest, NegativeCases) {
192199
auto result_type = transformPtr.value()->ResultType();
193200
ASSERT_THAT(result_type, IsError(ErrorKind::kNotSupported));
194201
}
202+
203+
for (const auto& c : null_cases) {
204+
auto result = TransformFromString(c.str);
205+
ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str;
206+
207+
const auto& transform = result.value();
208+
auto transformPtr = transform->Bind(c.source_type);
209+
210+
ASSERT_THAT(transformPtr, IsError(ErrorKind::kNotSupported));
211+
EXPECT_THAT(transformPtr, HasErrorMessage("null is not a valid"));
212+
}
195213
}
196214

197215
} // namespace iceberg

0 commit comments

Comments
 (0)