Skip to content

Commit 681cb76

Browse files
committed
init decorator and tests for checking server_url
1 parent 22d2e2c commit 681cb76

File tree

3 files changed

+132
-0
lines changed

3 files changed

+132
-0
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import os
2+
import pytest
3+
4+
from unstructured_client import UnstructuredClient
5+
6+
7+
def get_api_key():
8+
api_key = os.getenv("UNS_API_KEY")
9+
if api_key is None:
10+
raise ValueError("""UNS_API_KEY environment variable not set.
11+
Set it in your current shell session with `export UNS_API_KEY=<api_key>`""")
12+
return api_key
13+
14+
15+
@pytest.mark.parametrize(
16+
("server_url"),
17+
[
18+
("https://unstructured-000mock.api.unstructuredapp.io"), # correct url
19+
("unstructured-000mock.api.unstructuredapp.io"),
20+
("http://unstructured-000mock.api.unstructuredapp.io/general/v0/general"),
21+
("https://unstructured-000mock.api.unstructuredapp.io/general/v0/general"),
22+
("unstructured-000mock.api.unstructuredapp.io/general/v0/general"),
23+
]
24+
)
25+
def test_clean_server_url_on_paid_api_url(server_url: str):
26+
client = UnstructuredClient(
27+
server_url=server_url,
28+
api_key_auth=get_api_key(),
29+
)
30+
assert client.general.sdk_configuration.server_url == "https://unstructured-000mock.api.unstructuredapp.io"
31+
32+
33+
@pytest.mark.parametrize(
34+
("server_url"),
35+
[
36+
("http://localhost:8000"), # correct url
37+
("https://localhost:8000"),
38+
("localhost:8000"),
39+
("localhost:8000/general/v0/general"),
40+
("https://localhost:8000/general/v0/general"),
41+
]
42+
)
43+
def test_clean_server_url_on_localhost(server_url: str):
44+
client = UnstructuredClient(
45+
server_url=server_url,
46+
api_key_auth=get_api_key(),
47+
)
48+
assert client.general.sdk_configuration.server_url == "http://localhost:8000"
49+
50+
51+
def test_clean_server_url_on_empty_string():
52+
client = UnstructuredClient(
53+
server_url="",
54+
api_key_auth=get_api_key(),
55+
)
56+
assert client.general.sdk_configuration.server_url == ""
57+
58+
@pytest.mark.parametrize(
59+
("server_url"),
60+
[
61+
("https://unstructured-000mock.api.unstructuredapp.io"),
62+
("unstructured-000mock.api.unstructuredapp.io/general/v0/general"),
63+
]
64+
)
65+
def test_clean_server_url_with_positional_arguments(server_url: str):
66+
client = UnstructuredClient(
67+
get_api_key(),
68+
"",
69+
server_url,
70+
)
71+
assert client.general.sdk_configuration.server_url == "https://unstructured-000mock.api.unstructuredapp.io"

src/unstructured_client/sdk.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
from typing import Callable, Dict, Union
77
from unstructured_client import utils
88
from unstructured_client.models import shared
9+
from unstructured_client.utils._decorators import clean_server_url
910

1011
class UnstructuredClient:
1112
r"""Unstructured Pipeline API: Partition documents with the Unstructured library"""
1213
general: General
1314

1415
sdk_configuration: SDKConfiguration
1516

17+
@clean_server_url
1618
def __init__(self,
1719
api_key_auth: Union[str, Callable[[], str]],
1820
server: str = None,
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from __future__ import annotations
2+
3+
import functools
4+
import inspect
5+
from typing import Any, Callable, Dict, Optional
6+
from typing_extensions import ParamSpec, Concatenate
7+
from urllib.parse import urlparse, urlunparse, ParseResult
8+
9+
10+
_P = ParamSpec("_P")
11+
12+
13+
def clean_server_url(func: Callable[_P, None]) -> Callable[_P, None]:
14+
15+
@functools.wraps(func)
16+
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
17+
18+
def get_call_args_applying_defaults() -> Dict[str, Any]:
19+
"""Map both explicit and default arguments of decorated func call by param name."""
20+
sig = inspect.signature(func)
21+
call_args: Dict[str, Any] = dict(
22+
**dict(zip(sig.parameters, args)), **kwargs
23+
)
24+
for param in sig.parameters.values():
25+
if param.name not in call_args and param.default is not param.empty:
26+
call_args[param.name] = param.default
27+
return call_args
28+
29+
call_args = get_call_args_applying_defaults()
30+
31+
server_url: Optional[str] = call_args.get("server_url")
32+
33+
if server_url:
34+
# -- add a url scheme if not present (urllib.parse does not work reliably without it)
35+
if "http" not in server_url:
36+
server_url = "http://" + server_url
37+
38+
parsed_url: ParseResult = urlparse(server_url)
39+
40+
if "api.unstructuredapp.io" in server_url:
41+
if parsed_url.scheme != "https":
42+
parsed_url = parsed_url._replace(scheme="https")
43+
else:
44+
# -- if not a paid api url, assume the api is hosted locally and the scheme is "http"
45+
if parsed_url.scheme != "http":
46+
parsed_url = parsed_url._replace(scheme="http")
47+
48+
# -- path should always be empty
49+
cleaned_url = parsed_url._replace(path="")
50+
call_args["server_url"] = urlunparse(cleaned_url)
51+
52+
# -- call_args contains all args and kwargs. If users define some parameters using
53+
# -- kwargs, param definitions would be duplicated. Pass only the `self`
54+
# -- param as an arg and keep the rest in kwargs to prevent duplicates.
55+
self_arg = (call_args.pop("self"),)
56+
57+
return func(*self_arg, **call_args) # type: ignore
58+
59+
return wrapper

0 commit comments

Comments
 (0)