Skip to content

Commit 6d4ffa1

Browse files
committed
feat: implement transform ResultType
1 parent 0779a52 commit 6d4ffa1

File tree

2 files changed

+169
-7
lines changed

2 files changed

+169
-7
lines changed

src/iceberg/transform_function.cc

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,27 @@ Result<ArrowArray> BucketTransform::Transform(const ArrowArray& input) {
4848
}
4949

5050
Result<std::shared_ptr<Type>> BucketTransform::ResultType() const {
51-
return NotImplemented("BucketTransform::result_type");
51+
auto src_type = source_type();
52+
if (!src_type) {
53+
return NotSupported("null is not a valid input type for bucket transform");
54+
}
55+
switch (src_type->type_id()) {
56+
case TypeId::kInt:
57+
case TypeId::kLong:
58+
case TypeId::kDecimal:
59+
case TypeId::kDate:
60+
case TypeId::kTime:
61+
case TypeId::kTimestamp:
62+
case TypeId::kTimestampTz:
63+
case TypeId::kString:
64+
case TypeId::kUuid:
65+
case TypeId::kFixed:
66+
case TypeId::kBinary:
67+
return std::make_shared<IntType>();
68+
default:
69+
return NotSupported("{} is not a valid input type for bucket transform",
70+
src_type->ToString());
71+
}
5272
}
5373

5474
TruncateTransform::TruncateTransform(std::shared_ptr<Type> const& source_type,
@@ -60,7 +80,21 @@ Result<ArrowArray> TruncateTransform::Transform(const ArrowArray& input) {
6080
}
6181

6282
Result<std::shared_ptr<Type>> TruncateTransform::ResultType() const {
63-
return NotImplemented("TruncateTransform::result_type");
83+
auto src_type = source_type();
84+
if (!src_type) {
85+
return NotSupported("null is not a valid input type for truncate transform");
86+
}
87+
switch (src_type->type_id()) {
88+
case TypeId::kInt:
89+
case TypeId::kLong:
90+
case TypeId::kDecimal:
91+
case TypeId::kString:
92+
case TypeId::kBinary:
93+
return src_type;
94+
default:
95+
return NotSupported("{} is not a valid input type for truncate transform",
96+
src_type->ToString());
97+
}
6498
}
6599

66100
YearTransform::YearTransform(std::shared_ptr<Type> const& source_type)
@@ -71,7 +105,19 @@ Result<ArrowArray> YearTransform::Transform(const ArrowArray& input) {
71105
}
72106

73107
Result<std::shared_ptr<Type>> YearTransform::ResultType() const {
74-
return NotImplemented("YearTransform::result_type");
108+
auto src_type = source_type();
109+
if (!src_type) {
110+
return NotSupported("null is not a valid input type for year transform");
111+
}
112+
switch (src_type->type_id()) {
113+
case TypeId::kDate:
114+
case TypeId::kTimestamp:
115+
case TypeId::kTimestampTz:
116+
return std::make_shared<IntType>();
117+
default:
118+
return NotSupported("{} is not a valid input type for year transform",
119+
src_type->ToString());
120+
}
75121
}
76122

77123
MonthTransform::MonthTransform(std::shared_ptr<Type> const& source_type)
@@ -82,7 +128,19 @@ Result<ArrowArray> MonthTransform::Transform(const ArrowArray& input) {
82128
}
83129

84130
Result<std::shared_ptr<Type>> MonthTransform::ResultType() const {
85-
return NotImplemented("MonthTransform::result_type");
131+
auto src_type = source_type();
132+
if (!src_type) {
133+
return NotSupported("null is not a valid input type for month transform");
134+
}
135+
switch (src_type->type_id()) {
136+
case TypeId::kDate:
137+
case TypeId::kTimestamp:
138+
case TypeId::kTimestampTz:
139+
return std::make_shared<IntType>();
140+
default:
141+
return NotSupported("{} is not a valid input type for month transform",
142+
src_type->ToString());
143+
}
86144
}
87145

88146
DayTransform::DayTransform(std::shared_ptr<Type> const& source_type)
@@ -93,7 +151,19 @@ Result<ArrowArray> DayTransform::Transform(const ArrowArray& input) {
93151
}
94152

95153
Result<std::shared_ptr<Type>> DayTransform::ResultType() const {
96-
return NotImplemented("DayTransform::result_type");
154+
auto src_type = source_type();
155+
if (!src_type) {
156+
return NotSupported("null is not a valid input type for day transform");
157+
}
158+
switch (src_type->type_id()) {
159+
case TypeId::kDate:
160+
case TypeId::kTimestamp:
161+
case TypeId::kTimestampTz:
162+
return std::make_shared<DateType>();
163+
default:
164+
return NotSupported("{} is not a valid input type for day transform",
165+
src_type->ToString());
166+
}
97167
}
98168

99169
HourTransform::HourTransform(std::shared_ptr<Type> const& source_type)
@@ -104,7 +174,18 @@ Result<ArrowArray> HourTransform::Transform(const ArrowArray& input) {
104174
}
105175

106176
Result<std::shared_ptr<Type>> HourTransform::ResultType() const {
107-
return NotImplemented("HourTransform::result_type");
177+
auto src_type = source_type();
178+
if (!src_type) {
179+
return NotSupported("null is not a valid input type for hour transform");
180+
}
181+
switch (src_type->type_id()) {
182+
case TypeId::kTimestamp:
183+
case TypeId::kTimestampTz:
184+
return std::make_shared<IntType>();
185+
default:
186+
return NotSupported("{} is not a valid input type for hour transform",
187+
src_type->ToString());
188+
}
108189
}
109190

110191
VoidTransform::VoidTransform(std::shared_ptr<Type> const& source_type)
@@ -115,7 +196,11 @@ Result<ArrowArray> VoidTransform::Transform(const ArrowArray& input) {
115196
}
116197

117198
Result<std::shared_ptr<Type>> VoidTransform::ResultType() const {
118-
return NotImplemented("VoidTransform::result_type");
199+
auto src_type = source_type();
200+
if (!src_type) {
201+
return NotSupported("null is not a valid input type for void transform");
202+
}
203+
return src_type;
119204
}
120205

121206
} // namespace iceberg

