Skip to content

Commit 826b3d1

Browse files
HeartLinkedgty404
authored andcommitted
feat: Implement Type Casting and toString for Literals (apache#206)
- Implements the complete type casting logic for `iceberg::Literal` in the `LiteralCaster` class to align with the Java reference implementation. This is critical for expression evaluation and predicate pushdown. - Add basic implementation for fixed type. - Updated `ToString()` to match Java's output format for better consistency (e.g., `X'...'` for binary). - Added comprehensive unit tests to validate all new casting logic and `ToString()` formatting.
1 parent fa1c7b4 commit 826b3d1

File tree

3 files changed

+464
-251
lines changed

3 files changed

+464
-251
lines changed

src/iceberg/expression/literal.cc

Lines changed: 173 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121

2222
#include <cmath>
2323
#include <concepts>
24+
#include <cstdint>
25+
#include <string>
2426

25-
#include "iceberg/exception.h"
27+
#include "iceberg/type_fwd.h"
28+
#include "iceberg/util/checked_cast.h"
2629
#include "iceberg/util/conversions.h"
27-
#include "iceberg/util/macros.h"
2830

2931
namespace iceberg {
3032

@@ -54,6 +56,30 @@ class LiteralCaster {
5456
/// Cast from Float type to target type.
5557
static Result<Literal> CastFromFloat(const Literal& literal,
5658
const std::shared_ptr<PrimitiveType>& target_type);
59+
60+
/// Cast from Double type to target type.
61+
static Result<Literal> CastFromDouble(
62+
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type);
63+
64+
/// Cast from String type to target type.
65+
static Result<Literal> CastFromString(
66+
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type);
67+
68+
/// Cast from Timestamp type to target type.
69+
static Result<Literal> CastFromTimestamp(
70+
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type);
71+
72+
/// Cast from TimestampTz type to target type.
73+
static Result<Literal> CastFromTimestampTz(
74+
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type);
75+
76+
/// Cast from Binary type to target type.
77+
static Result<Literal> CastFromBinary(
78+
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type);
79+
80+
/// Cast from Fixed type to target type.
81+
static Result<Literal> CastFromFixed(const Literal& literal,
82+
const std::shared_ptr<PrimitiveType>& target_type);
5783
};
5884

5985
Literal LiteralCaster::BelowMinLiteral(std::shared_ptr<PrimitiveType> type) {
@@ -76,6 +102,8 @@ Result<Literal> LiteralCaster::CastFromInt(
76102
return Literal::Float(static_cast<float>(int_val));
77103
case TypeId::kDouble:
78104
return Literal::Double(static_cast<double>(int_val));
105+
case TypeId::kDate:
106+
return Literal::Date(int_val);
79107
default:
80108
return NotSupported("Cast from Int to {} is not implemented",
81109
target_type->ToString());
@@ -85,15 +113,14 @@ Result<Literal> LiteralCaster::CastFromInt(
85113
Result<Literal> LiteralCaster::CastFromLong(
86114
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) {
87115
auto long_val = std::get<int64_t>(literal.value_);
88-
auto target_type_id = target_type->type_id();
89116

90-
switch (target_type_id) {
117+
switch (target_type->type_id()) {
91118
case TypeId::kInt: {
92119
// Check for overflow
93-
if (long_val >= std::numeric_limits<int32_t>::max()) {
120+
if (long_val > std::numeric_limits<int32_t>::max()) {
94121
return AboveMaxLiteral(target_type);
95122
}
96-
if (long_val <= std::numeric_limits<int32_t>::min()) {
123+
if (long_val < std::numeric_limits<int32_t>::min()) {
97124
return BelowMinLiteral(target_type);
98125
}
99126
return Literal::Int(static_cast<int32_t>(long_val));
@@ -102,6 +129,21 @@ Result<Literal> LiteralCaster::CastFromLong(
102129
return Literal::Float(static_cast<float>(long_val));
103130
case TypeId::kDouble:
104131
return Literal::Double(static_cast<double>(long_val));
132+
case TypeId::kDate: {
133+
if (long_val > std::numeric_limits<int32_t>::max()) {
134+
return AboveMaxLiteral(target_type);
135+
}
136+
if (long_val < std::numeric_limits<int32_t>::min()) {
137+
return BelowMinLiteral(target_type);
138+
}
139+
return Literal::Date(static_cast<int32_t>(long_val));
140+
}
141+
case TypeId::kTime:
142+
return Literal::Time(long_val);
143+
case TypeId::kTimestamp:
144+
return Literal::Timestamp(long_val);
145+
case TypeId::kTimestampTz:
146+
return Literal::TimestampTz(long_val);
105147
default:
106148
return NotSupported("Cast from Long to {} is not supported",
107149
target_type->ToString());
@@ -111,9 +153,8 @@ Result<Literal> LiteralCaster::CastFromLong(
111153
Result<Literal> LiteralCaster::CastFromFloat(
112154
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) {
113155
auto float_val = std::get<float>(literal.value_);
114-
auto target_type_id = target_type->type_id();
115156

116-
switch (target_type_id) {
157+
switch (target_type->type_id()) {
117158
case TypeId::kDouble:
118159
return Literal::Double(static_cast<double>(float_val));
119160
default:
@@ -122,6 +163,103 @@ Result<Literal> LiteralCaster::CastFromFloat(
122163
}
123164
}
124165

166+
Result<Literal> LiteralCaster::CastFromDouble(
167+
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) {
168+
auto double_val = std::get<double>(literal.value_);
169+
170+
switch (target_type->type_id()) {
171+
case TypeId::kFloat: {
172+
if (double_val > static_cast<double>(std::numeric_limits<float>::max())) {
173+
return AboveMaxLiteral(target_type);
174+
}
175+
if (double_val < static_cast<double>(std::numeric_limits<float>::lowest())) {
176+
return BelowMinLiteral(target_type);
177+
}
178+
return Literal::Float(static_cast<float>(double_val));
179+
}
180+
default:
181+
return NotSupported("Cast from Double to {} is not supported",
182+
target_type->ToString());
183+
}
184+
}
185+
186+
Result<Literal> LiteralCaster::CastFromString(
187+
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) {
188+
const auto& str_val = std::get<std::string>(literal.value_);
189+
190+
switch (target_type->type_id()) {
191+
case TypeId::kDate:
192+
case TypeId::kTime:
193+
case TypeId::kTimestamp:
194+
case TypeId::kTimestampTz:
195+
case TypeId::kUuid:
196+
return NotImplemented("Cast from String to {} is not implemented yet",
197+
target_type->ToString());
198+
default:
199+
return NotSupported("Cast from String to {} is not supported",
200+
target_type->ToString());
201+
}
202+
}
203+
204+
Result<Literal> LiteralCaster::CastFromTimestamp(
205+
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) {
206+
auto timestamp_val = std::get<int64_t>(literal.value_);
207+
208+
switch (target_type->type_id()) {
209+
case TypeId::kDate:
210+
return NotImplemented("Cast from Timestamp to Date is not implemented yet");
211+
case TypeId::kTimestampTz:
212+
return Literal::TimestampTz(timestamp_val);
213+
default:
214+
return NotSupported("Cast from Timestamp to {} is not supported",
215+
target_type->ToString());
216+
}
217+
}
218+
219+
Result<Literal> LiteralCaster::CastFromTimestampTz(
220+
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) {
221+
auto micros = std::get<int64_t>(literal.value_);
222+
223+
switch (target_type->type_id()) {
224+
case TypeId::kDate:
225+
return NotImplemented("Cast from TimestampTz to Date is not implemented yet");
226+
case TypeId::kTimestamp:
227+
return Literal::Timestamp(micros);
228+
default:
229+
return NotSupported("Cast from TimestampTz to {} is not supported",
230+
target_type->ToString());
231+
}
232+
}
233+
234+
Result<Literal> LiteralCaster::CastFromBinary(
235+
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) {
236+
auto binary_val = std::get<std::vector<uint8_t>>(literal.value_);
237+
switch (target_type->type_id()) {
238+
case TypeId::kFixed: {
239+
auto target_fixed_type = internal::checked_pointer_cast<FixedType>(target_type);
240+
if (binary_val.size() == target_fixed_type->length()) {
241+
return Literal::Fixed(std::move(binary_val));
242+
}
243+
return InvalidArgument("Failed to cast Binary with length {} to Fixed({})",
244+
binary_val.size(), target_fixed_type->length());
245+
}
246+
default:
247+
return NotSupported("Cast from Binary to {} is not supported",
248+
target_type->ToString());
249+
}
250+
}
251+
252+
Result<Literal> LiteralCaster::CastFromFixed(
253+
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) {
254+
switch (target_type->type_id()) {
255+
case TypeId::kBinary:
256+
return Literal::Binary(std::get<std::vector<uint8_t>>(literal.value_));
257+
default:
258+
return NotSupported("Cast from Fixed to {} is not supported",
259+
target_type->ToString());
260+
}
261+
}
262+
125263
// Constructor
126264
Literal::Literal(Value value, std::shared_ptr<PrimitiveType> type)
127265
: value_(std::move(value)), type_(std::move(type)) {}
@@ -154,8 +292,8 @@ Literal Literal::Binary(std::vector<uint8_t> value) {
154292
}
155293

156294
Literal Literal::Fixed(std::vector<uint8_t> value) {
157-
auto length = static_cast<int32_t>(value.size());
158-
return {Value{std::move(value)}, fixed(length)};
295+
const auto size = value.size();
296+
return {Value{std::move(value)}, fixed(size)};
159297
}
160298

161299
Result<Literal> Literal::Deserialize(std::span<const uint8_t> data,
@@ -262,12 +400,7 @@ std::partial_ordering Literal::operator<=>(const Literal& other) const {
262400
return std::partial_ordering::unordered;
263401
}
264402

265-
case TypeId::kBinary: {
266-
auto& this_val = std::get<std::vector<uint8_t>>(value_);
267-
auto& other_val = std::get<std::vector<uint8_t>>(other.value_);
268-
return this_val <=> other_val;
269-
}
270-
403+
case TypeId::kBinary:
271404
case TypeId::kFixed: {
272405
auto& this_val = std::get<std::vector<uint8_t>>(value_);
273406
auto& other_val = std::get<std::vector<uint8_t>>(other.value_);
@@ -308,38 +441,32 @@ std::string Literal::ToString() const {
308441
return std::to_string(std::get<double>(value_));
309442
}
310443
case TypeId::kString: {
311-
return std::get<std::string>(value_);
444+
return "\"" + std::get<std::string>(value_) + "\"";
312445
}
313446
case TypeId::kUuid: {
314447
return std::get<Uuid>(value_).ToString();
315448
}
316-
case TypeId::kBinary: {
449+
case TypeId::kBinary:
450+
case TypeId::kFixed: {
317451
const auto& binary_data = std::get<std::vector<uint8_t>>(value_);
318-
std::string result;
319-
result.reserve(binary_data.size() * 2); // 2 chars per byte
452+
std::string result = "X'";
453+
result.reserve(/*prefix*/ 2 + /*suffix*/ 1 + /*data*/ binary_data.size() * 2);
320454
for (const auto& byte : binary_data) {
321455
std::format_to(std::back_inserter(result), "{:02X}", byte);
322456
}
457+
result.push_back('\'');
323458
return result;
324459
}
325-
case TypeId::kFixed: {
326-
const auto& fixed_data = std::get<std::vector<uint8_t>>(value_);
327-
std::string result;
328-
result.reserve(fixed_data.size() * 2); // 2 chars per byte
329-
for (const auto& byte : fixed_data) {
330-
std::format_to(std::back_inserter(result), "{:02X}", byte);
331-
}
332-
return result;
333-
}
334-
case TypeId::kDecimal:
335-
case TypeId::kDate:
336460
case TypeId::kTime:
337461
case TypeId::kTimestamp:
338462
case TypeId::kTimestampTz: {
339-
throw IcebergError("Not implemented: ToString for " + type_->ToString());
463+
return std::to_string(std::get<int64_t>(value_));
464+
}
465+
case TypeId::kDate: {
466+
return std::to_string(std::get<int32_t>(value_));
340467
}
341468
default: {
342-
throw IcebergError("Unknown type: " + type_->ToString());
469+
return std::format("invalid literal of type {}", type_->ToString());
343470
}
344471
}
345472
}
@@ -371,22 +498,32 @@ Result<Literal> LiteralCaster::CastTo(const Literal& literal,
371498

372499
// Delegate to specific cast functions based on source type
373500
switch (source_type_id) {
501+
case TypeId::kBoolean:
502+
// No casts defined for Boolean, other than to itself.
503+
break;
374504
case TypeId::kInt:
375505
return CastFromInt(literal, target_type);
376506
case TypeId::kLong:
377507
return CastFromLong(literal, target_type);
378508
case TypeId::kFloat:
379509
return CastFromFloat(literal, target_type);
380510
case TypeId::kDouble:
381-
case TypeId::kBoolean:
511+
return CastFromDouble(literal, target_type);
382512
case TypeId::kString:
513+
return CastFromString(literal, target_type);
383514
case TypeId::kBinary:
384-
break;
515+
return CastFromBinary(literal, target_type);
516+
case TypeId::kFixed:
517+
return CastFromFixed(literal, target_type);
518+
case TypeId::kTimestamp:
519+
return CastFromTimestamp(literal, target_type);
520+
case TypeId::kTimestampTz:
521+
return CastFromTimestampTz(literal, target_type);
385522
default:
386523
break;
387524
}
388525

389-
return NotSupported("Cast from {} to {} is not implemented", literal.type_->ToString(),
526+
return NotSupported("Cast from {} to {} is not supported", literal.type_->ToString(),
390527
target_type->ToString());
391528
}
392529

src/iceberg/expression/predicate.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,11 @@ std::string UnboundPredicate<B>::ToString() const {
100100
return values_.size() == 1 ? std::format("{} != {}", term, values_[0])
101101
: invalid_predicate_string(op);
102102
case Expression::Operation::kStartsWith:
103-
return values_.size() == 1 ? std::format("{} startsWith \"{}\"", term, values_[0])
103+
return values_.size() == 1 ? std::format("{} startsWith {}", term, values_[0])
104104
: invalid_predicate_string(op);
105105
case Expression::Operation::kNotStartsWith:
106-
return values_.size() == 1
107-
? std::format("{} notStartsWith \"{}\"", term, values_[0])
108-
: invalid_predicate_string(op);
106+
return values_.size() == 1 ? std::format("{} notStartsWith {}", term, values_[0])
107+
: invalid_predicate_string(op);
109108
case Expression::Operation::kIn:
110109
return std::format("{} in {}", term, values_);
111110
case Expression::Operation::kNotIn:

0 commit comments

Comments
 (0)