Skip to content

Commit bf94971

Browse files
committed
validate source_type and other params in Make function
1 parent 41c92a6 commit bf94971

File tree

2 files changed

+67
-82
lines changed

2 files changed

+67
-82
lines changed

src/iceberg/transform_function.cc

Lines changed: 65 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,15 @@ Result<ArrowArray> BucketTransform::Transform(const ArrowArray& input) {
5252
}
5353

5454
Result<std::shared_ptr<Type>> BucketTransform::ResultType() const {
55-
auto src_type = source_type();
56-
switch (src_type->type_id()) {
55+
return iceberg::int32();
56+
}
57+
58+
Result<std::unique_ptr<TransformFunction>> BucketTransform::Make(
59+
std::shared_ptr<Type> const& source_type, int32_t num_buckets) {
60+
if (!source_type) {
61+
return NotSupported("null is not a valid input type for bucket transform");
62+
}
63+
switch (source_type->type_id()) {
5764
case TypeId::kInt:
5865
case TypeId::kLong:
5966
case TypeId::kDecimal:
@@ -65,17 +72,13 @@ Result<std::shared_ptr<Type>> BucketTransform::ResultType() const {
6572
case TypeId::kUuid:
6673
case TypeId::kFixed:
6774
case TypeId::kBinary:
68-
return iceberg::int32();
75+
break;
6976
default:
7077
return NotSupported("{} is not a valid input type for bucket transform",
71-
src_type->ToString());
78+
source_type->ToString());
7279
}
73-
}
74-
75-
Result<std::unique_ptr<TransformFunction>> BucketTransform::Make(
76-
std::shared_ptr<Type> const& source_type, int32_t num_buckets) {
77-
if (!source_type) {
78-
return NotSupported("null is not a valid input type for bucket transform");
80+
if (num_buckets <= 0) {
81+
return InvalidArgument("Number of buckets must be positive, got {}", num_buckets);
7982
}
8083
return std::make_unique<BucketTransform>(source_type, num_buckets);
8184
}
@@ -89,24 +92,27 @@ Result<ArrowArray> TruncateTransform::Transform(const ArrowArray& input) {
8992
}
9093

9194
Result<std::shared_ptr<Type>> TruncateTransform::ResultType() const {
92-
auto src_type = source_type();
93-
switch (src_type->type_id()) {
95+
return source_type();
96+
}
97+
98+
Result<std::unique_ptr<TransformFunction>> TruncateTransform::Make(
99+
std::shared_ptr<Type> const& source_type, int32_t width) {
100+
if (!source_type) {
101+
return NotSupported("null is not a valid input type for truncate transform");
102+
}
103+
switch (source_type->type_id()) {
94104
case TypeId::kInt:
95105
case TypeId::kLong:
96106
case TypeId::kDecimal:
97107
case TypeId::kString:
98108
case TypeId::kBinary:
99-
return src_type;
109+
break;
100110
default:
101111
return NotSupported("{} is not a valid input type for truncate transform",
102-
src_type->ToString());
112+
source_type->ToString());
103113
}
104-
}
105-
106-
Result<std::unique_ptr<TransformFunction>> TruncateTransform::Make(
107-
std::shared_ptr<Type> const& source_type, int32_t width) {
108-
if (!source_type) {
109-
return NotSupported("null is not a valid input type for truncate transform");
114+
if (width <= 0) {
115+
return InvalidArgument("Width must be positive, got {}", width);
110116
}
111117
return std::make_unique<TruncateTransform>(source_type, width);
112118
}
@@ -119,23 +125,23 @@ Result<ArrowArray> YearTransform::Transform(const ArrowArray& input) {
119125
}
120126

121127
Result<std::shared_ptr<Type>> YearTransform::ResultType() const {
122-
auto src_type = source_type();
123-
switch (src_type->type_id()) {
124-
case TypeId::kDate:
125-
case TypeId::kTimestamp:
126-
case TypeId::kTimestampTz:
127-
return iceberg::int32();
128-
default:
129-
return NotSupported("{} is not a valid input type for year transform",
130-
src_type->ToString());
131-
}
128+
return iceberg::int32();
132129
}
133130

134131
Result<std::unique_ptr<TransformFunction>> YearTransform::Make(
135132
std::shared_ptr<Type> const& source_type) {
136133
if (!source_type) {
137134
return NotSupported("null is not a valid input type for year transform");
138135
}
136+
switch (source_type->type_id()) {
137+
case TypeId::kDate:
138+
case TypeId::kTimestamp:
139+
case TypeId::kTimestampTz:
140+
break;
141+
default:
142+
return NotSupported("{} is not a valid input type for year transform",
143+
source_type->ToString());
144+
}
139145
return std::make_unique<YearTransform>(source_type);
140146
}
141147

@@ -147,23 +153,23 @@ Result<ArrowArray> MonthTransform::Transform(const ArrowArray& input) {
147153
}
148154

149155
Result<std::shared_ptr<Type>> MonthTransform::ResultType() const {
150-
auto src_type = source_type();
151-
switch (src_type->type_id()) {
152-
case TypeId::kDate:
153-
case TypeId::kTimestamp:
154-
case TypeId::kTimestampTz:
155-
return iceberg::int32();
156-
default:
157-
return NotSupported("{} is not a valid input type for month transform",
158-
src_type->ToString());
159-
}
156+
return iceberg::int32();
160157
}
161158

162159
Result<std::unique_ptr<TransformFunction>> MonthTransform::Make(
163160
std::shared_ptr<Type> const& source_type) {
164161
if (!source_type) {
165162
return NotSupported("null is not a valid input type for month transform");
166163
}
164+
switch (source_type->type_id()) {
165+
case TypeId::kDate:
166+
case TypeId::kTimestamp:
167+
case TypeId::kTimestampTz:
168+
break;
169+
default:
170+
return NotSupported("{} is not a valid input type for month transform",
171+
source_type->ToString());
172+
}
167173
return std::make_unique<MonthTransform>(source_type);
168174
}
169175

@@ -174,24 +180,22 @@ Result<ArrowArray> DayTransform::Transform(const ArrowArray& input) {
174180
return NotImplemented("DayTransform::Transform");
175181
}
176182

177-
Result<std::shared_ptr<Type>> DayTransform::ResultType() const {
178-
auto src_type = source_type();
179-
switch (src_type->type_id()) {
180-
case TypeId::kDate:
181-
case TypeId::kTimestamp:
182-
case TypeId::kTimestampTz:
183-
return iceberg::date();
184-
default:
185-
return NotSupported("{} is not a valid input type for day transform",
186-
src_type->ToString());
187-
}
188-
}
183+
Result<std::shared_ptr<Type>> DayTransform::ResultType() const { return iceberg::date(); }
189184

190185
Result<std::unique_ptr<TransformFunction>> DayTransform::Make(
191186
std::shared_ptr<Type> const& source_type) {
192187
if (!source_type) {
193188
return NotSupported("null is not a valid input type for day transform");
194189
}
190+
switch (source_type->type_id()) {
191+
case TypeId::kDate:
192+
case TypeId::kTimestamp:
193+
case TypeId::kTimestampTz:
194+
break;
195+
default:
196+
return NotSupported("{} is not a valid input type for day transform",
197+
source_type->ToString());
198+
}
195199
return std::make_unique<DayTransform>(source_type);
196200
}
197201

@@ -203,22 +207,22 @@ Result<ArrowArray> HourTransform::Transform(const ArrowArray& input) {
203207
}
204208

205209
Result<std::shared_ptr<Type>> HourTransform::ResultType() const {
206-
auto src_type = source_type();
207-
switch (src_type->type_id()) {
208-
case TypeId::kTimestamp:
209-
case TypeId::kTimestampTz:
210-
return iceberg::int32();
211-
default:
212-
return NotSupported("{} is not a valid input type for hour transform",
213-
src_type->ToString());
214-
}
210+
return iceberg::int32();
215211
}
216212

217213
Result<std::unique_ptr<TransformFunction>> HourTransform::Make(
218214
std::shared_ptr<Type> const& source_type) {
219215
if (!source_type) {
220216
return NotSupported("null is not a valid input type for hour transform");
221217
}
218+
switch (source_type->type_id()) {
219+
case TypeId::kTimestamp:
220+
case TypeId::kTimestampTz:
221+
break;
222+
default:
223+
return NotSupported("{} is not a valid input type for hour transform",
224+
source_type->ToString());
225+
}
222226
return std::make_unique<HourTransform>(source_type);
223227
}
224228

test/transform_test.cc

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -173,42 +173,23 @@ TEST(TransformResultTypeTest, NegativeCases) {
173173
};
174174

175175
const std::vector<Case> cases = {
176+
{.str = "identity", .source_type = nullptr},
176177
{.str = "year", .source_type = iceberg::string()},
177178
{.str = "month", .source_type = iceberg::string()},
178179
{.str = "day", .source_type = iceberg::string()},
179180
{.str = "hour", .source_type = iceberg::string()},
181+
{.str = "void", .source_type = nullptr},
180182
{.str = "bucket[16]", .source_type = iceberg::float32()},
181183
{.str = "truncate[32]", .source_type = iceberg::float64()}};
182184

183-
const std::vector<Case> null_cases = {{.str = "identity", .source_type = nullptr},
184-
{.str = "year", .source_type = nullptr},
185-
{.str = "month", .source_type = nullptr},
186-
{.str = "day", .source_type = nullptr},
187-
{.str = "hour", .source_type = nullptr},
188-
{.str = "void", .source_type = nullptr},
189-
{.str = "bucket[16]", .source_type = nullptr},
190-
{.str = "truncate[32]", .source_type = nullptr}};
191-
192185
for (const auto& c : cases) {
193186
auto result = TransformFromString(c.str);
194187
ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str;
195188

196189
const auto& transform = result.value();
197190
auto transformPtr = transform->Bind(c.source_type);
198191

199-
auto result_type = transformPtr.value()->ResultType();
200-
ASSERT_THAT(result_type, IsError(ErrorKind::kNotSupported));
201-
}
202-
203-
for (const auto& c : null_cases) {
204-
auto result = TransformFromString(c.str);
205-
ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str;
206-
207-
const auto& transform = result.value();
208-
auto transformPtr = transform->Bind(c.source_type);
209-
210192
ASSERT_THAT(transformPtr, IsError(ErrorKind::kNotSupported));
211-
EXPECT_THAT(transformPtr, HasErrorMessage("null is not a valid"));
212193
}
213194
}
214195

0 commit comments

Comments
 (0)