Skip to content

Commit 7e3aea7

Browse files
committed
Fix protobuf enums
1 parent bffb7ba commit 7e3aea7

File tree

8 files changed

+75
-66
lines changed

8 files changed

+75
-66
lines changed

src/errors/arrow_error.rs

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,49 +3,52 @@ use datafusion::arrow::error::ArrowError;
33

44
#[derive(Clone, PartialEq, ::prost::Message)]
55
pub struct ArrowErrorProto {
6-
#[prost(oneof = "ArrowErrorInnerProto", tags = "1")]
7-
pub inner: Option<ArrowErrorInnerProto>,
8-
#[prost(string, optional, tag = "2")]
6+
#[prost(string, optional, tag = "1")]
97
pub ctx: Option<String>,
8+
#[prost(
9+
oneof = "ArrowErrorInnerProto",
10+
tags = "2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19"
11+
)]
12+
pub inner: Option<ArrowErrorInnerProto>,
1013
}
1114

1215
#[derive(Clone, PartialEq, prost::Oneof)]
1316
pub enum ArrowErrorInnerProto {
14-
#[prost(string, tag = "1")]
15-
NotYetImplemented(String),
1617
#[prost(string, tag = "2")]
17-
ExternalError(String),
18+
NotYetImplemented(String),
1819
#[prost(string, tag = "3")]
19-
CastError(String),
20+
ExternalError(String),
2021
#[prost(string, tag = "4")]
21-
MemoryError(String),
22+
CastError(String),
2223
#[prost(string, tag = "5")]
23-
ParseError(String),
24+
MemoryError(String),
2425
#[prost(string, tag = "6")]
25-
SchemaError(String),
26+
ParseError(String),
2627
#[prost(string, tag = "7")]
28+
SchemaError(String),
29+
#[prost(string, tag = "8")]
2730
ComputeError(String),
28-
#[prost(bool, tag = "8")]
31+
#[prost(bool, tag = "9")]
2932
DivideByZero(bool),
30-
#[prost(string, tag = "9")]
31-
ArithmeticOverflow(String),
3233
#[prost(string, tag = "10")]
33-
CsvError(String),
34+
ArithmeticOverflow(String),
3435
#[prost(string, tag = "11")]
36+
CsvError(String),
37+
#[prost(string, tag = "12")]
3538
JsonError(String),
36-
#[prost(message, tag = "12")]
37-
IoError(IoErrorProto),
3839
#[prost(message, tag = "13")]
39-
IpcError(String),
40+
IoError(IoErrorProto),
4041
#[prost(message, tag = "14")]
41-
InvalidArgumentError(String),
42+
IpcError(String),
4243
#[prost(message, tag = "15")]
43-
ParquetError(String),
44+
InvalidArgumentError(String),
4445
#[prost(message, tag = "16")]
46+
ParquetError(String),
47+
#[prost(message, tag = "17")]
4548
CDataInterface(String),
46-
#[prost(bool, tag = "17")]
47-
DictionaryKeyOverflowError(bool),
4849
#[prost(bool, tag = "18")]
50+
DictionaryKeyOverflowError(bool),
51+
#[prost(bool, tag = "19")]
4952
RunEndIndexOverflowError(bool),
5053
}
5154

