Skip to content

Commit e1dd11d

Browse files
committed
Resolve comments and add impl for compare / cast
1 parent 082c7e1 commit e1dd11d

File tree

2 files changed

+265
-54
lines changed

2 files changed

+265
-54
lines changed

src/iceberg/datum.cc

Lines changed: 218 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,31 +32,39 @@ PrimitiveLiteral::PrimitiveLiteral(PrimitiveLiteralValue value,
3232

3333
// Factory methods
3434
PrimitiveLiteral PrimitiveLiteral::Boolean(bool value) {
35-
return PrimitiveLiteral(value, std::make_shared<BooleanType>());
35+
return {PrimitiveLiteralValue{value}, std::make_shared<BooleanType>()};
3636
}
3737

38-
PrimitiveLiteral PrimitiveLiteral::Integer(int32_t value) {
39-
return PrimitiveLiteral(value, std::make_shared<IntType>());
38+
PrimitiveLiteral PrimitiveLiteral::Int(int32_t value) {
39+
return {PrimitiveLiteralValue{value}, std::make_shared<IntType>()};
4040
}
4141

4242
PrimitiveLiteral PrimitiveLiteral::Long(int64_t value) {
43-
return PrimitiveLiteral(value, std::make_shared<LongType>());
43+
return {PrimitiveLiteralValue{value}, std::make_shared<LongType>()};
4444
}
4545

4646
PrimitiveLiteral PrimitiveLiteral::Float(float value) {
47-
return PrimitiveLiteral(value, std::make_shared<FloatType>());
47+
return {PrimitiveLiteralValue{value}, std::make_shared<FloatType>()};
4848
}
4949

5050
PrimitiveLiteral PrimitiveLiteral::Double(double value) {
51-
return PrimitiveLiteral(value, std::make_shared<DoubleType>());
51+
return {PrimitiveLiteralValue{value}, std::make_shared<DoubleType>()};
5252
}
5353

5454
PrimitiveLiteral PrimitiveLiteral::String(std::string value) {
55-
return PrimitiveLiteral(std::move(value), std::make_shared<StringType>());
55+
return {PrimitiveLiteralValue{std::move(value)}, std::make_shared<StringType>()};
5656
}
5757

5858
PrimitiveLiteral PrimitiveLiteral::Binary(std::vector<uint8_t> value) {
59-
return PrimitiveLiteral(std::move(value), std::make_shared<BinaryType>());
59+
return {PrimitiveLiteralValue{std::move(value)}, std::make_shared<BinaryType>()};
60+
}
61+
62+
PrimitiveLiteral PrimitiveLiteral::BelowMinLiteral(std::shared_ptr<PrimitiveType> type) {
63+
return PrimitiveLiteral(PrimitiveLiteralValue{BelowMin{}}, std::move(type));
64+
}
65+
66+
PrimitiveLiteral PrimitiveLiteral::AboveMaxLiteral(std::shared_ptr<PrimitiveType> type) {
67+
return PrimitiveLiteral(PrimitiveLiteralValue{AboveMax{}}, std::move(type));
6068
}
6169

6270
Result<PrimitiveLiteral> PrimitiveLiteral::Deserialize(std::span<const uint8_t> data) {
@@ -68,7 +76,6 @@ Result<std::vector<uint8_t>> PrimitiveLiteral::Serialize() const {
6876
}
6977

7078
// Getters
71-
const PrimitiveLiteralValue& PrimitiveLiteral::value() const { return value_; }
7279

7380
const std::shared_ptr<PrimitiveType>& PrimitiveLiteral::type() const { return type_; }
7481

@@ -80,8 +87,109 @@ Result<PrimitiveLiteral> PrimitiveLiteral::CastTo(
8087
return PrimitiveLiteral(value_, target_type);
8188
}
8289

83-
return NotImplemented("Cast from {} to {} is not implemented", type_->ToString(),
84-
target_type->ToString());
90+
// Handle special values
91+
if (std::holds_alternative<BelowMin>(value_) ||
92+
std::holds_alternative<AboveMax>(value_)) {
93+
// Cannot cast type for special values
94+
return NotSupported("Cannot cast type for {}", ToString());
95+
}
96+
97+
auto source_type_id = type_->type_id();
98+
auto target_type_id = target_type->type_id();
99+
100+
// Delegate to specific cast functions based on source type
101+
switch (source_type_id) {
102+
case TypeId::kInt:
103+
return CastFromInt(target_type_id);
104+
case TypeId::kLong:
105+
return CastFromLong(target_type_id);
106+
case TypeId::kFloat:
107+
return CastFromFloat(target_type_id);
108+
case TypeId::kDouble:
109+
return CastFromDouble(target_type_id);
110+
case TypeId::kBoolean:
111+
case TypeId::kString:
112+
case TypeId::kBinary:
113+
// These types only support conversion to string (handled above)
114+
break;
115+
default:
116+
break;
117+
}
118+
119+
return NotSupported("Cast from {} to {} is not implemented", type_->ToString(),
120+
target_type->ToString());
121+
}
122+
123+
Result<PrimitiveLiteral> PrimitiveLiteral::CastFromInt(TypeId target_type_id) const {
124+
auto int_val = std::get<int32_t>(value_);
125+
126+
switch (target_type_id) {
127+
case TypeId::kLong:
128+
return PrimitiveLiteral::Long(static_cast<int64_t>(int_val));
129+
case TypeId::kFloat:
130+
return PrimitiveLiteral::Float(static_cast<float>(int_val));
131+
case TypeId::kDouble:
132+
return PrimitiveLiteral::Double(static_cast<double>(int_val));
133+
// TODO(mwish): Supports casts to date and literal
134+
default:
135+
return NotSupported("Cast from Int to {} is not implemented",
136+
static_cast<int>(target_type_id));
137+
}
138+
}
139+
140+
Result<PrimitiveLiteral> PrimitiveLiteral::CastFromLong(TypeId target_type_id) const {
141+
auto long_val = std::get<int64_t>(value_);
142+
143+
switch (target_type_id) {
144+
case TypeId::kInt: {
145+
// Check for overflow
146+
if (long_val >= std::numeric_limits<int32_t>::max()) {
147+
return PrimitiveLiteral::AboveMaxLiteral(type_);
148+
}
149+
if (long_val <= std::numeric_limits<int32_t>::min()) {
150+
return PrimitiveLiteral::BelowMinLiteral(type_);
151+
}
152+
return PrimitiveLiteral::Int(static_cast<int32_t>(long_val));
153+
}
154+
case TypeId::kFloat:
155+
return PrimitiveLiteral::Float(static_cast<float>(long_val));
156+
case TypeId::kDouble:
157+
return PrimitiveLiteral::Double(static_cast<double>(long_val));
158+
default:
159+
return NotImplemented("Cast from Long to {} is not implemented",
160+
static_cast<int>(target_type_id));
161+
}
162+
}
163+
164+
Result<PrimitiveLiteral> PrimitiveLiteral::CastFromFloat(TypeId target_type_id) const {
165+
auto float_val = std::get<float>(value_);
166+
167+
switch (target_type_id) {
168+
case TypeId::kDouble:
169+
return PrimitiveLiteral::Double(static_cast<double>(float_val));
170+
default:
171+
return NotImplemented("Cast from Float to {} is not implemented",
172+
static_cast<int>(target_type_id));
173+
}
174+
}
175+
176+
Result<PrimitiveLiteral> PrimitiveLiteral::CastFromDouble(TypeId target_type_id) const {
177+
auto double_val = std::get<double>(value_);
178+
179+
switch (target_type_id) {
180+
case TypeId::kFloat: {
181+
if (double_val > std::numeric_limits<float>::max()) {
182+
return PrimitiveLiteral::AboveMaxLiteral(type_);
183+
}
184+
if (double_val < std::numeric_limits<float>::lowest()) {
185+
return PrimitiveLiteral::BelowMinLiteral(type_);
186+
}
187+
return PrimitiveLiteral::Float(static_cast<float>(double_val));
188+
}
189+
default:
190+
return NotImplemented("Cast from Double to {} is not implemented",
191+
static_cast<int>(target_type_id));
192+
}
85193
}
86194

87195
// Three-way comparison operator
@@ -90,14 +198,109 @@ std::partial_ordering PrimitiveLiteral::operator<=>(const PrimitiveLiteral& othe
90198
if (type_->type_id() != other.type_->type_id()) {
91199
return std::partial_ordering::unordered;
92200
}
93-
if (value_ == other.value_) {
94-
return std::partial_ordering::equivalent;
201+
202+
// Same type comparison
203+
switch (type_->type_id()) {
204+
case TypeId::kBoolean: {
205+
auto this_val = std::get<bool>(value_);
206+
auto other_val = std::get<bool>(other.value_);
207+
if (this_val == other_val) return std::partial_ordering::equivalent;
208+
return this_val ? std::partial_ordering::greater : std::partial_ordering::less;
209+
}
210+
211+
case TypeId::kInt: {
212+
auto this_val = std::get<int32_t>(value_);
213+
auto other_val = std::get<int32_t>(other.value_);
214+
return this_val <=> other_val;
215+
}
216+
217+
case TypeId::kLong: {
218+
auto this_val = std::get<int64_t>(value_);
219+
auto other_val = std::get<int64_t>(other.value_);
220+
return this_val <=> other_val;
221+
}
222+
223+
case TypeId::kFloat: {
224+
auto this_val = std::get<float>(value_);
225+
auto other_val = std::get<float>(other.value_);
226+
// Use strong_ordering for floating point as spec requests
227+
return std::strong_order(this_val, other_val);
228+
}
229+
230+
case TypeId::kDouble: {
231+
auto this_val = std::get<double>(value_);
232+
auto other_val = std::get<double>(other.value_);
233+
// Use strong_ordering for floating point as spec requests
234+
return std::strong_order(this_val, other_val);
235+
}
236+
237+
case TypeId::kString: {
238+
auto& this_val = std::get<std::string>(value_);
239+
auto& other_val = std::get<std::string>(other.value_);
240+
return this_val <=> other_val;
241+
}
242+
243+
case TypeId::kBinary: {
244+
auto& this_val = std::get<std::vector<uint8_t>>(value_);
245+
auto& other_val = std::get<std::vector<uint8_t>>(other.value_);
246+
return this_val <=> other_val;
247+
}
248+
249+
default:
250+
// For unsupported types, return unordered
251+
return std::partial_ordering::unordered;
95252
}
96-
throw IcebergError("Not implemented: comparison between different primitive types");
97253
}
98254

99255
std::string PrimitiveLiteral::ToString() const {
100-
throw NotImplemented("ToString for PrimitiveLiteral is not implemented yet");
256+
if (std::holds_alternative<BelowMin>(value_)) {
257+
return "BelowMin";
258+
}
259+
if (std::holds_alternative<AboveMax>(value_)) {
260+
return "AboveMax";
261+
}
262+
263+
switch (type_->type_id()) {
264+
case TypeId::kBoolean: {
265+
return std::get<bool>(value_) ? "true" : "false";
266+
}
267+
case TypeId::kInt: {
268+
return std::to_string(std::get<int32_t>(value_));
269+
}
270+
case TypeId::kLong: {
271+
return std::to_string(std::get<int64_t>(value_));
272+
}
273+
case TypeId::kFloat: {
274+
return std::to_string(std::get<float>(value_));
275+
}
276+
case TypeId::kDouble: {
277+
return std::to_string(std::get<double>(value_));
278+
}
279+
case TypeId::kString: {
280+
return std::get<std::string>(value_);
281+
}
282+
case TypeId::kBinary: {
283+
const auto& binary_data = std::get<std::vector<uint8_t>>(value_);
284+
std::string result;
285+
result.reserve(binary_data.size() * 2); // 2 chars per byte
286+
for (const auto& byte : binary_data) {
287+
result += std::format("{:02X}", byte);
288+
}
289+
return result;
290+
}
291+
case TypeId::kDecimal:
292+
case TypeId::kUuid:
293+
case TypeId::kFixed:
294+
case TypeId::kDate:
295+
case TypeId::kTime:
296+
case TypeId::kTimestamp:
297+
case TypeId::kTimestampTz: {
298+
throw IcebergError("Not implemented: ToString for " + type_->ToString());
299+
}
300+
default: {
301+
throw IcebergError("Unknown type: " + type_->ToString());
302+
}
303+
}
101304
}
102305

103306
} // namespace iceberg

src/iceberg/datum.h

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -30,45 +30,44 @@
3030

3131
namespace iceberg {
3232

33-
/// \brief Exception type for values that are below the minimum allowed value for a
34-
/// primitive type.
35-
///
36-
/// When casting a value to a narrow primitive type, if the value exceeds the maximum of
37-
/// dest type, it might be above the maximum allowed value for that type.
38-
struct BelowMin {
39-
bool operator==(const BelowMin&) const = default;
40-
std::strong_ordering operator<=>(const BelowMin&) const = default;
41-
};
42-
43-
/// \brief Exception type for values that are above the maximum allowed value for a
44-
/// primitive type.
45-
///
46-
/// When casting a value to a narrow primitive type, if the value exceeds the maximum of
47-
/// dest type, it might be above the maximum allowed value for that type.
48-
struct AboveMax {
49-
bool operator==(const AboveMax&) const = default;
50-
std::strong_ordering operator<=>(const AboveMax&) const = default;
51-
};
52-
53-
using PrimitiveLiteralValue =
54-
std::variant<bool, // for boolean
55-
int32_t, // for int, date
56-
int64_t, // for long, timestamp, timestamp_tz, time
57-
float, // for float
58-
double, // for double
59-
std::string, // for string
60-
std::vector<uint8_t>, // for binary, fixed, decimal and uuid
61-
BelowMin, AboveMax>;
62-
6333
/// \brief PrimitiveLiteral is owned literal of a primitive type.
64-
class PrimitiveLiteral {
65-
public:
66-
explicit PrimitiveLiteral(PrimitiveLiteralValue value,
67-
std::shared_ptr<PrimitiveType> type);
34+
class ICEBERG_EXPORT PrimitiveLiteral {
35+
private:
36+
/// \brief Exception type for values that are below the minimum allowed value for a
37+
/// primitive type.
38+
///
39+
/// When casting a value to a narrow primitive type, if the value exceeds the maximum of
40+
/// dest type, it might be above the maximum allowed value for that type.
41+
struct BelowMin {
42+
bool operator==(const BelowMin&) const = default;
43+
std::strong_ordering operator<=>(const BelowMin&) const = default;
44+
};
45+
46+
/// \brief Exception type for values that are above the maximum allowed value for a
47+
/// primitive type.
48+
///
49+
/// When casting a value to a narrow primitive type, if the value exceeds the maximum of
50+
/// dest type, it might be above the maximum allowed value for that type.
51+
struct AboveMax {
52+
bool operator==(const AboveMax&) const = default;
53+
std::strong_ordering operator<=>(const AboveMax&) const = default;
54+
};
55+
56+
using PrimitiveLiteralValue =
57+
std::variant<bool, // for boolean
58+
int32_t, // for int, date
59+
int64_t, // for long, timestamp, timestamp_tz, time
60+
float, // for float
61+
double, // for double
62+
std::string, // for string
63+
std::vector<uint8_t>, // for binary, fixed
64+
std::array<uint8_t, 16>, // for uuid and decimal
65+
BelowMin, AboveMax>;
6866

69-
// Factory methods for primitive types
67+
public:
68+
/// Factory methods for primitive types
7069
static PrimitiveLiteral Boolean(bool value);
71-
static PrimitiveLiteral Integer(int32_t value);
70+
static PrimitiveLiteral Int(int32_t value);
7271
static PrimitiveLiteral Long(int64_t value);
7372
static PrimitiveLiteral Float(float value);
7473
static PrimitiveLiteral Double(double value);
@@ -86,9 +85,6 @@ class PrimitiveLiteral {
8685
/// for reference.
8786
Result<std::vector<uint8_t>> Serialize() const;
8887

89-
/// Get the value as a variant
90-
const PrimitiveLiteralValue& value() const;
91-
9288
/// Get the Iceberg Type of the literal
9389
const std::shared_ptr<PrimitiveType>& type() const;
9490

@@ -100,6 +96,18 @@ class PrimitiveLiteral {
10096

10197
std::string ToString() const;
10298

99+
private:
100+
PrimitiveLiteral(PrimitiveLiteralValue value, std::shared_ptr<PrimitiveType> type);
101+
102+
static PrimitiveLiteral BelowMinLiteral(std::shared_ptr<PrimitiveType> type);
103+
static PrimitiveLiteral AboveMaxLiteral(std::shared_ptr<PrimitiveType> type);
104+
105+
// Helper methods for type casting
106+
Result<PrimitiveLiteral> CastFromInt(TypeId target_type_id) const;
107+
Result<PrimitiveLiteral> CastFromLong(TypeId target_type_id) const;
108+
Result<PrimitiveLiteral> CastFromFloat(TypeId target_type_id) const;
109+
Result<PrimitiveLiteral> CastFromDouble(TypeId target_type_id) const;
110+
103111
private:
104112
PrimitiveLiteralValue value_;
105113
std::shared_ptr<PrimitiveType> type_;

0 commit comments

Comments
 (0)