test/transform_test.cc

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,81 @@ TEST(TransformFromStringTest, NegativeCases) {
117117
}
118118
}
119119

120+
TEST(TransformResultTypeTest, PositiveCases) {
121+
struct Case {
122+
std::string str;
123+
std::shared_ptr<Type> source_type;
124+
std::shared_ptr<Type> expected_result_type;
125+
};
126+
127+
const std::vector<Case> cases = {
128+
{.str = "identity",
129+
.source_type = std::make_shared<StringType>(),
130+
.expected_result_type = std::make_shared<StringType>()},
131+
{.str = "year",
132+
.source_type = std::make_shared<TimestampType>(),
133+
.expected_result_type = std::make_shared<IntType>()},
134+
{.str = "month",
135+
.source_type = std::make_shared<TimestampType>(),
136+
.expected_result_type = std::make_shared<IntType>()},
137+
{.str = "day",
138+
.source_type = std::make_shared<TimestampType>(),
139+
.expected_result_type = std::make_shared<DateType>()},
140+
{.str = "hour",
141+
.source_type = std::make_shared<TimestampType>(),
142+
.expected_result_type = std::make_shared<IntType>()},
143+
{.str = "void",
144+
.source_type = std::make_shared<StringType>(),
145+
.expected_result_type = std::make_shared<StringType>()},
146+
{.str = "bucket[16]",
147+
.source_type = std::make_shared<StringType>(),
148+
.expected_result_type = std::make_shared<IntType>()},
149+
{.str = "truncate[32]",
150+
.source_type = std::make_shared<StringType>(),
151+
.expected_result_type = std::make_shared<StringType>()},
152+
};
153+
154+
for (const auto& c : cases) {
155+
auto result = TransformFromString(c.str);
156+
ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str;
157+
158+
const auto& transform = result.value();
159+
const auto transformPtr = transform->Bind(c.source_type);
160+
ASSERT_TRUE(transformPtr.has_value()) << "Failed to bind: " << c.str;
161+
162+
auto result_type = transformPtr.value()->ResultType();
163+
ASSERT_TRUE(result_type.has_value()) << "Failed to get result type for: " << c.str;
164+
EXPECT_EQ(result_type.value()->type_id(), c.expected_result_type->type_id())
165+
<< "Unexpected result type for: " << c.str;
166+
}
167+
}
168+
169+
TEST(TransformResultTypeTest, NegativeCases) {
170+
struct Case {
171+
std::string str;
172+
std::shared_ptr<Type> source_type;
173+
};
174+
175+
const std::vector<Case> cases = {
176+
{.str = "identity", .source_type = nullptr},
177+
{.str = "year", .source_type = std::make_shared<StringType>()},
178+
{.str = "month", .source_type = std::make_shared<StringType>()},
179+
{.str = "day", .source_type = std::make_shared<StringType>()},
180+
{.str = "hour", .source_type = std::make_shared<StringType>()},
181+
{.str = "void", .source_type = nullptr},
182+
{.str = "bucket[16]", .source_type = std::make_shared<FloatType>()},
183+
{.str = "truncate[32]", .source_type = std::make_shared<DoubleType>()}};
184+
185+
for (const auto& c : cases) {
186+
auto result = TransformFromString(c.str);
187+
ASSERT_TRUE(result.has_value()) << "Failed to parse: " << c.str;
188+
189+
const auto& transform = result.value();
190+
auto transformPtr = transform->Bind(c.source_type);
191+
192+
auto result_type = transformPtr.value()->ResultType();
193+
ASSERT_THAT(result_type, IsError(ErrorKind::kNotSupported));
194+
}
195+
}
196+
120197
} // namespace iceberg

0 commit comments

Comments
 (0)