Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ sharepoint = ["requirements/connectors/sharepoint.txt"]
singlestore = ["requirements/connectors/singlestore.txt"]
slack = ["requirements/connectors/slack.txt"]
snowflake = ["requirements/connectors/snowflake.txt"]
surrealdb = ["requirements/connectors/surrealdb.txt"]
vastdb = ["requirements/connectors/vastdb.txt"]
vectara = ["requirements/connectors/vectara.txt"]
weaviate = ["requirements/connectors/weaviate.txt"]
Expand Down
2 changes: 2 additions & 0 deletions requirements/connectors/surrealdb.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pandas
surrealdb
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"url": "ws://surrealdb:8000",
"username": "root",
"password": "root"
}
41 changes: 41 additions & 0 deletions test/integration/connectors/surrealdb/assets/surrealdb-schema.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
DEFINE TABLE elements SCHEMAFULL;

DEFINE FIELD id ON TABLE elements TYPE string;
DEFINE FIELD element_id ON TABLE elements TYPE string;
DEFINE FIELD text ON TABLE elements TYPE string;
DEFINE FIELD embeddings ON TABLE elements TYPE array<float>;
DEFINE FIELD type ON TABLE elements TYPE string;
DEFINE FIELD system ON TABLE elements TYPE option<string>;
DEFINE FIELD layout_width ON TABLE elements TYPE option<decimal>;
DEFINE FIELD layout_height ON TABLE elements TYPE option<decimal>;
DEFINE FIELD points ON TABLE elements TYPE option<string>;
DEFINE FIELD url ON TABLE elements TYPE option<string>;
DEFINE FIELD version ON TABLE elements TYPE option<string>;
DEFINE FIELD date_created ON TABLE elements TYPE option<int>;
DEFINE FIELD date_modified ON TABLE elements TYPE option<int>;
DEFINE FIELD date_processed ON TABLE elements TYPE option<float>;
DEFINE FIELD permissions_data ON TABLE elements TYPE option<string>;
DEFINE FIELD record_locator ON TABLE elements TYPE option<string>;
DEFINE FIELD category_depth ON TABLE elements TYPE option<int>;
DEFINE FIELD parent_id ON TABLE elements TYPE option<string>;
DEFINE FIELD attached_filename ON TABLE elements TYPE option<string>;
DEFINE FIELD filetype ON TABLE elements TYPE string;
DEFINE FIELD last_modified ON TABLE elements TYPE option<datetime>;
DEFINE FIELD file_directory ON TABLE elements TYPE option<string>;
DEFINE FIELD filename ON TABLE elements TYPE string;
DEFINE FIELD languages ON TABLE elements TYPE array<string>;
DEFINE FIELD page_number ON TABLE elements TYPE string;
DEFINE FIELD links ON TABLE elements TYPE option<string>;
DEFINE FIELD page_name ON TABLE elements TYPE option<string>;
DEFINE FIELD link_urls ON TABLE elements TYPE option<array<string>>;
DEFINE FIELD link_texts ON TABLE elements TYPE option<array<string>>;
DEFINE FIELD sent_from ON TABLE elements TYPE option<array<string>>;
DEFINE FIELD sent_to ON TABLE elements TYPE option<array<string>>;
DEFINE FIELD subject ON TABLE elements TYPE option<string>;
DEFINE FIELD section ON TABLE elements TYPE option<string>;
DEFINE FIELD header_footer_type ON TABLE elements TYPE option<string>;
DEFINE FIELD emphasized_text_contents ON TABLE elements TYPE option<array<string>>;
DEFINE FIELD emphasized_text_tags ON TABLE elements TYPE option<array<string>>;
DEFINE FIELD text_as_html ON TABLE elements TYPE option<string>;
DEFINE FIELD regex_metadata ON TABLE elements TYPE option<string>;
DEFINE FIELD detection_class_prob ON TABLE elements TYPE option<decimal>;
179 changes: 179 additions & 0 deletions test/integration/connectors/surrealdb/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from pathlib import Path

import pytest
import subprocess
import tempfile
import os
import time
import json
import string
import random
from typing import Any, Generator

from unstructured_ingest.processes.connectors.surrealdb.surrealdb import (
SurrealDBConnectionConfig,
SurrealDBAccessConfig,
)

int_test_dir = Path(__file__).parent
assets_dir = int_test_dir / "assets"


REMOTE_CONFIG_PATH = "test/integration/connectors/surrealdb/assets/remote_config.json"


def pytest_generate_tests(metafunc):
if "config" not in metafunc.fixturenames:
return

