diff --git a/src/unstract/sdk/__init__.py b/src/unstract/sdk/__init__.py index 3f9c0f15..ffcb125c 100644 --- a/src/unstract/sdk/__init__.py +++ b/src/unstract/sdk/__init__.py @@ -1,4 +1,4 @@ -__version__ = "v0.77.1" +__version__ = "v0.78.0" def get_sdk_version() -> str: diff --git a/src/unstract/sdk/adapters/base.py b/src/unstract/sdk/adapters/base.py index 4b4daf98..b1827bf4 100644 --- a/src/unstract/sdk/adapters/base.py +++ b/src/unstract/sdk/adapters/base.py @@ -2,6 +2,8 @@ from abc import ABC, abstractmethod from unstract.sdk.adapters.enums import AdapterTypes +from unstract.sdk.adapters.exceptions import AdapterError +from unstract.sdk.adapters.url_validator import URLValidator logger = logging.getLogger(__name__) @@ -32,7 +34,7 @@ def get_icon() -> str: @classmethod def get_json_schema(cls) -> str: - schema_path = getattr(cls, 'SCHEMA_PATH', None) + schema_path = getattr(cls, "SCHEMA_PATH", None) if schema_path is None: raise ValueError(f"SCHEMA_PATH not defined for {cls.__name__}") with open(schema_path) as f: @@ -43,6 +45,36 @@ def get_json_schema(cls) -> str: def get_adapter_type() -> AdapterTypes: return "" + @abstractmethod + def get_configured_urls(self) -> list[str]: + """Return all URLs that this adapter will connect to. + + This method should return a list of all URLs that the adapter + uses for external connections. These URLs will be validated + for security before allowing connection attempts. + + Returns: + list[str]: List of URLs that will be accessed by this adapter + """ + return [] + + def _validate_urls(self) -> None: + """Validate all configured URLs against security rules.""" + urls = self.get_configured_urls() + + for url in urls: + if not url: # Skip empty/None URLs + continue + + is_valid, error_message = URLValidator.validate_url(url) + if not is_valid: + # Use class name as fallback when self.name isn't set yet + adapter_name = getattr(self, "name", self.__class__.__name__) + logger.error( + f"URL validation failed for adapter '{adapter_name}': {error_message}" + ) + raise AdapterError(f"URL validation failed: {error_message}") + @abstractmethod def test_connection(self) -> bool: """Override to test connection for a adapter. diff --git a/src/unstract/sdk/adapters/embedding/azure_open_ai/src/azure_open_ai.py b/src/unstract/sdk/adapters/embedding/azure_open_ai/src/azure_open_ai.py index 872cb398..37631a34 100644 --- a/src/unstract/sdk/adapters/embedding/azure_open_ai/src/azure_open_ai.py +++ b/src/unstract/sdk/adapters/embedding/azure_open_ai/src/azure_open_ai.py @@ -22,10 +22,14 @@ class Constants: class AzureOpenAI(EmbeddingAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("AzureOpenAIEmbedding") self.config = settings + # Validate URLs BEFORE any network operations + if validate_urls: + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -48,6 +52,11 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/AzureopenAI.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + endpoint = self.config.get("azure_endpoint") + return [endpoint] if endpoint else [] + def get_embedding_instance(self) -> BaseEmbedding: try: embedding_batch_size = EmbeddingHelper.get_embedding_batch_size( diff --git a/src/unstract/sdk/adapters/embedding/ollama/src/ollama.py b/src/unstract/sdk/adapters/embedding/ollama/src/ollama.py index 68b5b8a0..f0e921ec 100644 --- a/src/unstract/sdk/adapters/embedding/ollama/src/ollama.py +++ b/src/unstract/sdk/adapters/embedding/ollama/src/ollama.py @@ -15,10 +15,14 @@ class Constants: class Ollama(EmbeddingAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("Ollama") self.config = settings + # Validate URLs BEFORE any network operations + if validate_urls: + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -41,6 +45,11 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/ollama.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + base_url = self.config.get("base_url") + return [base_url] if base_url else [] + def get_embedding_instance(self) -> BaseEmbedding: try: embedding_batch_size = EmbeddingHelper.get_embedding_batch_size( diff --git a/src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py b/src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py index 781e849d..c51790d0 100644 --- a/src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py +++ b/src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py @@ -21,10 +21,14 @@ class Constants: class OpenAI(EmbeddingAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("OpenAI") self.config = settings + # Validate URLs BEFORE any network operations + if validate_urls: + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -47,6 +51,11 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/OpenAI.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + api_base = self.config.get("api_base") + return [api_base] if api_base else [] + def get_embedding_instance(self) -> BaseEmbedding: try: timeout = int(self.config.get(Constants.TIMEOUT, Constants.DEFAULT_TIMEOUT)) diff --git a/src/unstract/sdk/adapters/llm/any_scale/src/anyscale.py b/src/unstract/sdk/adapters/llm/any_scale/src/anyscale.py index 3c371ddc..5d53560e 100644 --- a/src/unstract/sdk/adapters/llm/any_scale/src/anyscale.py +++ b/src/unstract/sdk/adapters/llm/any_scale/src/anyscale.py @@ -4,7 +4,6 @@ from llama_index.core.constants import DEFAULT_NUM_OUTPUTS from llama_index.core.llms import LLM from llama_index.llms.anyscale import Anyscale - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.llm.constants import LLMKeys from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter @@ -20,10 +19,14 @@ class Constants: class AnyScaleLLM(LLMAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("AnyScale") self.config = settings + # Validate URLs BEFORE any network operations + if validate_urls: + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -46,6 +49,11 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/anyscale.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + api_base = self.config.get(Constants.API_BASE) + return [api_base] if api_base else [] + def get_llm_instance(self) -> LLM: try: max_tokens = int(self.config.get(Constants.MAX_TOKENS, DEFAULT_NUM_OUTPUTS)) diff --git a/src/unstract/sdk/adapters/llm/azure_open_ai/src/azure_open_ai.py b/src/unstract/sdk/adapters/llm/azure_open_ai/src/azure_open_ai.py index fe1c123c..22a62f7c 100644 --- a/src/unstract/sdk/adapters/llm/azure_open_ai/src/azure_open_ai.py +++ b/src/unstract/sdk/adapters/llm/azure_open_ai/src/azure_open_ai.py @@ -4,7 +4,6 @@ from llama_index.core.llms import LLM from llama_index.llms.azure_openai import AzureOpenAI from llama_index.llms.openai.utils import O1_MODELS - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.llm.constants import LLMKeys from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter @@ -26,10 +25,14 @@ class Constants: class AzureOpenAILLM(LLMAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("AzureOpenAI") self.config = settings + # Validate URLs BEFORE any network operations + if validate_urls: + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -52,6 +55,11 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/AzureopenAI.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + endpoint = self.config.get("azure_endpoint") + return [endpoint] if endpoint else [] + def get_llm_instance(self) -> LLM: max_retries = int( self.config.get(Constants.MAX_RETRIES, LLMKeys.DEFAULT_MAX_RETRIES) @@ -74,9 +82,7 @@ def get_llm_instance(self) -> LLM: } if enable_reasoning: - llm_kwargs["reasoning_effort"] = self.config.get( - Constants.REASONING_EFFORT - ) + llm_kwargs["reasoning_effort"] = self.config.get(Constants.REASONING_EFFORT) if model not in O1_MODELS: llm_kwargs["max_completion_tokens"] = max_tokens diff --git a/src/unstract/sdk/adapters/llm/ollama/src/ollama.py b/src/unstract/sdk/adapters/llm/ollama/src/ollama.py index 49e6ff13..0c35b67b 100644 --- a/src/unstract/sdk/adapters/llm/ollama/src/ollama.py +++ b/src/unstract/sdk/adapters/llm/ollama/src/ollama.py @@ -6,7 +6,6 @@ from httpx import ConnectError, HTTPStatusError from llama_index.core.llms import LLM from llama_index.llms.ollama import Ollama - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.llm.constants import LLMKeys from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter @@ -25,10 +24,13 @@ class Constants: class OllamaLLM(LLMAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("Ollama") self.config = settings + if validate_urls: + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -51,6 +53,11 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/ollama.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + base_url = self.config.get(Constants.BASE_URL) + return [base_url] if base_url else [] + def get_llm_instance(self) -> LLM: try: llm: LLM = Ollama( diff --git a/src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py b/src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py index d1d1b255..0baa4e18 100644 --- a/src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py +++ b/src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py @@ -25,10 +25,13 @@ class Constants: class OpenAILLM(LLMAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("OpenAI") self.config = settings + if validate_urls: + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -51,6 +54,11 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/OpenAI.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + api_base = self.config.get("api_base") + return [api_base] if api_base else [] + def get_llm_instance(self) -> LLM: try: max_tokens = self.config.get(Constants.MAX_TOKENS) diff --git a/src/unstract/sdk/adapters/url_validator.py b/src/unstract/sdk/adapters/url_validator.py new file mode 100644 index 00000000..afdeb131 --- /dev/null +++ b/src/unstract/sdk/adapters/url_validator.py @@ -0,0 +1,172 @@ +import ipaddress +import logging +import os +import socket +from dataclasses import dataclass +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + + +@dataclass +class WhitelistEntry: + """Represents a whitelisted endpoint with IP range and optional port.""" + + ip_network: ipaddress.IPv4Network | ipaddress.IPv6Network + port: int | None = None + + +class URLValidator: + """Validates URLs to prevent SSRF attacks by blocking private IP addresses. + + URLs are validated to block private IP addresses unless explicitly + whitelisted via WHITELISTED_ENDPOINTS. + """ + + ENV_VAR = "WHITELISTED_ENDPOINTS" + + # Private IP ranges that are blocked by default (RFC 1918 + others) + BLOCKED_PRIVATE_RANGES = [ + "127.0.0.0/8", # Localhost + "10.0.0.0/8", # Class A private + "172.16.0.0/12", # Class B private + "192.168.0.0/16", # Class C private + "169.254.0.0/16", # Link-local + "0.0.0.0/8", # Current network + "224.0.0.0/4", # Multicast + "240.0.0.0/4", # Reserved + # IPv6 ranges + "::1/128", # IPv6 localhost + "fc00::/7", # IPv6 unique local + "fe80::/10", # IPv6 link-local + ] + + @classmethod + def validate_url(cls, url: str) -> tuple[bool, str]: + """Validates a URL against security rules. + + Args: + url: The URL to validate + + Returns: + Tuple of (is_valid, error_message) + """ + try: + parsed = urlparse(url) + + if not parsed.hostname: + return False, f"Invalid URL: No hostname found in '{url}'" + + # Resolve hostname to IP address + try: + host_ip = socket.gethostbyname(parsed.hostname) + except socket.gaierror as e: + return ( + False, + f"DNS resolution failed for '{parsed.hostname}': {str(e)}", + ) + + # Check if IP is private + ip_obj = ipaddress.ip_address(host_ip) + if cls._is_private_ip(ip_obj): + # Private IP - check whitelist + port = parsed.port + if cls._is_whitelisted(ip_obj, port): + logger.info(f"Private IP {host_ip}:{port} allowed by whitelist") + return True, "" + else: + error_msg = ( + f"URL blocked: Private IP {host_ip}" + f"{':' + str(port) if port else ''} not in whitelist. " + f"Contact platform admin for assistance." + ) + return False, error_msg + + # Public IP - allowed by default + return True, "" + + except Exception as e: + logger.error(f"URL validation error for '{url}': {str(e)}") + return False, f"{str(e)}" + + @classmethod + def _is_private_ip(cls, ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: + """Check if IP address is in private ranges.""" + for range_str in cls.BLOCKED_PRIVATE_RANGES: + try: + network = ipaddress.ip_network(range_str) + if ip in network: + return True + except ValueError: + continue + return False + + @classmethod + def _is_whitelisted( + cls, ip: ipaddress.IPv4Address | ipaddress.IPv6Address, port: int | None + ) -> bool: + """Check if IP:port combination is whitelisted.""" + whitelist = cls._parse_whitelist_config() + + for entry in whitelist: + if ip in entry.ip_network: + # IP matches - check port + if entry.port is None or entry.port == port: + return True + + return False + + @classmethod + def _parse_whitelist_config(cls) -> list[WhitelistEntry]: + """Parse whitelist configuration from environment variable.""" + config = os.getenv(cls.ENV_VAR, "").strip() + if not config: + return [] + + entries = [] + for item in config.split(","): + item = item.strip() + if not item: + continue + + try: + entry = cls._parse_whitelist_entry(item) + if entry: + entries.append(entry) + except Exception as e: + logger.warning(f"Invalid whitelist entry '{item}': {str(e)}") + + return entries + + @classmethod + def _parse_whitelist_entry(cls, entry: str) -> WhitelistEntry | None: + """Parse a single whitelist entry in format 'IP:PORT' or + 'IP/CIDR:PORT'.""" + port = None + ip_part = entry + + # Check if entry has port specification + if ":" in entry: + parts = entry.rsplit(":", 1) + if len(parts) == 2 and parts[1].isdigit(): + ip_part = parts[0] + port = int(parts[1]) + + # Parse IP or CIDR + try: + if "/" in ip_part: + # CIDR notation + network = ipaddress.ip_network(ip_part, strict=False) + else: + # Single IP - convert to /32 or /128 network + ip = ipaddress.ip_address(ip_part) + if isinstance(ip, ipaddress.IPv4Address): + network = ipaddress.IPv4Network(f"{ip}/32") + else: + network = ipaddress.IPv6Network(f"{ip}/128") + + return WhitelistEntry(ip_network=network, port=port) + + except ValueError as e: + logger.warning(f"Invalid IP/CIDR in whitelist entry '{ip_part}': {str(e)}") + return None diff --git a/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py b/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py index 12c1fbc7..4b60fffe 100644 --- a/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py +++ b/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py @@ -17,10 +17,14 @@ class Constants: class Milvus(VectorDBAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): self._config = settings self._client: MilvusClient | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME + + if validate_urls: + self._validate_urls() + self._vector_db_instance = self._get_vector_db_instance() super().__init__("Milvus", self._vector_db_instance) @@ -42,6 +46,11 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/Milvus.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + uri = self._config.get(Constants.URI) + return [uri] if uri else [] + def get_vector_db_instance(self) -> VectorStore: return self._vector_db_instance diff --git a/src/unstract/sdk/adapters/vectordb/milvus/src/static/json_schema.json b/src/unstract/sdk/adapters/vectordb/milvus/src/static/json_schema.json index 22965dbb..9c493a59 100644 --- a/src/unstract/sdk/adapters/vectordb/milvus/src/static/json_schema.json +++ b/src/unstract/sdk/adapters/vectordb/milvus/src/static/json_schema.json @@ -16,7 +16,7 @@ "type": "string", "title": "URI", "format": "uri", - "default": "localhost:19530", + "default": "http://localhost:19530", "description": "Provide the URI of the Milvus server. Example: `https://.api.gcp-us-west1.zillizcloud.com`" }, "token": { diff --git a/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py b/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py index 36676129..d6e200a9 100644 --- a/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py +++ b/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py @@ -23,11 +23,15 @@ class Constants: class Postgres(VectorDBAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): self._config = settings self._client: connection | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME self._schema_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME + + if validate_urls: + self._validate_urls() + self._vector_db_instance = self._get_vector_db_instance() super().__init__("Postgres", self._vector_db_instance) @@ -108,6 +112,20 @@ def test_connection(self) -> bool: return test_result + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + host = self._config.get(Constants.HOST) + port = self._config.get(Constants.PORT) + + if host: + # Construct the database URL for validation + if port: + url = f"postgresql://{host}:{port}" + else: + url = f"postgresql://{host}" + return [url] + return [] + def close(self, **kwargs: Any) -> None: if self._client: self._client.close() diff --git a/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py b/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py index 767a387d..36173ee6 100644 --- a/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py +++ b/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py @@ -20,10 +20,14 @@ class Constants: class Qdrant(VectorDBAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): self._config = settings self._client: QdrantClient | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME + + if validate_urls: + self._validate_urls() + self._vector_db_instance = self._get_vector_db_instance() super().__init__("Qdrant", self._vector_db_instance) @@ -45,6 +49,11 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/qdrant.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + url = self._config.get(Constants.URL) + return [url] if url else [] + def get_vector_db_instance(self) -> BasePydanticVectorStore: return self._vector_db_instance diff --git a/src/unstract/sdk/adapters/vectordb/weaviate/src/weaviate.py b/src/unstract/sdk/adapters/vectordb/weaviate/src/weaviate.py index 8e2c12aa..13cd9860 100644 --- a/src/unstract/sdk/adapters/vectordb/weaviate/src/weaviate.py +++ b/src/unstract/sdk/adapters/vectordb/weaviate/src/weaviate.py @@ -21,10 +21,14 @@ class Constants: class Weaviate(VectorDBAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): self._config = settings self._client: weaviate.Client | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME + + if validate_urls: + self._validate_urls() + self._vector_db_instance = self._get_vector_db_instance() super().__init__("Weaviate", self._vector_db_instance) @@ -46,6 +50,11 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/Weaviate.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + url = self._config.get(Constants.URL) + return [url] if url else [] + def get_vector_db_instance(self) -> BasePydanticVectorStore: return self._vector_db_instance diff --git a/src/unstract/sdk/adapters/x2text/helper.py b/src/unstract/sdk/adapters/x2text/helper.py index 3a94872f..ad225951 100644 --- a/src/unstract/sdk/adapters/x2text/helper.py +++ b/src/unstract/sdk/adapters/x2text/helper.py @@ -5,6 +5,7 @@ from requests import Response from requests.exceptions import ConnectionError, HTTPError, Timeout from unstract.sdk.adapters.exceptions import AdapterError +from unstract.sdk.adapters.url_validator import URLValidator from unstract.sdk.adapters.utils import AdapterUtils from unstract.sdk.adapters.x2text.constants import X2TextConstants from unstract.sdk.constants import MimeType @@ -101,6 +102,13 @@ def make_request( ) -> Response: unstructured_url = unstructured_adapter_config.get(UnstructuredHelper.URL) + # Validate the unstructured URL for security + if unstructured_url: + is_valid, error_message = URLValidator.validate_url(unstructured_url) + if not is_valid: + logger.error(f"Unstructured URL validation failed: {error_message}") + raise AdapterError(f"URL validation failed: {error_message}") + x2text_service_url = unstructured_adapter_config.get(X2TextConstants.X2TEXT_HOST) x2text_service_port = unstructured_adapter_config.get(X2TextConstants.X2TEXT_PORT) platform_service_api_key = unstructured_adapter_config.get( diff --git a/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py b/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py index 237a5078..d871fc2f 100644 --- a/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py +++ b/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py @@ -7,7 +7,9 @@ from llama_parse import LlamaParse from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.x2text.dto import TextExtractionResult -from unstract.sdk.adapters.x2text.llama_parse.src.constants import LlamaParseConfig +from unstract.sdk.adapters.x2text.llama_parse.src.constants import ( + LlamaParseConfig, +) from unstract.sdk.adapters.x2text.x2text_adapter import X2TextAdapter from unstract.sdk.file_storage import FileStorage, FileStorageProvider @@ -15,10 +17,14 @@ class LlamaParseAdapter(X2TextAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("LlamaParse") self.config = settings + # Validate URLs BEFORE any network operations + if validate_urls: + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -37,6 +43,13 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/llama-parse.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + base_url = self.config.get(LlamaParseConfig.BASE_URL) + if isinstance(base_url, str): + base_url = base_url.strip() + return [base_url] if base_url else [] + def _call_parser( self, input_file_path: str, diff --git a/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py b/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py index 71ef4502..0bce7895 100644 --- a/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py +++ b/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py @@ -8,6 +8,7 @@ import requests from requests import Response from requests.exceptions import ConnectionError, HTTPError, Timeout + from unstract.sdk.adapters.exceptions import ExtractorError from unstract.sdk.adapters.utils import AdapterUtils from unstract.sdk.adapters.x2text.constants import X2TextConstants @@ -55,6 +56,11 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/LLMWhisperer.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + url = str(self.config.get(WhispererConfig.URL) or "").strip() + return [url.rstrip("/")] if url else [] + def _get_request_headers(self) -> dict[str, Any]: """Obtains the request headers to authenticate with LLMWhisperer. @@ -63,7 +69,9 @@ def _get_request_headers(self) -> dict[str, Any]: """ return { "accept": MimeType.JSON, - WhispererHeader.UNSTRACT_KEY: self.config.get(WhispererConfig.UNSTRACT_KEY), + WhispererHeader.UNSTRACT_KEY: self.config.get( + WhispererConfig.UNSTRACT_KEY + ), } def _make_request( @@ -109,7 +117,9 @@ def _make_request( data=data, ) else: - raise ExtractorError(f"Unsupported request method: {request_method}") + raise ExtractorError( + f"Unsupported request method: {request_method}" + ) response.raise_for_status() except ConnectionError as e: logger.error(f"Adapter error: {e}") @@ -129,7 +139,9 @@ def _make_request( raise ExtractorError(msg) return response - def _get_whisper_params(self, enable_highlight: bool = False) -> dict[str, Any]: + def _get_whisper_params( + self, enable_highlight: bool = False + ) -> dict[str, Any]: """Gets query params meant for /whisper endpoint. The params is filled based on the configuration passed. @@ -195,11 +207,14 @@ def _get_whisper_params(self, enable_highlight: bool = False) -> dict[str, Any]: if enable_highlight: params.update( - {WhispererConfig.STORE_METADATA_FOR_HIGHLIGHTING: enable_highlight} + { + WhispererConfig.STORE_METADATA_FOR_HIGHLIGHTING: enable_highlight + } ) return params def test_connection(self) -> bool: + self._make_request( request_method=HTTPMethod.GET, request_endpoint=WhispererEndpoint.TEST_CONNECTION, @@ -243,7 +258,9 @@ def _check_status_until_ready( ) if status_response.status_code == 200: status_data = status_response.json() - status = status_data.get(WhisperStatus.STATUS, WhisperStatus.UNKNOWN) + status = status_data.get( + WhisperStatus.STATUS, WhisperStatus.UNKNOWN + ) logger.info(f"Whisper status for {whisper_hash}: {status}") if status in [WhisperStatus.PROCESSED, WhisperStatus.DELIVERED]: break @@ -256,7 +273,8 @@ def _check_status_until_ready( # Exit with error if max poll count is reached if request_count >= MAX_POLLS: raise ExtractorError( - "Unable to extract text after attempting" f" {request_count} times" + "Unable to extract text after attempting" + f" {request_count} times" ) time.sleep(POLL_INTERVAL) @@ -340,7 +358,9 @@ def _extract_text_from_response( raise ExtractorError("Couldn't extract text from file") if output_file_path: self._write_output_to_file( - output_json=output_json, output_file_path=Path(output_file_path), fs=fs + output_json=output_json, + output_file_path=Path(output_file_path), + fs=fs, ) return output_json.get("text", "") @@ -381,9 +401,13 @@ def _write_output_to_file( fs.mkdir(str(metadata_dir), create_parents=True) # Remove the "text" key from the metadata metadata = { - key: value for key, value in output_json.items() if key != "text" + key: value + for key, value in output_json.items() + if key != "text" } - metadata_json = json.dumps(metadata, ensure_ascii=False, indent=4) + metadata_json = json.dumps( + metadata, ensure_ascii=False, indent=4 + ) logger.info(f"Writing metadata to {metadata_file_path}") fs.write( @@ -393,7 +417,9 @@ def _write_output_to_file( data=metadata_json, ) except Exception as e: - logger.error(f"Error while writing metadata to {metadata_file_path}: {e}") + logger.error( + f"Error while writing metadata to {metadata_file_path}: {e}" + ) except Exception as e: logger.error(f"Error while writing {output_file_path}: {e}") diff --git a/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py b/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py index 166b9f92..0964894b 100644 --- a/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py +++ b/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py @@ -20,10 +20,14 @@ class LLMWhispererV2(X2TextAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("LLMWhispererV2") self.config = settings + # Validate URLs BEFORE any network operations + if validate_urls: + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -42,6 +46,11 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/LLMWhispererV2.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + url = self.config.get("url") + return [url] if url else [] + def test_connection(self) -> bool: LLMWhispererHelper.test_connection_request( config=self.config, diff --git a/src/unstract/sdk/adapters/x2text/unstructured_community/src/unstructured_community.py b/src/unstract/sdk/adapters/x2text/unstructured_community/src/unstructured_community.py index b205ce0c..6be3289c 100644 --- a/src/unstract/sdk/adapters/x2text/unstructured_community/src/unstructured_community.py +++ b/src/unstract/sdk/adapters/x2text/unstructured_community/src/unstructured_community.py @@ -11,10 +11,14 @@ class UnstructuredCommunity(X2TextAdapter): - def __init__(self, settings: dict[str, Any]): + def __init__(self, settings: dict[str, Any], validate_urls: bool = False): super().__init__("UnstructuredIOCommunity") self.config = settings + # Validate URLs BEFORE any network operations + if validate_urls: + self._validate_urls() + SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json" @staticmethod @@ -33,6 +37,11 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/UnstructuredIO.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + url = self.config.get("url") + return [url] if url else [] + def process( self, input_file_path: str, diff --git a/src/unstract/sdk/adapters/x2text/unstructured_enterprise/src/unstructured_enterprise.py b/src/unstract/sdk/adapters/x2text/unstructured_enterprise/src/unstructured_enterprise.py index 908be6da..a713f409 100644 --- a/src/unstract/sdk/adapters/x2text/unstructured_enterprise/src/unstructured_enterprise.py +++ b/src/unstract/sdk/adapters/x2text/unstructured_enterprise/src/unstructured_enterprise.py @@ -33,6 +33,11 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/UnstructuredIO.png" + def get_configured_urls(self) -> list[str]: + """Return all URLs this adapter will connect to.""" + url = self.config.get("url") + return [url] if url else [] + def process( self, input_file_path: str, diff --git a/src/unstract/sdk/prompt.py b/src/unstract/sdk/prompt.py index b21c8315..7818656e 100644 --- a/src/unstract/sdk/prompt.py +++ b/src/unstract/sdk/prompt.py @@ -6,11 +6,8 @@ import requests from deprecated import deprecated from requests import ConnectionError, RequestException, Response -from unstract.sdk.constants import ( - MimeType, - RequestHeader, - ToolEnv, -) + +from unstract.sdk.constants import MimeType, RequestHeader, ToolEnv from unstract.sdk.helper import SdkHelper from unstract.sdk.platform import PlatformHelper from unstract.sdk.tool.base import BaseTool @@ -22,7 +19,9 @@ R = TypeVar("R") -def handle_service_exceptions(context: str) -> Callable[[Callable[P, R]], Callable[P, R]]: +def handle_service_exceptions( + context: str, +) -> Callable[[Callable[P, R]], Callable[P, R]]: """Decorator to handle exceptions in PromptTool service calls. Args: @@ -39,20 +38,23 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: except ConnectionError as e: msg = f"Error while {context}. Unable to connect to prompt service." logger.error(f"{msg}\n{e}") - args[0].tool.stream_error_and_exit(msg, e) + args[0].tool.stream_error_and_exit(msg, e, None) except RequestException as e: error_message = str(e) + status_code = None response = getattr(e, "response", None) if response is not None: + status_code = response.status_code if ( - MimeType.JSON in response.headers.get("Content-Type", "").lower() + MimeType.JSON + in response.headers.get("Content-Type", "").lower() and "error" in response.json() ): error_message = response.json()["error"] elif response.text: error_message = response.text msg = f"Error while {context}. {error_message}" - args[0].tool.stream_error_and_exit(msg, e) + args[0].tool.stream_error_and_exit(msg, e, status_code) return wrapper @@ -79,7 +81,9 @@ def __init__( is_public_call (bool): Whether the call is public. Defaults to False """ self.tool = tool - self.base_url = SdkHelper.get_platform_base_url(prompt_host, prompt_port) + self.base_url = SdkHelper.get_platform_base_url( + prompt_host, prompt_port + ) self.is_public_call = is_public_call self.request_id = request_id if not is_public_call: @@ -168,7 +172,9 @@ def summarize( headers=headers, ) - def _get_headers(self, headers: dict[str, str] | None = None) -> dict[str, str]: + def _get_headers( + self, headers: dict[str, str] | None = None + ) -> dict[str, str]: """Get default headers for requests. Returns: @@ -218,13 +224,13 @@ def _call_service( response = requests.get(url=url, params=params, headers=req_headers) else: raise ValueError(f"Unsupported HTTP method: {method}") - response.raise_for_status() return response.json() @staticmethod @deprecated( - version="v0.71.0", reason="Use `PlatformHelper.get_prompt_studio_tool` instead" + version="v0.71.0", + reason="Use `PlatformHelper.get_prompt_studio_tool` instead", ) def get_exported_tool( tool: BaseTool, prompt_registry_id: str diff --git a/src/unstract/sdk/tool/stream.py b/src/unstract/sdk/tool/stream.py index 9286025a..ca9c84b3 100644 --- a/src/unstract/sdk/tool/stream.py +++ b/src/unstract/sdk/tool/stream.py @@ -115,18 +115,19 @@ def stream_log( } print(json.dumps(record)) - def stream_error_and_exit(self, message: str, err: Exception | None = None) -> None: + def stream_error_and_exit(self, message: str, err: Exception | None = None, status_code: int | None = None) -> None: """Stream error log and exit. Args: message (str): Error message err (Exception): Actual exception that occurred + status_code (int): HTTP status code to preserve """ self.stream_log(message, level=LogLevel.ERROR) if self._exec_by_tool: exit(1) else: - raise SdkError(message, actual_err=err) + raise SdkError(message, status_code=status_code, actual_err=err) def get_env_or_die(self, env_key: str) -> str: """Returns the value of an env variable. diff --git a/tests/test_url_validator.py b/tests/test_url_validator.py new file mode 100644 index 00000000..0560afeb --- /dev/null +++ b/tests/test_url_validator.py @@ -0,0 +1,208 @@ +import os +import unittest +from unittest.mock import patch + +from unstract.sdk.adapters.url_validator import URLValidator + + +class TestURLValidator(unittest.TestCase): + """Test cases for URL validation functionality.""" + + def setUp(self): + """Set up test environment.""" + # Clear any existing environment variables + if URLValidator.ENV_VAR in os.environ: + del os.environ[URLValidator.ENV_VAR] + + def tearDown(self): + """Clean up after tests.""" + # Clear environment variables + if URLValidator.ENV_VAR in os.environ: + del os.environ[URLValidator.ENV_VAR] + + @patch("socket.gethostbyname", return_value="1.1.1.1") + def test_public_urls_allowed(self, _): + """Test that public URLs are allowed by default.""" + test_cases = [ + "https://api.openai.com/v1/chat/completions", + "https://google.com", + "http://example.com", + "https://1.1.1.1:8080", # Public IP with port + ] + + for url in test_cases: + with self.subTest(url=url): + is_valid, error = URLValidator.validate_url(url) + self.assertTrue( + is_valid, f"Public URL should be valid: {url}, Error: {error}" + ) + @patch("socket.gethostbyname") + def test_private_ips_blocked_by_default(self, mock_gethostbyname): + """Test that private IPs are blocked when not whitelisted.""" + test_cases = [ + ("https://192.168.1.100", "192.168.1.100"), # Private class C + ("https://10.0.0.5:8080", "10.0.0.5"), # Private class A with port + ("https://172.16.5.10", "172.16.5.10"), # Private class B + ("https://127.0.0.1", "127.0.0.1"), # Localhost + ("https://169.254.169.254", "169.254.169.254"), # Link-local (AWS metadata) + ] + + for url, ip in test_cases: + with self.subTest(url=url): + mock_gethostbyname.return_value = ip + is_valid, error = URLValidator.validate_url(url) + self.assertFalse(is_valid, f"Private IP should be blocked: {url}") + self.assertIn("not in", error) + self.assertIn("whitelist", error) + + @patch("socket.gethostbyname") + def test_whitelisted_private_ips_allowed(self, mock_gethostbyname): + """Test that whitelisted private IPs are allowed.""" + # Set whitelist environment variable + os.environ[URLValidator.ENV_VAR] = "192.168.1.100:8080,10.0.0.0/8" + + test_cases = [ + ("https://192.168.1.100:8080", "192.168.1.100"), # Exact IP:port match + ("https://10.0.0.5:9200", "10.0.0.5"), # CIDR range match + ("https://10.255.255.255", "10.255.255.255"), # CIDR range edge + ] + + for url, ip in test_cases: + with self.subTest(url=url): + mock_gethostbyname.return_value = ip + is_valid, error = URLValidator.validate_url(url) + self.assertTrue( + is_valid, f"Whitelisted IP should be allowed: {url}, Error: {error}" + ) + + @patch("socket.gethostbyname") + def test_port_specific_whitelist(self, mock_gethostbyname): + """Test port-specific whitelisting.""" + os.environ[URLValidator.ENV_VAR] = "192.168.1.100:8080" + mock_gethostbyname.return_value = "192.168.1.100" + + # Port match - should be allowed + is_valid, error = URLValidator.validate_url("https://192.168.1.100:8080") + self.assertTrue(is_valid, "Matching port should be allowed") + + # Port mismatch - should be blocked + is_valid, error = URLValidator.validate_url("https://192.168.1.100:9000") + self.assertFalse(is_valid, "Non-matching port should be blocked") + + @patch("socket.gethostbyname") + def test_cidr_range_matching(self, mock_gethostbyname): + """Test CIDR range matching in whitelist.""" + os.environ[URLValidator.ENV_VAR] = "192.168.1.0/24:8080" + + test_cases = [ + ("192.168.1.1", True), # In range + ("192.168.1.255", True), # In range (edge) + ("192.168.2.1", False), # Out of range + ("192.168.0.255", False), # Out of range + ] + + for ip, should_be_valid in test_cases: + with self.subTest(ip=ip): + mock_gethostbyname.return_value = ip + is_valid, error = URLValidator.validate_url(f"https://{ip}:8080") + self.assertEqual( + is_valid, should_be_valid, f"CIDR matching failed for {ip}: {error}" + ) + + def test_whitelist_parsing(self): + """Test whitelist configuration parsing.""" + # Test various whitelist formats + os.environ[URLValidator.ENV_VAR] = ( + "192.168.1.100:8080,10.0.0.0/8,172.16.5.100:3000" + ) + + entries = URLValidator._parse_whitelist_config() + + self.assertEqual(len(entries), 3) + + # Check first entry (single IP with port) + self.assertEqual(str(entries[0].ip_network), "192.168.1.100/32") + self.assertEqual(entries[0].port, 8080) + + # Check second entry (CIDR without port) + self.assertEqual(str(entries[1].ip_network), "10.0.0.0/8") + self.assertIsNone(entries[1].port) + + # Check third entry (single IP with port) + self.assertEqual(str(entries[2].ip_network), "172.16.5.100/32") + self.assertEqual(entries[2].port, 3000) + + def test_invalid_whitelist_entries_ignored(self): + """Test that invalid whitelist entries are ignored gracefully.""" + os.environ[URLValidator.ENV_VAR] = ( + "192.168.1.100:8080,invalid-ip,10.0.0.0/8,bad-cidr/35" + ) + + entries = URLValidator._parse_whitelist_config() + + # Only valid entries should be parsed + self.assertEqual(len(entries), 2) + self.assertEqual(str(entries[0].ip_network), "192.168.1.100/32") + self.assertEqual(str(entries[1].ip_network), "10.0.0.0/8") + + def test_empty_whitelist_config(self): + """Test behavior with empty whitelist configuration.""" + os.environ[URLValidator.ENV_VAR] = "" + + entries = URLValidator._parse_whitelist_config() + self.assertEqual(len(entries), 0) + + @patch("socket.gethostbyname") + def test_dns_resolution_failure(self, mock_gethostbyname): + """Test handling of DNS resolution failures.""" + mock_gethostbyname.side_effect = Exception("DNS resolution failed") + + is_valid, error = URLValidator.validate_url("https://nonexistent.example.com") + self.assertFalse(is_valid) + self.assertIn("DNS resolution failed", error) + + def test_invalid_url_handling(self): + """Test handling of invalid URLs.""" + invalid_urls = [ + "not-a-url", + "https://", # No hostname + "", # Empty URL + ] + + for url in invalid_urls: + with self.subTest(url=url): + is_valid, error = URLValidator.validate_url(url) + self.assertFalse(is_valid) + self.assertTrue(len(error) > 0) + + @patch("socket.gethostbyname") + def test_localhost_blocked_by_default(self, mock_gethostbyname): + """Test that localhost is blocked when not explicitly whitelisted.""" + # No whitelist configured - localhost should be blocked + + localhost_ips = ["127.0.0.1", "127.0.0.2", "127.255.255.255"] + + for ip in localhost_ips: + with self.subTest(ip=ip): + mock_gethostbyname.return_value = ip + is_valid, error = URLValidator.validate_url(f"https://{ip}") + self.assertFalse( + is_valid, f"Localhost IP should be blocked by default: {ip}" + ) + self.assertIn("not in", error) + + @patch("socket.gethostbyname") + def test_metadata_service_blocked(self, mock_gethostbyname): + """Test that cloud metadata services are blocked.""" + # AWS/Azure metadata service + mock_gethostbyname.return_value = "169.254.169.254" + + is_valid, error = URLValidator.validate_url( + "https://169.254.169.254/latest/meta-data" + ) + self.assertFalse(is_valid) + self.assertIn("not in", error) + + +if __name__ == "__main__": + unittest.main()