From a860defa2c9faa3d9d513bdc31bc12e123bfb06c Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Sat, 4 Oct 2025 13:34:13 +0200 Subject: [PATCH] Use upstream composed extension codec --- src/common/composed_extension_codec.rs | 126 ------------------------- src/common/mod.rs | 2 - src/protobuf/distributed_codec.rs | 3 +- 3 files changed, 1 insertion(+), 130 deletions(-) delete mode 100644 src/common/composed_extension_codec.rs diff --git a/src/common/composed_extension_codec.rs b/src/common/composed_extension_codec.rs deleted file mode 100644 index d081989..0000000 --- a/src/common/composed_extension_codec.rs +++ /dev/null @@ -1,126 +0,0 @@ -use datafusion::common::internal_datafusion_err; -use datafusion::error::DataFusionError; -use datafusion::error::Result; -use datafusion::execution::FunctionRegistry; -use datafusion::logical_expr::{AggregateUDF, ScalarUDF}; -use datafusion::physical_plan::ExecutionPlan; -use datafusion_proto::physical_plan::PhysicalExtensionCodec; -use prost::Message; -use std::fmt::Debug; -use std::sync::Arc; -// Code taken from https://github.com/apache/datafusion/blob/10f41887fa40d7d425c19b07857f80115460a98e/datafusion/proto/src/physical_plan/mod.rs -// TODO: It's not yet on DF 49, once upgrading to DF 50 we can remove this - -/// DataEncoderTuple captures the position of the encoder -/// in the codec list that was used to encode the data and actual encoded data -#[derive(Clone, PartialEq, prost::Message)] -struct DataEncoderTuple { - /// The position of encoder used to encode data - /// (to be used for decoding) - #[prost(uint32, tag = 1)] - pub encoder_position: u32, - - #[prost(bytes, tag = 2)] - pub blob: Vec, -} - -/// A PhysicalExtensionCodec that tries one of multiple inner codecs -/// until one works -#[derive(Debug)] -pub struct ComposedPhysicalExtensionCodec { - codecs: Vec>, -} - -impl ComposedPhysicalExtensionCodec { - // Position in this codecs list is important as it will be used for decoding. - // If new codec is added it should go to last position. - pub fn new(codecs: Vec>) -> Self { - Self { codecs } - } - - fn decode_protobuf( - &self, - buf: &[u8], - decode: impl FnOnce(&dyn PhysicalExtensionCodec, &[u8]) -> Result, - ) -> Result { - let proto = - DataEncoderTuple::decode(buf).map_err(|e| DataFusionError::Internal(e.to_string()))?; - - let pos = proto.encoder_position as usize; - let codec = self.codecs.get(pos).ok_or_else(|| { - internal_datafusion_err!( - "Can't find required codec in position {pos} in codec list with {} elements", - self.codecs.len() - ) - })?; - - decode(codec.as_ref(), &proto.blob) - } - - fn encode_protobuf( - &self, - buf: &mut Vec, - mut encode: impl FnMut(&dyn PhysicalExtensionCodec, &mut Vec) -> Result<()>, - ) -> Result<(), DataFusionError> { - let mut data = vec![]; - let mut last_err = None; - let mut encoder_position = None; - - // find the encoder - for (position, codec) in self.codecs.iter().enumerate() { - match encode(codec.as_ref(), &mut data) { - Ok(_) => { - encoder_position = Some(position as u32); - break; - } - Err(err) => last_err = Some(err), - } - } - - let encoder_position = encoder_position.ok_or_else(|| { - last_err.unwrap_or_else(|| { - DataFusionError::NotImplemented("Empty list of composed codecs".to_owned()) - }) - })?; - - // encode with encoder position - let proto = DataEncoderTuple { - encoder_position, - blob: data, - }; - proto - .encode(buf) - .map_err(|e| DataFusionError::Internal(e.to_string())) - } -} - -impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { - fn try_decode( - &self, - buf: &[u8], - inputs: &[Arc], - registry: &dyn FunctionRegistry, - ) -> Result> { - self.decode_protobuf(buf, |codec, data| codec.try_decode(data, inputs, registry)) - } - - fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { - self.encode_protobuf(buf, |codec, data| codec.try_encode(Arc::clone(&node), data)) - } - - fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { - self.decode_protobuf(buf, |codec, data| codec.try_decode_udf(name, data)) - } - - fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { - self.encode_protobuf(buf, |codec, data| codec.try_encode_udf(node, data)) - } - - fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { - self.decode_protobuf(buf, |codec, data| codec.try_decode_udaf(name, data)) - } - - fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { - self.encode_protobuf(buf, |codec, data| codec.try_encode_udaf(node, data)) - } -} diff --git a/src/common/mod.rs b/src/common/mod.rs index 2cdf6ee..fe9773f 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,9 +1,7 @@ mod callback_stream; -mod composed_extension_codec; mod partitioning; #[allow(unused)] pub mod ttl_map; pub(crate) use callback_stream::with_callback; -pub(crate) use composed_extension_codec::ComposedPhysicalExtensionCodec; pub(crate) use partitioning::{scale_partitioning, scale_partitioning_props}; diff --git a/src/protobuf/distributed_codec.rs b/src/protobuf/distributed_codec.rs index 66de577..4107f46 100644 --- a/src/protobuf/distributed_codec.rs +++ b/src/protobuf/distributed_codec.rs @@ -1,5 +1,4 @@ use super::get_distributed_user_codecs; -use crate::common::ComposedPhysicalExtensionCodec; use crate::execution_plans::{NetworkCoalesceExec, NetworkCoalesceReady, NetworkShuffleReadyExec}; use crate::{NetworkShuffleExec, PartitionIsolatorExec}; use datafusion::arrow::datatypes::Schema; @@ -9,9 +8,9 @@ use datafusion::physical_expr::EquivalenceProperties; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::{ExecutionPlan, Partitioning, PlanProperties}; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::physical_plan::from_proto::parse_protobuf_partitioning; use datafusion_proto::physical_plan::to_proto::serialize_partitioning; +use datafusion_proto::physical_plan::{ComposedPhysicalExtensionCodec, PhysicalExtensionCodec}; use datafusion_proto::protobuf; use datafusion_proto::protobuf::proto_error; use prost::Message;