Skip to content

Commit dbb7da5

Browse files
authored
refactor: extract common logic out of split functions with style fix (#1058)
1 parent 0353098 commit dbb7da5

File tree

3 files changed

+77
-182
lines changed

3 files changed

+77
-182
lines changed

src/ops/functions/split_by_separators.rs

Lines changed: 11 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,22 @@ use anyhow::{Context, Result};
22
use regex::Regex;
33
use std::sync::Arc;
44

5-
use crate::base::field_attrs;
65
use crate::ops::registry::ExecutorFactoryRegistry;
7-
use crate::ops::shared::split::{Position, set_output_positions};
6+
use crate::ops::shared::split::{Position, make_common_chunk_schema, set_output_positions};
87
use crate::{fields_value, ops::sdk::*};
98

109
#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
1110
#[serde(rename_all = "UPPERCASE")]
1211
enum KeepSep {
13-
NONE,
14-
LEFT,
15-
RIGHT,
12+
Left,
13+
Right,
1614
}
1715

1816
#[derive(Serialize, Deserialize)]
1917
struct Spec {
2018
// Python SDK provides defaults/values.
2119
separators_regex: Vec<String>,
22-
keep_separator: KeepSep,
20+
keep_separator: Option<KeepSep>,
2321
include_empty: bool,
2422
trim: bool,
2523
}
@@ -90,13 +88,13 @@ impl SimpleFunctionExecutor for Executor {
9088
let mut start = 0usize;
9189
for m in re.find_iter(full_text) {
9290
let end = match self.spec.keep_separator {
93-
KeepSep::LEFT => m.end(),
94-
KeepSep::NONE | KeepSep::RIGHT => m.start(),
91+
Some(KeepSep::Left) => m.end(),
92+
Some(KeepSep::Right) | None => m.start(),
9593
};
9694
add_range(start, end);
9795
start = match self.spec.keep_separator {
98-
KeepSep::RIGHT => m.start(),
99-
KeepSep::NONE | KeepSep::LEFT => m.end(),
96+
Some(KeepSep::Right) => m.start(),
97+
_ => m.end(),
10098
};
10199
}
102100
add_range(start, full_text.len());
@@ -154,50 +152,7 @@ impl SimpleFunctionFactoryBase for Factory {
154152
.required()?,
155153
};
156154

157-
// start/end structs exactly like SplitRecursively
158-
let pos_struct = schema::ValueType::Struct(schema::StructSchema {
159-
fields: Arc::new(vec![
160-
schema::FieldSchema::new("offset", make_output_type(BasicValueType::Int64)),
161-
schema::FieldSchema::new("line", make_output_type(BasicValueType::Int64)),
162-
schema::FieldSchema::new("column", make_output_type(BasicValueType::Int64)),
163-
]),
164-
description: None,
165-
});
166-
167-
let mut struct_schema = StructSchema::default();
168-
let mut sb = StructSchemaBuilder::new(&mut struct_schema);
169-
sb.add_field(FieldSchema::new(
170-
"location",
171-
make_output_type(BasicValueType::Range),
172-
));
173-
sb.add_field(FieldSchema::new(
174-
"text",
175-
make_output_type(BasicValueType::Str),
176-
));
177-
sb.add_field(FieldSchema::new(
178-
"start",
179-
schema::EnrichedValueType {
180-
typ: pos_struct.clone(),
181-
nullable: false,
182-
attrs: Default::default(),
183-
},
184-
));
185-
sb.add_field(FieldSchema::new(
186-
"end",
187-
schema::EnrichedValueType {
188-
typ: pos_struct,
189-
nullable: false,
190-
attrs: Default::default(),
191-
},
192-
));
193-
let output_schema = make_output_type(TableSchema::new(
194-
TableKind::KTable(KTableInfo { num_key_parts: 1 }),
195-
struct_schema,
196-
))
197-
.with_attr(
198-
field_attrs::CHUNK_BASE_TEXT,
199-
serde_json::to_value(args_resolver.get_analyze_value(&args.text))?,
200-
);
155+
let output_schema = make_common_chunk_schema(args_resolver, &args.text)?;
201156
Ok((args, output_schema))
202157
}
203158

@@ -224,7 +179,7 @@ mod tests {
224179
async fn test_split_by_separators_paragraphs() {
225180
let spec = Spec {
226181
separators_regex: vec![r"\n\n+".to_string()],
227-
keep_separator: KeepSep::NONE,
182+
keep_separator: None,
228183
include_empty: false,
229184
trim: true,
230185
};
@@ -268,7 +223,7 @@ mod tests {
268223
async fn test_split_by_separators_keep_right() {
269224
let spec = Spec {
270225
separators_regex: vec![r"\.".to_string()],
271-
keep_separator: KeepSep::RIGHT,
226+
keep_separator: Some(KeepSep::Right),
272227
include_empty: false,
273228
trim: true,
274229
};

src/ops/functions/split_recursively.rs

Lines changed: 4 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ use std::sync::LazyLock;
66
use std::{collections::HashMap, sync::Arc};
77
use unicase::UniCase;
88

9-
use crate::base::field_attrs;
10-
use crate::ops::registry::ExecutorFactoryRegistry;
9+
use crate::ops::shared::split::{Position, set_output_positions};
1110
use crate::{fields_value, ops::sdk::*};
1211

1312
#[derive(Serialize, Deserialize)]
@@ -479,36 +478,6 @@ impl<'s> AtomChunksCollector<'s> {
479478
}
480479
}
481480

482-
#[derive(Debug, Clone, PartialEq, Eq)]
483-
struct OutputPosition {
484-
char_offset: usize,
485-
line: u32,
486-
column: u32,
487-
}
488-
489-
impl OutputPosition {
490-
fn into_output(self) -> value::Value {
491-
value::Value::Struct(fields_value!(
492-
self.char_offset as i64,
493-
self.line as i64,
494-
self.column as i64
495-
))
496-
}
497-
}
498-
struct Position {
499-
byte_offset: usize,
500-
output: Option<OutputPosition>,
501-
}
502-
503-
impl Position {
504-
fn new(byte_offset: usize) -> Self {
505-
Self {
506-
byte_offset,
507-
output: None,
508-
}
509-
}
510-
}
511-
512481
struct ChunkOutput<'s> {
513482
start_pos: Position,
514483
end_pos: Position,
@@ -826,55 +795,6 @@ impl Executor {
826795
}
827796
}
828797

829-
fn set_output_positions<'a>(text: &str, positions: impl Iterator<Item = &'a mut Position>) {
830-
let mut positions = positions.collect::<Vec<_>>();
831-
positions.sort_by_key(|o| o.byte_offset);
832-
833-
let mut positions_iter = positions.iter_mut();
834-
let Some(mut next_position) = positions_iter.next() else {
835-
return;
836-
};
837-
838-
let mut char_offset = 0;
839-
let mut line = 1;
840-
let mut column = 1;
841-
for (byte_offset, ch) in text.char_indices() {
842-
while next_position.byte_offset == byte_offset {
843-
next_position.output = Some(OutputPosition {
844-
char_offset,
845-
line,
846-
column,
847-
});
848-
if let Some(position) = positions_iter.next() {
849-
next_position = position;
850-
} else {
851-
return;
852-
}
853-
}
854-
char_offset += 1;
855-
if ch == '\n' {
856-
line += 1;
857-
column = 1;
858-
} else {
859-
column += 1;
860-
}
861-
}
862-
863-
// Offsets after the last char.
864-
loop {
865-
next_position.output = Some(OutputPosition {
866-
char_offset,
867-
line,
868-
column,
869-
});
870-
if let Some(position) = positions_iter.next() {
871-
next_position = position;
872-
} else {
873-
return;
874-
}
875-
}
876-
}
877-
878798
#[async_trait]
879799
impl SimpleFunctionExecutor for Executor {
880800
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
@@ -997,49 +917,8 @@ impl SimpleFunctionFactoryBase for Factory {
997917
.optional(),
998918
};
999919

1000-
let pos_struct = schema::ValueType::Struct(schema::StructSchema {
1001-
fields: Arc::new(vec![
1002-
schema::FieldSchema::new("offset", make_output_type(BasicValueType::Int64)),
1003-
schema::FieldSchema::new("line", make_output_type(BasicValueType::Int64)),
1004-
schema::FieldSchema::new("column", make_output_type(BasicValueType::Int64)),
1005-
]),
1006-
description: None,
1007-
});
1008-
1009-
let mut struct_schema = StructSchema::default();
1010-
let mut schema_builder = StructSchemaBuilder::new(&mut struct_schema);
1011-
schema_builder.add_field(FieldSchema::new(
1012-
"location",
1013-
make_output_type(BasicValueType::Range),
1014-
));
1015-
schema_builder.add_field(FieldSchema::new(
1016-
"text",
1017-
make_output_type(BasicValueType::Str),
1018-
));
1019-
schema_builder.add_field(FieldSchema::new(
1020-
"start",
1021-
schema::EnrichedValueType {
1022-
typ: pos_struct.clone(),
1023-
nullable: false,
1024-
attrs: Default::default(),
1025-
},
1026-
));
1027-
schema_builder.add_field(FieldSchema::new(
1028-
"end",
1029-
schema::EnrichedValueType {
1030-
typ: pos_struct,
1031-
nullable: false,
1032-
attrs: Default::default(),
1033-
},
1034-
));
1035-
let output_schema = make_output_type(TableSchema::new(
1036-
TableKind::KTable(KTableInfo { num_key_parts: 1 }),
1037-
struct_schema,
1038-
))
1039-
.with_attr(
1040-
field_attrs::CHUNK_BASE_TEXT,
1041-
serde_json::to_value(args_resolver.get_analyze_value(&args.text))?,
1042-
);
920+
let output_schema =
921+
crate::ops::shared::split::make_common_chunk_schema(args_resolver, &args.text)?;
1043922
Ok((args, output_schema))
1044923
}
1045924

@@ -1060,7 +939,7 @@ pub fn register(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
1060939
#[cfg(test)]
1061940
mod tests {
1062941
use super::*;
1063-
use crate::ops::functions::test_utils::test_flow_function;
942+
use crate::ops::{functions::test_utils::test_flow_function, shared::split::OutputPosition};
1064943

1065944
// Helper function to assert chunk text and its consistency with the range within the original text.
1066945
fn assert_chunk_text_consistency(

src/ops/shared/split.rs

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
use crate::{fields_value, ops::sdk::value};
1+
use crate::{
2+
base::field_attrs,
3+
fields_value,
4+
ops::sdk::value,
5+
ops::sdk::{
6+
BasicValueType, EnrichedValueType, FieldSchema, KTableInfo, OpArgsResolver, StructSchema,
7+
StructSchemaBuilder, TableKind, TableSchema, make_output_type, schema,
8+
},
9+
};
10+
use anyhow::Result;
211

312
#[derive(Debug, Clone, PartialEq, Eq)]
413
pub struct OutputPosition {
@@ -79,3 +88,55 @@ pub fn set_output_positions<'a>(text: &str, positions: impl Iterator<Item = &'a
7988
}
8089
}
8190
}
91+
92+
/// Build the common chunk output schema used by splitters.
93+
/// Fields: `location: Range`, `text: Str`, `start: {offset,line,column}`, `end: {offset,line,column}`.
94+
pub fn make_common_chunk_schema<'a>(
95+
args_resolver: &OpArgsResolver<'a>,
96+
text_arg: &crate::ops::sdk::ResolvedOpArg,
97+
) -> Result<EnrichedValueType> {
98+
let pos_struct = schema::ValueType::Struct(schema::StructSchema {
99+
fields: std::sync::Arc::new(vec![
100+
schema::FieldSchema::new("offset", make_output_type(BasicValueType::Int64)),
101+
schema::FieldSchema::new("line", make_output_type(BasicValueType::Int64)),
102+
schema::FieldSchema::new("column", make_output_type(BasicValueType::Int64)),
103+
]),
104+
description: None,
105+
});
106+
107+
let mut struct_schema = StructSchema::default();
108+
let mut sb = StructSchemaBuilder::new(&mut struct_schema);
109+
sb.add_field(FieldSchema::new(
110+
"location",
111+
make_output_type(BasicValueType::Range),
112+
));
113+
sb.add_field(FieldSchema::new(
114+
"text",
115+
make_output_type(BasicValueType::Str),
116+
));
117+
sb.add_field(FieldSchema::new(
118+
"start",
119+
schema::EnrichedValueType {
120+
typ: pos_struct.clone(),
121+
nullable: false,
122+
attrs: Default::default(),
123+
},
124+
));
125+
sb.add_field(FieldSchema::new(
126+
"end",
127+
schema::EnrichedValueType {
128+
typ: pos_struct,
129+
nullable: false,
130+
attrs: Default::default(),
131+
},
132+
));
133+
let output_schema = make_output_type(TableSchema::new(
134+
TableKind::KTable(KTableInfo { num_key_parts: 1 }),
135+
struct_schema,
136+
))
137+
.with_attr(
138+
field_attrs::CHUNK_BASE_TEXT,
139+
serde_json::to_value(args_resolver.get_analyze_value(text_arg))?,
140+
);
141+
Ok(output_schema)
142+
}

0 commit comments

Comments
 (0)