Skip to content

Commit f989a81

Browse files
committed
feat: implement Literal Transform
1 parent 378de75 commit f989a81

File tree

9 files changed

+591
-13
lines changed

9 files changed

+591
-13
lines changed

src/iceberg/expression/literal.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,16 @@ Literal Literal::Boolean(bool value) { return {Value{value}, iceberg::boolean()}
130130

131131
Literal Literal::Int(int32_t value) { return {Value{value}, iceberg::int32()}; }
132132

133+
Literal Literal::Date(int32_t value) { return {Value{value}, iceberg::date()}; }
134+
133135
Literal Literal::Long(int64_t value) { return {Value{value}, iceberg::int64()}; }
134136

137+
Literal Literal::Timestamp(int64_t value) { return {Value{value}, iceberg::timestamp()}; }
138+
139+
Literal Literal::TimestampTz(int64_t value) {
140+
return {Value{value}, iceberg::timestamp_tz()};
141+
}
142+
135143
Literal Literal::Float(float value) { return {Value{value}, iceberg::float32()}; }
136144

137145
Literal Literal::Double(double value) { return {Value{value}, iceberg::float64()}; }
@@ -208,12 +216,30 @@ std::partial_ordering Literal::operator<=>(const Literal& other) const {
208216
return this_val <=> other_val;
209217
}
210218

219+
case TypeId::kDate: {
220+
auto this_val = std::get<int32_t>(value_);
221+
auto other_val = std::get<int32_t>(other.value_);
222+
return this_val <=> other_val;
223+
}
224+
211225
case TypeId::kLong: {
212226
auto this_val = std::get<int64_t>(value_);
213227
auto other_val = std::get<int64_t>(other.value_);
214228
return this_val <=> other_val;
215229
}
216230

231+
case TypeId::kTimestamp: {
232+
auto this_val = std::get<int64_t>(value_);
233+
auto other_val = std::get<int64_t>(other.value_);
234+
return this_val <=> other_val;
235+
}
236+
237+
case TypeId::kTimestampTz: {
238+
auto this_val = std::get<int64_t>(value_);
239+
auto other_val = std::get<int64_t>(other.value_);
240+
return this_val <=> other_val;
241+
}
242+
217243
case TypeId::kFloat: {
218244
auto this_val = std::get<float>(value_);
219245
auto other_val = std::get<float>(other.value_);

src/iceberg/expression/literal.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ class ICEBERG_EXPORT Literal {
6363
/// \brief Factory methods for primitive types
6464
static Literal Boolean(bool value);
6565
static Literal Int(int32_t value);
66+
static Literal Date(int32_t value);
6667
static Literal Long(int64_t value);
68+
static Literal Timestamp(int64_t value);
69+
static Literal TimestampTz(int64_t value);
6770
static Literal Float(float value);
6871
static Literal Double(double value);
6972
static Literal String(std::string value);
@@ -85,6 +88,9 @@ class ICEBERG_EXPORT Literal {
8588
/// \brief Get the literal type.
8689
const std::shared_ptr<PrimitiveType>& type() const;
8790

91+
/// \brief Get the literal value.
92+
const Value& value() const { return value_; }
93+
8894
/// \brief Converts this literal to a literal of the given type.
8995
///
9096
/// When a predicate is bound to a concrete data column, literals are converted to match

src/iceberg/manifest_reader_internal.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@
3333

3434
namespace iceberg {
3535

36-
#define NANOARROW_RETURN_IF_NOT_OK(status, error) \
37-
if (status != NANOARROW_OK) [[unlikely]] { \
38-
return InvalidArrowData("Nanoarrow error: {}", error.message); \
39-
}
40-
4136
#define PARSE_PRIMITIVE_FIELD(item, array_view, type) \
4237
for (size_t row_idx = 0; row_idx < array_view->length; row_idx++) { \
4338
if (!ArrowArrayViewIsNull(array_view, row_idx)) { \

src/iceberg/transform.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323

2424
#include <cstdint>
2525
#include <memory>
26+
#include <optional>
2627
#include <variant>
2728

2829
#include "iceberg/arrow_c_data.h"
30+
#include "iceberg/expression/literal.h"
2931
#include "iceberg/iceberg_export.h"
3032
#include "iceberg/result.h"
3133
#include "iceberg/type_fwd.h"
@@ -172,6 +174,8 @@ class ICEBERG_EXPORT TransformFunction {
172174
TransformFunction(TransformType transform_type, std::shared_ptr<Type> source_type);
173175
/// \brief Transform an input array to a new array
174176
virtual Result<ArrowArray> Transform(const ArrowArray& data) = 0;
177+
/// \brief Transform an input Literal to a new Literal
178+
virtual Result<std::optional<Literal>> Transform(const Literal& literal) = 0;
175179
/// \brief Get the transform type
176180
TransformType transform_type() const;
177181
/// \brief Get the source type of transform function

src/iceberg/transform_function.cc

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919

2020
#include "iceberg/transform_function.h"
2121

22+
#include <cassert>
23+
#include <chrono>
24+
2225
#include "iceberg/type.h"
26+
#include "iceberg/util/murmurhash3_internal.h"
2327

2428
namespace iceberg {
2529

@@ -30,6 +34,10 @@ Result<ArrowArray> IdentityTransform::Transform(const ArrowArray& input) {
3034
return NotImplemented("IdentityTransform::Transform");
3135
}
3236

37+
Result<std::optional<Literal>> IdentityTransform::Transform(const Literal& literal) {
38+
return literal;
39+
}
40+
3341
Result<std::shared_ptr<Type>> IdentityTransform::ResultType() const {
3442
return source_type();
3543
}
@@ -51,6 +59,57 @@ Result<ArrowArray> BucketTransform::Transform(const ArrowArray& input) {
5159
return NotImplemented("BucketTransform::Transform");
5260
}
5361

62+
Result<std::optional<Literal>> BucketTransform::Transform(const Literal& literal) {
63+
assert(literal.type() == source_type());
64+
if (literal.IsBelowMin() || literal.IsAboveMax()) {
65+
return InvalidArgument(
66+
"Cannot apply bucket transform to literal with value {} of type {}",
67+
literal.ToString(), source_type()->ToString());
68+
}
69+
int32_t hash_value = 0;
70+
switch (source_type()->type_id()) {
71+
case TypeId::kInt:
72+
case TypeId::kDate: {
73+
auto value = std::get<int32_t>(literal.value());
74+
MurmurHash3_x86_32(&value, sizeof(int32_t), 0, &hash_value);
75+
break;
76+
}
77+
case TypeId::kLong:
78+
case TypeId::kTime:
79+
case TypeId::kTimestamp:
80+
case TypeId::kTimestampTz: {
81+
auto value = std::get<int64_t>(literal.value());
82+
MurmurHash3_x86_32(&value, sizeof(int64_t), 0, &hash_value);
83+
break;
84+
}
85+
case TypeId::kDecimal:
86+
case TypeId::kUuid: {
87+
auto value = std::get<std::array<uint8_t, 16>>(literal.value());
88+
MurmurHash3_x86_32(value.data(), sizeof(uint8_t) * 16, 0, &hash_value);
89+
break;
90+
}
91+
case TypeId::kString: {
92+
auto value = std::get<std::string>(literal.value());
93+
MurmurHash3_x86_32(value.data(), value.size(), 0, &hash_value);
94+
break;
95+
}
96+
case TypeId::kFixed:
97+
case TypeId::kBinary: {
98+
auto value = std::get<std::vector<uint8_t>>(literal.value());
99+
MurmurHash3_x86_32(value.data(), value.size(), 0, &hash_value);
100+
break;
101+
}
102+
default:
103+
std::unreachable();
104+
}
105+
106+
// Calculate the bucket index
107+
int32_t bucket_index =
108+
(hash_value & std::numeric_limits<int32_t>::max()) % num_buckets_;
109+
110+
return Literal::Int(bucket_index);
111+
}
112+
54113
Result<std::shared_ptr<Type>> BucketTransform::ResultType() const {
55114
return iceberg::int32();
56115
}
@@ -91,6 +150,46 @@ Result<ArrowArray> TruncateTransform::Transform(const ArrowArray& input) {
91150
return NotImplemented("TruncateTransform::Transform");
92151
}
93152

153+
Result<std::optional<Literal>> TruncateTransform::Transform(const Literal& literal) {
154+
assert(literal.type() == source_type());
155+
if (literal.IsBelowMin() || literal.IsAboveMax()) {
156+
return InvalidArgument(
157+
"Cannot apply truncate transform to literal with value {} of type {}",
158+
literal.ToString(), source_type()->ToString());
159+
}
160+
161+
switch (source_type()->type_id()) {
162+
case TypeId::kInt: {
163+
auto value = std::get<int32_t>(literal.value());
164+
return Literal::Int(value % width_);
165+
}
166+
case TypeId::kLong: {
167+
auto value = std::get<int64_t>(literal.value());
168+
return Literal::Long(value % width_);
169+
}
170+
case TypeId::kDecimal: {
171+
// TODO(zhjwpku): Handle decimal truncation logic here
172+
return NotImplemented("Truncate for Decimal is not implemented yet");
173+
}
174+
case TypeId::kString: {
175+
auto value = std::get<std::string>(literal.value());
176+
if (value.size() > static_cast<size_t>(width_)) {
177+
value.resize(width_);
178+
}
179+
return Literal::String(value);
180+
}
181+
case TypeId::kBinary: {
182+
auto value = std::get<std::vector<uint8_t>>(literal.value());
183+
if (value.size() > static_cast<size_t>(width_)) {
184+
value.resize(width_);
185+
}
186+
return Literal::Binary(value);
187+
}
188+
default:
189+
std::unreachable();
190+
}
191+
}
192+
94193
Result<std::shared_ptr<Type>> TruncateTransform::ResultType() const {
95194
return source_type();
96195
}
@@ -124,6 +223,34 @@ Result<ArrowArray> YearTransform::Transform(const ArrowArray& input) {
124223
return NotImplemented("YearTransform::Transform");
125224
}
126225

226+
Result<std::optional<Literal>> YearTransform::Transform(const Literal& literal) {
227+
assert(literal.type() == source_type());
228+
if (literal.IsBelowMin() || literal.IsAboveMax()) {
229+
return InvalidArgument(
230+
"Cannot apply year transform to literal with value {} of type {}",
231+
literal.ToString(), source_type()->ToString());
232+
}
233+
234+
using namespace std::chrono;
235+
switch (source_type()->type_id()) {
236+
case TypeId::kDate: {
237+
auto value = std::get<int32_t>(literal.value());
238+
auto epoch = sys_days(year{1970} / January / 1);
239+
auto ymd = year_month_day(epoch + days{value});
240+
return Literal::Int(static_cast<int32_t>(ymd.year()));
241+
}
242+
case TypeId::kTimestamp:
243+
case TypeId::kTimestampTz: {
244+
auto value = std::get<int64_t>(literal.value());
245+
// Convert milliseconds-since-epoch into a `year_month_day` object
246+
auto ymd = year_month_day(floor<days>(sys_time<milliseconds>(milliseconds{value})));
247+
return Literal::Int(static_cast<int32_t>(ymd.year()));
248+
}
249+
default:
250+
std::unreachable();
251+
}
252+
}
253+
127254
Result<std::shared_ptr<Type>> YearTransform::ResultType() const {
128255
return iceberg::int32();
129256
}
@@ -152,6 +279,46 @@ Result<ArrowArray> MonthTransform::Transform(const ArrowArray& input) {
152279
return NotImplemented("MonthTransform::Transform");
153280
}
154281

282+
Result<std::optional<Literal>> MonthTransform::Transform(const Literal& literal) {
283+
assert(literal.type() == source_type());
284+
if (literal.IsBelowMin() || literal.IsAboveMax()) {
285+
return InvalidArgument(
286+
"Cannot apply month transform to literal with value {} of type {}",
287+
literal.ToString(), source_type()->ToString());
288+
}
289+
290+
using namespace std::chrono;
291+
switch (source_type()->type_id()) {
292+
case TypeId::kDate: {
293+
auto value = std::get<int32_t>(literal.value());
294+
auto epoch = sys_days(year{1970} / January / 1);
295+
auto ymd = year_month_day(epoch + days{value});
296+
auto epoch_ymd = year_month_day(epoch);
297+
auto delta = ymd.year() - epoch_ymd.year();
298+
// Calculate the month as months from 1970-01
299+
// Note: January is month 1, so we subtract 1 to get zero-based
300+
// month count.
301+
return Literal::Int(static_cast<int32_t>(delta.count() * 12 +
302+
static_cast<unsigned>(ymd.month()) - 1));
303+
}
304+
case TypeId::kTimestamp:
305+
case TypeId::kTimestampTz: {
306+
auto value = std::get<int64_t>(literal.value());
307+
// Convert milliseconds-since-epoch into a `year_month_day` object
308+
auto ymd = year_month_day(floor<days>(sys_time<milliseconds>(milliseconds{value})));
309+
auto epoch_ymd = year_month_day(year{1970} / January / 1);
310+
auto delta = ymd.year() - epoch_ymd.year();
311+
// Calculate the month as months from 1970-01
312+
// Note: January is month 1, so we subtract 1 to get zero-based
313+
// month count.
314+
return Literal::Int(static_cast<int32_t>(delta.count() * 12 +
315+
static_cast<unsigned>(ymd.month()) - 1));
316+
}
317+
default:
318+
std::unreachable();
319+
}
320+
}
321+
155322
Result<std::shared_ptr<Type>> MonthTransform::ResultType() const {
156323
return iceberg::int32();
157324
}
@@ -180,6 +347,35 @@ Result<ArrowArray> DayTransform::Transform(const ArrowArray& input) {
180347
return NotImplemented("DayTransform::Transform");
181348
}
182349

350+
Result<std::optional<Literal>> DayTransform::Transform(const Literal& literal) {
351+
assert(literal.type() == source_type());
352+
if (literal.IsBelowMin() || literal.IsAboveMax()) {
353+
return InvalidArgument(
354+
"Cannot apply day transform to literal with value {} of type {}",
355+
literal.ToString(), source_type()->ToString());
356+
}
357+
358+
using namespace std::chrono;
359+
switch (source_type()->type_id()) {
360+
case TypeId::kDate: {
361+
// Day is the same as the date value
362+
return literal;
363+
}
364+
case TypeId::kTimestamp:
365+
case TypeId::kTimestampTz: {
366+
auto value = std::get<int64_t>(literal.value());
367+
// Convert milliseconds to `sys_days` (chronological days since epoch)
368+
auto timestamp = sys_time<milliseconds>(milliseconds{value});
369+
auto days_since_epoch = floor<days>(timestamp);
370+
371+
return Literal::Date(
372+
static_cast<int32_t>(days_since_epoch.time_since_epoch().count()));
373+
}
374+
default:
375+
std::unreachable();
376+
}
377+
}
378+
183379
Result<std::shared_ptr<Type>> DayTransform::ResultType() const { return iceberg::date(); }
184380

185381
Result<std::unique_ptr<TransformFunction>> DayTransform::Make(
@@ -206,6 +402,32 @@ Result<ArrowArray> HourTransform::Transform(const ArrowArray& input) {
206402
return NotImplemented("HourTransform::Transform");
207403
}
208404

405+
Result<std::optional<Literal>> HourTransform::Transform(const Literal& literal) {
406+
assert(literal.type() == source_type());
407+
if (literal.IsBelowMin() || literal.IsAboveMax()) {
408+
return InvalidArgument(
409+
"Cannot apply hour transform to literal with value {} of type {}",
410+
literal.ToString(), source_type()->ToString());
411+
}
412+
413+
using namespace std::chrono;
414+
switch (source_type()->type_id()) {
415+
case TypeId::kTimestamp:
416+
case TypeId::kTimestampTz: {
417+
auto value = std::get<int64_t>(literal.value());
418+
// Create a `sys_time` object from the milliseconds value
419+
auto timestamp = sys_time<milliseconds>(milliseconds{value});
420+
421+
// Convert the time since epoch directly into hours
422+
auto hours_since_epoch = duration_cast<hours>(timestamp.time_since_epoch()).count();
423+
424+
return Literal::Int(static_cast<int32_t>(hours_since_epoch));
425+
}
426+
default:
427+
std::unreachable();
428+
}
429+
}
430+
209431
Result<std::shared_ptr<Type>> HourTransform::ResultType() const {
210432
return iceberg::int32();
211433
}
@@ -233,6 +455,10 @@ Result<ArrowArray> VoidTransform::Transform(const ArrowArray& input) {
233455
return NotImplemented("VoidTransform::Transform");
234456
}
235457

458+
Result<std::optional<Literal>> VoidTransform::Transform(const Literal& literal) {
459+
return std::nullopt;
460+
}
461+
236462
Result<std::shared_ptr<Type>> VoidTransform::ResultType() const { return source_type(); }
237463

238464
Result<std::unique_ptr<TransformFunction>> VoidTransform::Make(

0 commit comments

Comments
 (0)