configs: list[str] = []
# Check if surreal command is available
try:
# Note that the integration tests require the surreal command to be available
# in the PATH.
result = subprocess.run(
["surreal", "--version"], capture_output=True, text=True, check=False
)
if result.returncode == 0:
configs.append("local_config")
else:
print(
f"Skipping local SurrealDB tests because 'surreal' command not found or failed. Error: {result.stderr}"
)
except FileNotFoundError:
print("Skipping local SurrealDB tests because 'surreal' command not found.")

if Path(REMOTE_CONFIG_PATH).is_file():
configs.append("remote_config")
else:
print(
f"Skipping containerized SurrealDB tests because config file not found at: {REMOTE_CONFIG_PATH}"
)

# for test_name in ["test_check_succeeds", "test_write"]:
metafunc.parametrize("config", configs, indirect=True)


@pytest.fixture(scope="module")
def test_namespace_name() -> str:
letters = string.ascii_lowercase
rand_string = "".join(random.choice(letters) for _ in range(6))
return f"test_db_{rand_string}"


@pytest.fixture(scope="module")
def test_database_name() -> str:
letters = string.ascii_lowercase
rand_string = "".join(random.choice(letters) for _ in range(6))
return f"test_db_{rand_string}"


@pytest.fixture
def config(
request, test_namespace_name: str, test_database_name: str, surrealdb_schema: Path
) -> Generator[Any, Any, Any]:
if request.param == "local_config":
process = None
tmp_dir = tempfile.TemporaryDirectory()
try:
db_dir_path = os.path.join(str(tmp_dir.name), "test.surrealdb")
os.makedirs(db_dir_path, exist_ok=True)
cmd = [
"surreal",
"start",
"--allow-all",
"--user",
"root",
"--pass",
"root",
"--log",
"trace", # Or "debug" for more verbose logs if needed
"--bind",
"0.0.0.0:8000",
f"rocksdb://{db_dir_path}",
]

process = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)

# Wait for server to initialize.
# A more robust check (e.g., trying to connect) could be added if necessary.
time.sleep(3)

if process.poll() is not None:
# pylint: disable=C0301
stderr_output = (
process.stderr.read().decode() if process.stderr else "No stderr output."
)
raise RuntimeError(
f"Failed to start SurrealDB server. Exit code: {process.returncode}. "
f"Command: {' '.join(cmd)}. Stderr: {stderr_output}"
)

conf = SurrealDBConnectionConfig(
url="ws://localhost:8000",
namespace=test_namespace_name,
database=test_database_name,
access_config=SurrealDBAccessConfig(
username="root",
password="root",
),
)

with conf.get_client() as client:
client.query(f"DEFINE DATABASE {test_database_name};")
client.query(f"DEFINE NAMESPACE {test_namespace_name};")
client.use(test_namespace_name, test_database_name)
with surrealdb_schema.open("r") as f:
query = f.read()
client.query(query)

yield conf

finally:
# Ensure process is terminated
# and all resources are cleaned up
if process:
if process.poll() is None:
process.terminate()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
process.kill()
process.wait()
if process.stderr:
process.stderr.close()
tmp_dir.cleanup()

elif request.param == "remote_config":
conf = json.loads(Path(REMOTE_CONFIG_PATH).read_text(encoding="utf-8"))
conn_conf = SurrealDBConnectionConfig(
url=conf["url"],
namespace=test_namespace_name,
database=test_database_name,
access_config=SurrealDBAccessConfig(
username=conf.get("username", None),
password=conf.get("password", None),
token=conf.get("token", None),
),
)

with conn_conf.get_client() as client:
client.query(f"DEFINE DATABASE {test_database_name};")
client.query(f"DEFINE NAMESPACE {test_namespace_name};")
client.use(test_namespace_name, test_database_name)
with surrealdb_schema.open("r") as f:
query = f.read()
client.query(query)

yield conn_conf

else:
raise ValueError(f"Unknown config type: {request.param}")


@pytest.fixture
def surrealdb_schema() -> Path:
schema_file = assets_dir / "surrealdb-schema.sql"
assert schema_file.exists()
assert schema_file.is_file()
return schema_file
78 changes: 78 additions & 0 deletions test/integration/connectors/surrealdb/test_surrealdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import json
from pathlib import Path
import logging

import surrealdb
import pytest
from _pytest.fixtures import TopRequest

from test.integration.connectors.utils.constants import DESTINATION_TAG, SQL_TAG
from test.integration.connectors.utils.validation.destination import (
StagerValidationConfigs,
stager_validation,
)
from unstructured_ingest.data_types.file_data import FileData, SourceIdentifiers
from unstructured_ingest.processes.connectors.surrealdb.surrealdb import (
CONNECTOR_TYPE,
SurrealDBConnectionConfig,
SurrealDBUploader,
SurrealDBUploaderConfig,
SurrealDBUploadStager,
)

