Skip to content

Commit db5a804

Browse files
committed
fix: pass ExtractionResult instances to Python plugins, fix CI workflows
- Fix PostProcessor and Validator bridges to pass ExtractionResult class instances instead of plain dicts, matching the documented Python API - Support both ExtractionResult (attribute access) and dict returns for backward compatibility in PostProcessor bridge - Update all Python doc snippets to use attribute access consistently - Add golangci-lint installation to ci-go.yaml for e2e test generation - Fix Deno e2e helpers to handle undefined images/pages/elements fields - Add lint script to wasm-workers e2e package.json
1 parent 957d34a commit db5a804

File tree

14 files changed

+123
-179
lines changed

14 files changed

+123
-179
lines changed

.github/workflows/ci-go.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,12 @@ jobs:
599599
PKG_CONFIG_PATH: ${{ env.PKG_CONFIG_PATH }}
600600
GO_TEST_FLAGS: ${{ runner.os == 'Windows' && '-ldflags=-linkmode=external -extldflags=-Wl,--verbose' || '' }}
601601

602+
- name: Install golangci-lint
603+
shell: bash
604+
run: |
605+
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b "$(go env GOPATH)/bin" v2.9.0
606+
echo "$(go env GOPATH)/bin" >> "$GITHUB_PATH"
607+
602608
- name: Generate Go E2E tests
603609
shell: bash
604610
run: |

crates/kreuzberg-py/src/plugins/processor_bridge.rs

Lines changed: 55 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ use kreuzberg::plugins::{Plugin, PostProcessor, ProcessingStage};
1414
use kreuzberg::types::ExtractionResult;
1515
use kreuzberg::{KreuzbergError, Result};
1616

17-
use super::common::{json_value_to_py, python_to_json, validate_plugin_object};
17+
use crate::types::ExtractionResult as PyExtractionResult;
18+
19+
use super::common::{python_to_json, validate_plugin_object};
1820

