Skip to content

Commit 4dd581a

Browse files
authored
fix: union types validation in element enricher (#499)
1 parent 04ed313 commit 4dd581a

File tree

3 files changed

+33
-8
lines changed

3 files changed

+33
-8
lines changed

packages/ragbits-document-search/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
## Unreleased
44

5+
- fix union types validation in element enricher (#499)
6+
57
## 0.13.0 (2025-04-02)
68

79
### Changed
810

911
- ragbits-core updated to version v0.13.0
10-
1112
- DocumentSearch.ingest now raises IngestExecutionError when any errors are encountered during ingestion.
1213

1314
## 0.12.0 (2025-03-25)

packages/ragbits-document-search/src/ragbits/document_search/ingestion/enrichers/base.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
2-
from types import ModuleType
3-
from typing import ClassVar, Generic, TypeVar
2+
from types import ModuleType, UnionType
3+
from typing import ClassVar, Generic, TypeVar, get_args, get_origin
44

55
from ragbits.core.utils.config_handling import WithConstructionConfig
66
from ragbits.document_search.documents.element import Element
@@ -47,5 +47,18 @@ def validate_element_type(cls, element_type: type[Element]) -> None:
4747
Raises:
4848
EnricherElementNotSupportedError: If the element type is not supported.
4949
"""
50-
if element_type != cls.__orig_bases__[0].__args__[0]: # type: ignore
51-
raise EnricherElementNotSupportedError(enricher_name=cls.__name__, element_type=element_type)
50+
expected_element_type = cls.__orig_bases__[0].__args__[0] # type: ignore
51+
52+
# Check if expected_element_type is a Union and if element_type is in that Union
53+
if (
54+
(origin := get_origin(expected_element_type))
55+
and origin == UnionType
56+
and element_type in get_args(expected_element_type)
57+
):
58+
return
59+
60+
# Check if element_type matches expected_element_type exactly
61+
if element_type == expected_element_type:
62+
return
63+
64+
raise EnricherElementNotSupportedError(enricher_name=cls.__name__, element_type=element_type)

packages/ragbits-document-search/tests/unit/test_element_enrichers.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,27 @@
1111
from ragbits.document_search.ingestion.enrichers.image import ImageDescriberPrompt, ImageElementEnricher
1212

1313

14-
def test_enricher_validates_supported_element_types_passes() -> None:
14+
def test_enricher_validates_supported_element_types() -> None:
1515
ImageElementEnricher.validate_element_type(ImageElement)
1616

17+
with pytest.raises(EnricherElementNotSupportedError):
18+
ImageElementEnricher.validate_element_type(TextElement)
19+
1720

18-
def test_enricher_validates_supported_document_types_fails() -> None:
21+
def test_enricher_validates_supported_document_union_types() -> None:
1922
class CustomElement(Element):
23+
@property
24+
def text_representation(self) -> str:
25+
return ""
26+
27+
class CustomElementEnricher(ElementEnricher[CustomElement | TextElement]):
2028
pass
2129

30+
CustomElementEnricher.validate_element_type(CustomElement)
31+
CustomElementEnricher.validate_element_type(TextElement)
32+
2233
with pytest.raises(EnricherElementNotSupportedError):
23-
ImageElementEnricher.validate_element_type(CustomElement) # type: ignore
34+
CustomElementEnricher.validate_element_type(ImageElement)
2435

2536

2637
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)