logger = logging.getLogger("surrealdb_test")


def validate_surrealdb_destination(config: SurrealDBConnectionConfig, expected_num_elements: int):
with config.get_client() as client:
_results = client.query("select count() from elements group all;")
logger.debug(f"results: {_results}")
_count = _results[0]["count"]
assert _count == expected_num_elements, (
f"dest check failed: got {_count}, expected {expected_num_elements}"
)


@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, "surrealdb", SQL_TAG)
def test_surrealdb_destination(
upload_file: Path, config: SurrealDBConnectionConfig, temp_dir: Path
):
file_data = FileData(
source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
connector_type=CONNECTOR_TYPE,
identifier="mock-file-data",
)

stager = SurrealDBUploadStager()
staged_path = stager.run(
elements_filepath=upload_file,
file_data=file_data,
output_dir=temp_dir,
output_filename=upload_file.name,
)

upload_config = SurrealDBUploaderConfig()
uploader = SurrealDBUploader(connection_config=config, upload_config=upload_config)

uploader.run(path=staged_path, file_data=file_data)

with staged_path.open() as f:
data = json.load(f)
validate_surrealdb_destination(config=config, expected_num_elements=len(data))


@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, "surrealdb", SQL_TAG)
@pytest.mark.parametrize("upload_file_str", ["upload_file_ndjson", "upload_file"])
def surrealdb_stager(
request: TopRequest,
upload_file_str: str,
tmp_path: Path,
):
upload_file: Path = request.getfixturevalue(upload_file_str)
stager = SurrealDBUploadStager()
stager_validation(
configs=StagerValidationConfigs(test_id=CONNECTOR_TYPE, expected_count=22),
input_file=upload_file,
stager=stager,
tmp_dir=tmp_path,
)
1 change: 1 addition & 0 deletions unstructured_ingest/processes/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import unstructured_ingest.processes.connectors.lancedb # noqa: F401
import unstructured_ingest.processes.connectors.qdrant # noqa: F401
import unstructured_ingest.processes.connectors.sql # noqa: F401
import unstructured_ingest.processes.connectors.surrealdb # noqa: F401
import unstructured_ingest.processes.connectors.weaviate # noqa: F401
from unstructured_ingest.processes.connector_registry import (
add_destination_entry,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ def can_have_children() -> bool:
@classmethod
def from_dict(cls, data: dict):
"""Create OriginalSyncedBlock from dictionary data.

Original blocks contain children content.
"""
if "children" not in data:
raise ValueError(f"OriginalSyncedBlock data missing 'children': {data}")
raise ValueError(f"OriginalSyncedBlock data missing 'children': {data}")
return cls(children=data["children"])

def get_html(self) -> Optional[HtmlTag]:
Expand All @@ -38,7 +38,7 @@ class DuplicateSyncedBlock(BlockBase):
@staticmethod
def can_have_children() -> bool:
"""Check if duplicate synced blocks can have children.

Duplicate blocks themselves don't have children directly fetched here,
but they represent content that does, so Notion API might report has_children=True
on the parent block object. The actual children are fetched from the original block.
Expand All @@ -48,7 +48,7 @@ def can_have_children() -> bool:
@classmethod
def from_dict(cls, data: dict):
"""Create DuplicateSyncedBlock from dictionary data.

Duplicate blocks contain a 'synced_from' reference.
"""
synced_from_data = data.get("synced_from")
Expand All @@ -63,7 +63,7 @@ def from_dict(cls, data: dict):

def get_html(self) -> Optional[HtmlTag]:
"""Get HTML representation of the duplicate synced block.

HTML representation might need fetching the original block's content,
which is outside the scope of this simple data class.
"""
Expand All @@ -74,15 +74,15 @@ class SyncBlock(BlockBase):
@staticmethod
def can_have_children() -> bool:
"""Check if synced blocks can have children.

Synced blocks (both original and duplicate) can conceptually have children.
"""
return True

@classmethod
def from_dict(cls, data: dict):
"""Create appropriate SyncedBlock subclass from dictionary data.

Determine if it's a duplicate (has 'synced_from') or original (has 'children').
"""
if data.get("synced_from") is not None:
Expand All @@ -99,10 +99,9 @@ def from_dict(cls, data: dict):
# Consider logging a warning here if strictness is needed.
return OriginalSyncedBlock(children=[])


def get_html(self) -> Optional[HtmlTag]:
"""Get HTML representation of the synced block.

The specific instance returned by from_dict (Original or Duplicate)
will handle its own get_html logic.
This method on the base SyncBlock might not be directly called.
Expand Down
Loading