Skip to content

Commit b76780c

Browse files
authored
feat: add doc metadata extractor and ID generator classes (#34)
Signed-off-by: Panos Vagenas <[email protected]>
1 parent 4bde515 commit b76780c

File tree

7 files changed

+236
-0
lines changed

7 files changed

+236
-0
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#
2+
# Copyright IBM Corp. 2024 - 2024
3+
# SPDX-License-Identifier: MIT
4+
#
5+
6+
"""Define the ID generator types."""
7+
8+
from docling_core.transforms.id_generator.base import BaseIDGenerator # noqa
9+
from docling_core.transforms.id_generator.doc_hash_id_generator import ( # noqa
10+
DocHashIDGenerator,
11+
)
12+
from docling_core.transforms.id_generator.uuid_generator import UUIDGenerator # noqa
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#
2+
# Copyright IBM Corp. 2024 - 2024
3+
# SPDX-License-Identifier: MIT
4+
#
5+
6+
"""Base document ID generator module."""
7+
8+
from abc import ABC, abstractmethod
9+
from typing import Any
10+
11+
from docling_core.types import Document as DLDocument
12+
13+
14+
class BaseIDGenerator(ABC):
15+
"""Document ID generator base class."""
16+
17+
@abstractmethod
18+
def generate_id(self, doc: DLDocument, *args: Any, **kwargs: Any) -> str:
19+
"""Generate an ID for the given document.
20+
21+
Args:
22+
doc (DLDocument): document to generate ID for
23+
24+
Raises:
25+
NotImplementedError: in this abstract implementation
26+
27+
Returns:
28+
str: the generated ID
29+
"""
30+
raise NotImplementedError()
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#
2+
# Copyright IBM Corp. 2024 - 2024
3+
# SPDX-License-Identifier: MIT
4+
#
5+
6+
"""Doc-hash-based ID generator module."""
7+
8+
9+
from typing import Any
10+
11+
from docling_core.transforms.id_generator import BaseIDGenerator
12+
from docling_core.types import Document as DLDocument
13+
14+
15+
class DocHashIDGenerator(BaseIDGenerator):
16+
"""Doc-hash-based ID generator class."""
17+
18+
def generate_id(self, doc: DLDocument, *args: Any, **kwargs: Any) -> str:
19+
"""Generate an ID for the given document.
20+
21+
Args:
22+
doc (DLDocument): document to generate ID for
23+
24+
Returns:
25+
str: the generated ID
26+
"""
27+
return doc.file_info.document_hash
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#
2+
# Copyright IBM Corp. 2024 - 2024
3+
# SPDX-License-Identifier: MIT
4+
#
5+
6+
"""UUID-based ID generator module."""
7+
8+
from random import Random
9+
from typing import Annotated, Any, Optional
10+
from uuid import UUID
11+
12+
from pydantic import BaseModel, Field
13+
14+
from docling_core.transforms.id_generator import BaseIDGenerator
15+
from docling_core.types import Document as DLDocument
16+
17+
18+
class UUIDGenerator(BaseModel, BaseIDGenerator):
19+
"""UUID-based ID generator class."""
20+
21+
seed: Optional[int] = None
22+
uuid_version: Annotated[int, Field(strict=True, ge=1, le=5)] = 4
23+
24+
def generate_id(self, doc: DLDocument, *args: Any, **kwargs: Any) -> str:
25+
"""Generate an ID for the given document.
26+
27+
Args:
28+
doc (DLDocument): document to generate ID for
29+
30+
Returns:
31+
str: the generated ID
32+
"""
33+
rd = Random(x=self.seed)
34+
return str(UUID(int=rd.getrandbits(128), version=self.uuid_version))
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#
2+
# Copyright IBM Corp. 2024 - 2024
3+
# SPDX-License-Identifier: MIT
4+
#
5+
6+
"""Define the metadata extractor types."""
7+
8+
from docling_core.transforms.metadata_extractor.base import ( # noqa
9+
BaseMetadataExtractor,
10+
)
11+
from docling_core.transforms.metadata_extractor.simple_metadata_extractor import ( # noqa
12+
SimpleMetadataExtractor,
13+
)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#
2+
# Copyright IBM Corp. 2024 - 2024
3+
# SPDX-License-Identifier: MIT
4+
#
5+
6+
"""Base metadata extractor module."""
7+
8+
9+
from abc import ABC, abstractmethod
10+
from typing import Any
11+
12+
from pydantic import BaseModel
13+
14+
from docling_core.types import Document as DLDocument
15+
16+
17+
class BaseMetadataExtractor(BaseModel, ABC):
18+
"""Metadata extractor base class."""
19+
20+
@abstractmethod
21+
def get_metadata(
22+
self, doc: DLDocument, *args: Any, **kwargs: Any
23+
) -> dict[str, Any]:
24+
"""Extract metadata for the given document.
25+
26+
Args:
27+
doc (DLDocument): document to extract metadata for
28+
29+
Raises:
30+
NotImplementedError: in this abstract implementation
31+
32+
Returns:
33+
dict[str, Any]: the extracted metadata
34+
"""
35+
raise NotImplementedError()
36+
37+
@abstractmethod
38+
def get_excluded_embed_metadata_keys(self) -> list[str]:
39+
"""Get metadata keys to exclude from embedding.
40+
41+
Raises:
42+
NotImplementedError: in this abstract implementation
43+
44+
Returns:
45+
list[str]: the metadata to exclude
46+
"""
47+
raise NotImplementedError()
48+
49+
@abstractmethod
50+
def get_excluded_llm_metadata_keys(self) -> list[str]:
51+
"""Get metadata keys to exclude from LLM generation.
52+
53+
Raises:
54+
NotImplementedError: in this abstract implementation
55+
56+
Returns:
57+
list[str]: the metadata to exclude
58+
"""
59+
raise NotImplementedError()
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#
2+
# Copyright IBM Corp. 2024 - 2024
3+
# SPDX-License-Identifier: MIT
4+
#
5+
6+
"""Simple metadata extractor module."""
7+
8+
9+
from enum import Enum
10+
from typing import Any
11+
12+
from docling_core.transforms.metadata_extractor import BaseMetadataExtractor
13+
from docling_core.types import Document as DLDocument
14+
15+
16+
class SimpleMetadataExtractor(BaseMetadataExtractor):
17+
"""Simple metadata extractor class."""
18+
19+
class _Keys(str, Enum):
20+
DL_DOC_HASH = "dl_doc_hash"
21+
ORIGIN = "origin"
22+
23+
include_origin: bool = False
24+
25+
def get_metadata(
26+
self, doc: DLDocument, origin: str, *args: Any, **kwargs: Any
27+
) -> dict[str, Any]:
28+
"""Extract metadata for the given document.
29+
30+
Args:
31+
doc (DLDocument): document to extract metadata for
32+
origin (str): the document origin
33+
34+
Returns:
35+
dict[str, Any]: the extracted metadata
36+
"""
37+
meta: dict[str, Any] = {
38+
self._Keys.DL_DOC_HASH: doc.file_info.document_hash,
39+
}
40+
if self.include_origin:
41+
meta[self._Keys.ORIGIN] = origin
42+
return meta
43+
44+
def get_excluded_embed_metadata_keys(self) -> list[str]:
45+
"""Get metadata keys to exclude from embedding.
46+
47+
Returns:
48+
list[str]: the metadata to exclude
49+
"""
50+
excl_keys: list[str] = [self._Keys.DL_DOC_HASH]
51+
if self.include_origin:
52+
excl_keys.append(self._Keys.ORIGIN)
53+
return excl_keys
54+
55+
def get_excluded_llm_metadata_keys(self) -> list[str]:
56+
"""Get metadata keys to exclude from LLM generation.
57+
58+
Returns:
59+
list[str]: the metadata to exclude
60+
"""
61+
return self.get_excluded_embed_metadata_keys()

0 commit comments

Comments
 (0)