@@ -180,6 +183,7 @@ impl ArrowErrorProto {
180183
#[cfg(test)]
181184
mod tests {
182185
use super::*;
186+
use prost::Message;
183187
use std::io::{Error as IoError, ErrorKind};
184188

185189
#[test]
@@ -216,12 +220,20 @@ mod tests {
216220
&original_error,
217221
Some(&"test context".to_string()),
218222
);
223+
let proto = ArrowErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap();
219224
let (recovered_error, recovered_ctx) = proto.to_arrow_error();
220225

226+
if original_error.to_string() != recovered_error.to_string() {
227+
println!("original error: {}", original_error.to_string());
228+
println!("recovered error: {}", recovered_error.to_string());
229+
}
230+
221231
assert_eq!(original_error.to_string(), recovered_error.to_string());
222232
assert_eq!(recovered_ctx, Some("test context".to_string()));
223233

224234
let proto_no_ctx = ArrowErrorProto::from_arrow_error(&original_error, None);
235+
let proto_no_ctx =
236+
ArrowErrorProto::decode(proto_no_ctx.encode_to_vec().as_ref()).unwrap();
225237
let (recovered_error_no_ctx, recovered_ctx_no_ctx) = proto_no_ctx.to_arrow_error();
226238

227239
assert_eq!(

src/errors/datafusion_error.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ use std::sync::Arc;
1111

1212
#[derive(Clone, PartialEq, ::prost::Message)]
1313
pub struct DataFusionErrorProto {
14-
#[prost(oneof = "DataFusionErrorInnerProto", tags = "1")]
14+
#[prost(
15+
oneof = "DataFusionErrorInnerProto",
16+
tags = "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19"
17+
)]
1518
pub inner: Option<DataFusionErrorInnerProto>,
1619
}
1720

@@ -175,9 +178,7 @@ impl DataFusionErrorProto {
175178

176179
pub fn to_datafusion_err(&self) -> DataFusionError {
177180
let Some(ref inner) = self.inner else {
178-
return DataFusionError::Internal(
179-
"Malformed DataFusion error proto message".to_string(),
180-
);
181+
return DataFusionError::Internal("DataFusionError proto message is empty".to_string());
181182
};
182183

183184
match inner {
@@ -274,6 +275,7 @@ mod tests {
274275
use datafusion::logical_expr::sqlparser::parser::ParserError;
275276
use datafusion::parquet::errors::ParquetError;
276277
use object_store::Error as ObjectStoreError;
278+
use prost::Message;
277279
use std::io::{Error as IoError, ErrorKind};
278280
use std::sync::Arc;
279281

@@ -326,6 +328,7 @@ mod tests {
326328

327329
for original_error in test_cases {
328330
let proto = DataFusionErrorProto::from_datafusion_error(&original_error);
331+
let proto = DataFusionErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap();
329332
let recovered_error = proto.to_datafusion_err();
330333

331334
assert_eq!(original_error.to_string(), recovered_error.to_string());
@@ -350,6 +353,7 @@ mod tests {
350353
);
351354

352355
let proto = DataFusionErrorProto::from_datafusion_error(&nested_error);
356+
let proto = DataFusionErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap();
353357
let recovered_error = proto.to_datafusion_err();
354358

355359
assert_eq!(nested_error.to_string(), recovered_error.to_string());
@@ -364,6 +368,7 @@ mod tests {
364368
]);
365369

366370
let proto = DataFusionErrorProto::from_datafusion_error(&collection_error);
371+
let proto = DataFusionErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap();
367372
let recovered_error = proto.to_datafusion_err();
368373

369374
assert_eq!(collection_error.to_string(), recovered_error.to_string());
@@ -377,6 +382,7 @@ mod tests {
377382
);
378383

379384
let proto = DataFusionErrorProto::from_datafusion_error(&sql_error);
385+
let proto = DataFusionErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap();
380386
let recovered_error = proto.to_datafusion_err();
381387

382388
if let DataFusionError::SQL(_, backtrace) = recovered_error {

src/errors/io_error.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ mod tests {
140140
for (kind, msg) in test_cases {
141141
let original_error = IoError::new(kind, msg);
142142
let proto = IoErrorProto::from_io_error("test message", &original_error);
143+
let proto = IoErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap();
143144
let (recovered_error, recovered_message) = proto.to_io_error();
144145

145146
assert_eq!(original_error.kind(), recovered_error.kind());

src/errors/mod.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ mod parser_error;
1212
mod schema_error;
1313

1414
pub fn datafusion_error_to_tonic_status(err: &DataFusionError) -> tonic::Status {
15-
let err = DataFusionErrorProto::from_datafusion_error(err).encode_to_vec();
15+
let err = DataFusionErrorProto::from_datafusion_error(err);
16+
let err = err.encode_to_vec();
1617
let status = tonic::Status::with_details(tonic::Code::Internal, "DataFusionError", err.into());
1718
status
1819
}
@@ -27,7 +28,10 @@ pub fn tonic_status_to_datafusion_error(status: &tonic::Status) -> Option<DataFu
2728
}
2829

2930
match DataFusionErrorProto::decode(status.details()) {
30-
Ok(err_proto) => Some(err_proto.to_datafusion_err()),
31+
Ok(err_proto) => {
32+
dbg!(&err_proto);
33+
Some(err_proto.to_datafusion_err())
34+
}
3135
Err(err) => Some(internal_datafusion_err!(
3236
"Cannot decode DataFusionError: {err}"
3337
)),

src/errors/objectstore_error.rs

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#[derive(Clone, PartialEq, ::prost::Message)]
22
pub struct ObjectStoreErrorProto {
3-
#[prost(oneof = "ObjectStoreErrorInnerProto", tags = "1")]
3+
#[prost(
4+
oneof = "ObjectStoreErrorInnerProto",
5+
tags = "1,2,3,4,5,6,7,8,9,10,11,12"
6+
)]
47
pub inner: Option<ObjectStoreErrorInnerProto>,
58
}
69

@@ -255,6 +258,7 @@ fn parse_store(store: &str) -> &'static str {
255258
mod tests {
256259
use super::*;
257260
use object_store::Error as ObjectStoreError;
261+
use prost::Message;
258262
use std::io::ErrorKind;
259263

260264
#[test]
@@ -304,30 +308,13 @@ mod tests {
304308

305309
for original_error in test_cases {
306310
let proto = ObjectStoreErrorProto::from_object_store_error(&original_error);
311+
let proto = ObjectStoreErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap();
307312
let recovered_error = proto.to_object_store_error();
308313

309314
assert_eq!(original_error.to_string(), recovered_error.to_string());
310315
}
311316
}
312317

313-
#[test]
314-
fn test_unknown_store_handling() {
315-
// Test that unknown store names get mapped to "Unknown"
316-
let original_error = ObjectStoreError::Generic {
317-
store: "unknown_store",
318-
source: Box::new(std::io::Error::new(ErrorKind::Other, "generic error")),
319-
};
320-
let proto = ObjectStoreErrorProto::from_object_store_error(&original_error);
321-
let recovered_error = proto.to_object_store_error();
322-
323-
// The store name will be changed from "unknown_store" to "Unknown"
324-
assert_eq!(
325-
recovered_error.to_string(),
326-
"Generic Unknown error: generic error"
327-
);
328-
assert_ne!(original_error.to_string(), recovered_error.to_string());
329-
}
330-
331318
#[test]
332319
fn test_malformed_protobuf_message() {
333320
let malformed_proto = ObjectStoreErrorProto { inner: None };

src/errors/parquet_error.rs

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use datafusion::parquet::errors::ParquetError;
22

33
#[derive(Clone, PartialEq, ::prost::Message)]
44
pub struct ParquetErrorProto {
5-
#[prost(oneof = "ParquetErrorInnerProto", tags = "1")]
5+
#[prost(oneof = "ParquetErrorInnerProto", tags = "1,2,3,4,5,6,7")]
66
pub inner: Option<ParquetErrorInnerProto>,
77
}
88

@@ -113,22 +113,13 @@ mod tests {
113113

114114
for original_error in test_cases {
115115
let proto = ParquetErrorProto::from_parquet_error(&original_error);
116+
let proto = ParquetErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap();
116117
let recovered_error = proto.to_parquet_error();
117118

118119
assert_eq!(original_error.to_string(), recovered_error.to_string());
119120
}
120121
}
121122

122-
#[test]
123-
fn test_protobuf_serialization() {
124-
let original_error = ParquetError::General("general error".to_string());
125-
let proto = ParquetErrorProto::from_parquet_error(&original_error);
126-
let proto = ParquetErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap();
127-
let recovered_error = proto.to_parquet_error();
128-
129-
assert_eq!(original_error.to_string(), recovered_error.to_string());
130-
}
131-
132123
#[test]
133124
fn test_malformed_protobuf_message() {
134125
let malformed_proto = ParquetErrorProto { inner: None };

src/errors/parser_error.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use datafusion::sql::sqlparser::parser::ParserError;
22

33
#[derive(Clone, PartialEq, ::prost::Message)]
44
pub struct ParserErrorProto {
5-
#[prost(oneof = "ParserErrorInnerProto", tags = "1")]
5+
#[prost(oneof = "ParserErrorInnerProto", tags = "1,2,3")]
66
pub inner: Option<ParserErrorInnerProto>,
77
}
88

@@ -50,6 +50,7 @@ impl ParserErrorProto {
5050
mod tests {
5151
use super::*;
5252
use datafusion::sql::sqlparser::parser::ParserError;
53+
use prost::Message;
5354

5455
#[test]
5556
fn test_parser_error_roundtrip() {
@@ -61,6 +62,7 @@ mod tests {
6162

6263
for original_error in test_cases {
6364
let proto = ParserErrorProto::from_parser_error(&original_error);
65+
let proto = ParserErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap();
6466
let recovered_error = proto.to_parser_error();
6567

6668
assert_eq!(original_error.to_string(), recovered_error.to_string());

src/errors/schema_error.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@ use datafusion::common::{Column, SchemaError, TableReference};
22

33
#[derive(Clone, PartialEq, ::prost::Message)]
44
pub struct SchemaErrorProto {
5-
#[prost(oneof = "SchemaErrorInnerProto", tags = "1")]
6-
pub inner: Option<SchemaErrorInnerProto>,
7-
#[prost(string, optional, tag = "2")]
5+
#[prost(string, optional, tag = "1")]
86
pub backtrace: Option<String>,
7+
#[prost(oneof = "SchemaErrorInnerProto", tags = "2,3,4,5")]
8+
pub inner: Option<SchemaErrorInnerProto>,
99
}
1010

1111
#[derive(Clone, PartialEq, prost::Oneof)]
1212
pub enum SchemaErrorInnerProto {
13-
#[prost(message, tag = "1")]
14-
AmbiguousReference(AmbiguousReferenceProto),
1513
#[prost(message, tag = "2")]
16-
DuplicateQualifiedField(DuplicateQualifiedFieldProto),
14+
AmbiguousReference(AmbiguousReferenceProto),
1715
#[prost(message, tag = "3")]
18-
DuplicateUnqualifiedField(DuplicateUnqualifiedFieldProto),
16+
DuplicateQualifiedField(DuplicateQualifiedFieldProto),
1917
#[prost(message, tag = "4")]
18+
DuplicateUnqualifiedField(DuplicateUnqualifiedFieldProto),
19+
#[prost(message, tag = "5")]
2020
FieldNotFound(FieldNotFoundProto),
2121
}
2222

@@ -75,7 +75,7 @@ impl ColumnProto {
7575

7676
#[derive(Clone, PartialEq, ::prost::Message)]
7777
pub struct TableReferenceProto {
78-
#[prost(oneof = "TableReferenceInnerProto", tags = "1")]
78+
#[prost(oneof = "TableReferenceInnerProto", tags = "1,2,3")]
7979
pub inner: Option<TableReferenceInnerProto>,
8080
}
8181

@@ -260,6 +260,7 @@ impl SchemaErrorProto {
260260
mod tests {
261261
use super::*;
262262
use datafusion::common::{Column, SchemaError, TableReference};
263+
use prost::Message;
263264

264265
#[test]
265266
fn test_schema_error_roundtrip() {
@@ -291,12 +292,15 @@ mod tests {
291292
&original_error,
292293
Some(&"test backtrace".to_string()),
293294
);
295+
let proto = SchemaErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap();
294296
let (recovered_error, recovered_backtrace) = proto.to_schema_error();
295297

296298
assert_eq!(original_error.to_string(), recovered_error.to_string());
297299
assert_eq!(recovered_backtrace, Some("test backtrace".to_string()));
298300

299301
let proto_no_backtrace = SchemaErrorProto::from_schema_error(&original_error, None);
302+
let proto_no_backtrace =
303+
SchemaErrorProto::decode(proto_no_backtrace.encode_to_vec().as_ref()).unwrap();
300304
let (recovered_error_no_backtrace, recovered_backtrace_no_backtrace) =
301305
proto_no_backtrace.to_schema_error();
302306

@@ -328,6 +332,7 @@ mod tests {
328332

329333
for original_ref in test_cases {
330334
let proto = TableReferenceProto::from_table_reference(&original_ref);
335+
let proto = TableReferenceProto::decode(proto.encode_to_vec().as_ref()).unwrap();
331336
let recovered_ref = proto.to_table_reference();
332337

333338
assert_eq!(original_ref.to_string(), recovered_ref.to_string());
@@ -344,6 +349,7 @@ mod tests {
344349

345350
for original_column in test_cases {
346351
let proto = ColumnProto::from_column(&original_column);
352+
let proto = ColumnProto::decode(proto.encode_to_vec().as_ref()).unwrap();
347353
let recovered_column = proto.to_column();
348354

349355
assert_eq!(original_column.name, recovered_column.name);

0 commit comments

Comments
 (0)