1921
/// Wrapper that makes a Python PostProcessor usable from Rust.
2022
///
@@ -133,26 +135,29 @@ impl PostProcessor for PythonPostProcessor {
133135
Python::attach(|py| {
134136
let obj = self.python_obj.bind(py);
135137

136-
let result_dict = extraction_result_to_dict(py, result).map_err(|e| KreuzbergError::Plugin {
137-
message: format!("Failed to convert ExtractionResult to Python dict: {}", e),
138+
// Convert Rust ExtractionResult to Python ExtractionResult class instance
139+
let py_extraction_result =
140+
PyExtractionResult::from_rust(result.clone(), py, None, None).map_err(|e| {
141+
KreuzbergError::Plugin {
142+
message: format!("Failed to convert ExtractionResult to Python: {}", e),
143+
plugin_name: processor_name.clone(),
144+
}
145+
})?;
146+
147+
let py_result_obj = Py::new(py, py_extraction_result).map_err(|e| KreuzbergError::Plugin {
148+
message: format!("Failed to create Python ExtractionResult: {}", e),
138149
plugin_name: processor_name.clone(),
139150
})?;
140151

141-
let py_result = result_dict.bind(py);
142152
let processed = obj
143-
.call_method1("process", (py_result,))
153+
.call_method1("process", (py_result_obj,))
144154
.map_err(|e| KreuzbergError::Plugin {
145155
message: format!("Python PostProcessor '{}' failed during process: {}", processor_name, e),
146156
plugin_name: processor_name.clone(),
147157
})?;
148158

149-
let processed_dict = processed.cast_into::<PyDict>().map_err(|e| KreuzbergError::Plugin {
150-
message: format!("PostProcessor did not return a dict: {}", e),
151-
plugin_name: processor_name.clone(),
152-
})?;
153-
154159
let mut updated_result = result.clone();
155-
merge_dict_to_extraction_result(py, &processed_dict, &mut updated_result)?;
160+
merge_processed_result(py, &processed, &mut updated_result)?;
156161

157162
Ok::<ExtractionResult, KreuzbergError>(updated_result)
158163
})
@@ -167,67 +172,45 @@ impl PostProcessor for PythonPostProcessor {
167172
}
168173
}
169174

170-
/// Convert Rust ExtractionResult to Python dict.
175+
/// Merge a processed Python result back into a Rust ExtractionResult.
171176
///
172-
/// This creates a Python dict that can be passed to Python processors:
173-
/// ```python
174-
/// {
175-
/// "content": "extracted text",
176-
/// "mime_type": "application/pdf",
177-
/// "metadata": {"key": "value"},
178-
/// "tables": [...]
179-
/// }
180-
/// ```
181-
fn extraction_result_to_dict(py: Python<'_>, result: &ExtractionResult) -> PyResult<Py<PyDict>> {
182-
let dict = PyDict::new(py);
183-
184-
dict.set_item("content", &result.content)?;
185-
186-
dict.set_item("mime_type", &result.mime_type)?;
187-
188-
let metadata_dict = PyDict::new(py);
189-
190-
if let Some(title) = &result.metadata.title {
191-
metadata_dict.set_item("title", title)?;
192-
}
193-
if let Some(subject) = &result.metadata.subject {
194-
metadata_dict.set_item("subject", subject)?;
195-
}
196-
if let Some(authors) = &result.metadata.authors {
197-
metadata_dict.set_item("authors", authors)?;
198-
}
199-
if let Some(keywords) = &result.metadata.keywords {
200-
metadata_dict.set_item("keywords", keywords)?;
201-
}
202-
if let Some(language) = &result.metadata.language {
203-
metadata_dict.set_item("language", language)?;
204-
}
205-
if let Some(created_at) = &result.metadata.created_at {
206-
metadata_dict.set_item("created_at", created_at)?;
207-
}
208-
if let Some(modified_at) = &result.metadata.modified_at {
209-
metadata_dict.set_item("modified_at", modified_at)?;
210-
}
211-
if let Some(created_by) = &result.metadata.created_by {
212-
metadata_dict.set_item("created_by", created_by)?;
213-
}
214-
if let Some(modified_by) = &result.metadata.modified_by {
215-
metadata_dict.set_item("modified_by", modified_by)?;
216-
}
217-
if let Some(created_at) = &result.metadata.created_at {
218-
metadata_dict.set_item("created_at", created_at)?;
177+
/// Supports both ExtractionResult class instances (attribute access) and
178+
/// plain dicts (dict-style access) for backward compatibility.
179+
fn merge_processed_result(py: Python<'_>, processed: &Bound<'_, PyAny>, result: &mut ExtractionResult) -> Result<()> {
180+
// If processor returned a dict, use dict-style access for backward compatibility
181+
if let Ok(dict) = processed.cast::<PyDict>() {
182+
return merge_dict_to_extraction_result(py, dict, result);
219183
}
220184

221-
for (key, value) in &result.metadata.additional {
222-
let py_value = json_value_to_py(py, value)?;
223-
metadata_dict.set_item(key, py_value)?;
185+
// Use attribute access (ExtractionResult or duck-typed object)
186+
if let Ok(content) = processed.getattr("content")
187+
&& !content.is_none()
188+
{
189+
result.content = content.extract().map_err(|e| KreuzbergError::Plugin {
190+
message: format!("PostProcessor returned invalid 'content': {}", e),
191+
plugin_name: "python".to_string(),
192+
})?;
224193
}
225194

226-
dict.set_item("metadata", metadata_dict)?;
195+
if let Ok(metadata) = processed.getattr("metadata")
196+
&& !metadata.is_none()
197+
&& let Ok(meta_dict) = metadata.cast::<PyDict>()
198+
{
199+
for (key, value) in meta_dict.iter() {
200+
let key_str: String = key.extract().map_err(|_| KreuzbergError::Plugin {
201+
message: "Metadata keys must be strings".to_string(),
202+
plugin_name: "python".to_string(),
203+
})?;
227204

228-
dict.set_item("tables", pyo3::types::PyList::empty(py))?;
205+
let json_value = python_to_json(&value)?;
206+
result
207+
.metadata
208+
.additional
209+
.insert(std::borrow::Cow::Owned(key_str), json_value);
210+
}
211+
}
229212

230-
Ok(dict.unbind())
213+
Ok(())
231214
}
232215

233216
/// Merge Python dict back into ExtractionResult.
@@ -288,7 +271,7 @@ fn merge_dict_to_extraction_result(
288271
///
289272
/// The Python processor must implement:
290273
/// - `name() -> str` - Return processor name
291-
/// - `process(result: dict) -> dict` - Process and enrich the extraction result
274+
/// - `process(result: ExtractionResult) -> ExtractionResult` - Process and enrich the extraction result
292275
///
293276
/// # Optional Methods
294277
///
@@ -300,7 +283,7 @@ fn merge_dict_to_extraction_result(
300283
/// # Example
301284
///
302285
/// ```python
303-
/// from kreuzberg import register_post_processor
286+
/// from kreuzberg import register_post_processor, ExtractionResult
304287
///
305288
/// class EntityExtractor:
306289
/// def name(self) -> str:
@@ -309,10 +292,10 @@ fn merge_dict_to_extraction_result(
309292
/// def processing_stage(self) -> str:
310293
/// return "early"
311294
///
312-
/// def process(self, result: dict) -> dict:
313-
/// # Extract entities from result["content"]
295+
/// def process(self, result: ExtractionResult) -> ExtractionResult:
296+
/// # Extract entities from result.content
314297
/// entities = {"PERSON": ["John Doe"], "ORG": ["Microsoft"]}
315-
/// result["metadata"]["entities"] = entities
298+
/// result.metadata["entities"] = entities
316299
/// return result
317300
///
318301
/// register_post_processor(EntityExtractor())
@@ -369,7 +352,7 @@ pub fn register_post_processor(py: Python<'_>, processor: Py<PyAny>) -> PyResult
369352
/// def name(self) -> str:
370353
/// return "my_processor"
371354
///
372-
/// def process(self, result: dict) -> dict:
355+
/// def process(self, result: ExtractionResult) -> ExtractionResult:
373356
/// return result
374357
///
375358
/// register_post_processor(MyProcessor())
@@ -444,7 +427,7 @@ pub fn clear_post_processors(py: Python<'_>) -> PyResult<()> {
444427
/// def name(self) -> str:
445428
/// return "my_processor"
446429
///
447-
/// def process(self, result: dict) -> dict:
430+
/// def process(self, result: ExtractionResult) -> ExtractionResult:
448431
/// return result
449432
///
450433
/// # Register processor

crates/kreuzberg-py/src/plugins/validator_bridge.rs

Lines changed: 27 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
66
use async_trait::async_trait;
77
use pyo3::prelude::*;
8-
use pyo3::types::PyDict;
98
use std::sync::Arc;
109

1110
use kreuzberg::core::config::ExtractionConfig;
@@ -14,7 +13,9 @@ use kreuzberg::plugins::{Plugin, Validator};
1413
use kreuzberg::types::ExtractionResult;
1514
use kreuzberg::{KreuzbergError, Result};
1615

17-
use super::common::{json_value_to_py, validate_plugin_object};
16+
use crate::types::ExtractionResult as PyExtractionResult;
17+
18+
use super::common::validate_plugin_object;
1819

1920
/// Wrapper that makes a Python Validator usable from Rust.
2021
///
@@ -127,13 +128,21 @@ impl Validator for PythonValidator {
127128
Python::attach(|py| {
128129
let obj = self.python_obj.bind(py);
129130

130-
let result_dict = extraction_result_to_dict(py, result).map_err(|e| KreuzbergError::Plugin {
131-
message: format!("Failed to convert ExtractionResult to Python dict: {}", e),
131+
// Convert Rust ExtractionResult to Python ExtractionResult class instance
132+
let py_extraction_result =
133+
PyExtractionResult::from_rust(result.clone(), py, None, None).map_err(|e| {
134+
KreuzbergError::Plugin {
135+
message: format!("Failed to convert ExtractionResult to Python: {}", e),
136+
plugin_name: validator_name.clone(),
137+
}
138+
})?;
139+
140+
let py_result_obj = Py::new(py, py_extraction_result).map_err(|e| KreuzbergError::Plugin {
141+
message: format!("Failed to create Python ExtractionResult: {}", e),
132142
plugin_name: validator_name.clone(),
133143
})?;
134144

135-
let py_result = result_dict.bind(py);
136-
obj.call_method1("validate", (py_result,)).map_err(|e| {
145+
obj.call_method1("validate", (py_result_obj,)).map_err(|e| {
137146
let is_validation_error = e.is_instance_of::<pyo3::exceptions::PyValueError>(py)
138147
|| e.get_type(py)
139148
.name()
@@ -180,9 +189,10 @@ impl Validator for PythonValidator {
180189
.unwrap_or(false);
181190

182191
if has_should_validate {
183-
let result_dict = extraction_result_to_dict(py, result).ok()?;
184-
let py_result = result_dict.bind(py);
185-
obj.call_method1("should_validate", (py_result,))
192+
let py_extraction_result =
193+
PyExtractionResult::from_rust(result.clone(), py, None, None).ok()?;
194+
let py_result_obj = Py::new(py, py_extraction_result).ok()?;
195+
obj.call_method1("should_validate", (py_result_obj,))
186196
.and_then(|v| v.extract::<bool>())
187197
.ok()
188198
} else {
@@ -197,69 +207,6 @@ impl Validator for PythonValidator {
197207
}
198208
}
199209

200-
/// Convert Rust ExtractionResult to Python dict.
201-
///
202-
/// This creates a Python dict that can be passed to Python validators:
203-
/// ```python
204-
/// {
205-
/// "content": "extracted text",
206-
/// "mime_type": "application/pdf",
207-
/// "metadata": {"key": "value"},
208-
/// "tables": [...]
209-
/// }
210-
/// ```
211-
fn extraction_result_to_dict(py: Python<'_>, result: &ExtractionResult) -> PyResult<Py<PyDict>> {
212-
let dict = PyDict::new(py);
213-
214-
dict.set_item("content", &result.content)?;
215-
216-
dict.set_item("mime_type", &result.mime_type)?;
217-
218-
let metadata_dict = PyDict::new(py);
219-
220-
if let Some(title) = &result.metadata.title {
221-
metadata_dict.set_item("title", title)?;
222-
}
223-
if let Some(subject) = &result.metadata.subject {
224-
metadata_dict.set_item("subject", subject)?;
225-
}
226-
if let Some(authors) = &result.metadata.authors {
227-
metadata_dict.set_item("authors", authors)?;
228-
}
229-
if let Some(keywords) = &result.metadata.keywords {
230-
metadata_dict.set_item("keywords", keywords)?;
231-
}
232-
if let Some(language) = &result.metadata.language {
233-
metadata_dict.set_item("language", language)?;
234-
}
235-
if let Some(created_at) = &result.metadata.created_at {
236-
metadata_dict.set_item("created_at", created_at)?;
237-
}
238-
if let Some(modified_at) = &result.metadata.modified_at {
239-
metadata_dict.set_item("modified_at", modified_at)?;
240-
}
241-
if let Some(created_by) = &result.metadata.created_by {
242-
metadata_dict.set_item("created_by", created_by)?;
243-
}
244-
if let Some(modified_by) = &result.metadata.modified_by {
245-
metadata_dict.set_item("modified_by", modified_by)?;
246-
}
247-
if let Some(created_at) = &result.metadata.created_at {
248-
metadata_dict.set_item("created_at", created_at)?;
249-
}
250-
251-
for (key, value) in &result.metadata.additional {
252-
let py_value = json_value_to_py(py, value)?;
253-
metadata_dict.set_item(key, py_value)?;
254-
}
255-
256-
dict.set_item("metadata", metadata_dict)?;
257-
258-
dict.set_item("tables", pyo3::types::PyList::empty(py))?;
259-
260-
Ok(dict.unbind())
261-
}
262-
263210
/// Register a Python Validator with the Rust core.
264211
///
265212
/// This function validates the Python validator object, wraps it in a Rust
@@ -275,11 +222,11 @@ fn extraction_result_to_dict(py: Python<'_>, result: &ExtractionResult) -> PyRes
275222
///
276223
/// The Python validator must implement:
277224
/// - `name() -> str` - Return validator name
278-
/// - `validate(result: dict) -> None` - Validate the extraction result (raise error to fail)
225+
/// - `validate(result: ExtractionResult) -> None` - Validate the extraction result (raise error to fail)
279226
///
280227
/// # Optional Methods
281228
///
282-
/// - `should_validate(result: dict) -> bool` - Check if validator should run (defaults to True)
229+
/// - `should_validate(result: ExtractionResult) -> bool` - Check if validator should run (defaults to True)
283230
/// - `priority() -> int` - Return priority (defaults to 50, higher runs first)
284231
/// - `initialize()` - Called when validator is registered
285232
/// - `shutdown()` - Called when validator is unregistered
@@ -288,7 +235,7 @@ fn extraction_result_to_dict(py: Python<'_>, result: &ExtractionResult) -> PyRes
288235
/// # Example
289236
///
290237
/// ```python
291-
/// from kreuzberg import register_validator
238+
/// from kreuzberg import register_validator, ExtractionResult
292239
/// from kreuzberg.exceptions import ValidationError
293240
///
294241
/// class MinLengthValidator:
@@ -298,10 +245,10 @@ fn extraction_result_to_dict(py: Python<'_>, result: &ExtractionResult) -> PyRes
298245
/// def priority(self) -> int:
299246
/// return 100 # Run early
300247
///
301-
/// def validate(self, result: dict) -> None:
302-
/// if len(result["content"]) < 100:
248+
/// def validate(self, result: ExtractionResult) -> None:
249+
/// if len(result.content) < 100:
303250
/// raise ValidationError(
304-
/// f"Content too short: {len(result['content'])} < 100 characters"
251+
/// f"Content too short: {len(result.content)} < 100 characters"
305252
/// )
306253
///
307254
/// register_validator(MinLengthValidator())
@@ -358,7 +305,7 @@ pub fn register_validator(py: Python<'_>, validator: Py<PyAny>) -> PyResult<()>
358305
/// def name(self) -> str:
359306
/// return "my_validator"
360307
///
361-
/// def validate(self, result: dict) -> None:
308+
/// def validate(self, result: ExtractionResult) -> None:
362309
/// pass
363310
///
364311
/// register_validator(MyValidator())
@@ -433,7 +380,7 @@ pub fn clear_validators(py: Python<'_>) -> PyResult<()> {
433380
/// def name(self) -> str:
434381
/// return "my_validator"
435382
///
436-
/// def validate(self, result: dict) -> None:
383+
/// def validate(self, result: ExtractionResult) -> None:
437384
/// pass
438385
///
439386
/// # Register validator

0 commit comments

Comments
 (0)