Skip to content

Commit 638bee7

Browse files
committed
fix: fixed custom hooks incorrectly handle URLs
1 parent e89c333 commit 638bee7

File tree

3 files changed

+52
-4
lines changed

3 files changed

+52
-4
lines changed

_test_unstructured_client/unit/test_request_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
import httpx
33
import pytest
44

5-
from unstructured_client._hooks.custom.request_utils import create_pdf_chunk_request_params, get_multipart_stream_fields
5+
from unstructured_client._hooks.custom.request_utils import (
6+
create_pdf_chunk_request_params,
7+
get_base_url,
8+
get_multipart_stream_fields,
9+
)
610
from unstructured_client.models import shared
711

812

@@ -70,3 +74,16 @@ def test_multipart_stream_fields_raises_value_error_when_filename_is_not_set():
7074
def test_create_pdf_chunk_request_params(input_form_data, page_number, expected_form_data):
7175
form_data = create_pdf_chunk_request_params(input_form_data, page_number)
7276
assert form_data == expected_form_data
77+
78+
79+
@pytest.mark.parametrize(
80+
("url", "expected_base_url"),
81+
[
82+
("https://api.unstructuredapp.io/general/v0/general", "https://api.unstructuredapp.io"),
83+
("https://api.unstructuredapp.io/general/v0/general?some_param=23", "https://api.unstructuredapp.io"),
84+
("http://localhost:3000/general/v0/general", "http://localhost:3000"),
85+
],
86+
)
87+
def test_get_base_url(url: str, expected_base_url: str):
88+
assert get_base_url(url) == expected_base_url
89+

src/unstructured_client/_hooks/custom/request_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import json
66
import logging
77
from typing import Tuple, Any, BinaryIO
8+
from urllib.parse import urlparse
89

910
import httpx
11+
from httpx import URL
1012
from httpx._multipart import DataField, FileField
1113

1214
from unstructured_client._hooks.custom.common import UNSTRUCTURED_CLIENT_LOGGER_NAME
@@ -224,3 +226,15 @@ def create_response(elements: list) -> httpx.Response:
224226
response.headers.update({"Content-Length": content_length})
225227
setattr(response, "_content", content)
226228
return response
229+
230+
def get_base_url(url: str | URL) -> str:
231+
"""Extracts the base URL from the given URL.
232+
233+
Args:
234+
url: The URL.
235+
236+
Returns:
237+
The base URL.
238+
"""
239+
parsed_url = urlparse(str(url))
240+
return f"{parsed_url.scheme}://{parsed_url.netloc}"

src/unstructured_client/_hooks/custom/split_pdf_hook.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
PARTITION_FORM_SPLIT_PDF_PAGE_KEY,
3232
PARTITION_FORM_STARTING_PAGE_NUMBER_KEY,
3333
)
34+
from unstructured_client._hooks.custom.request_utils import get_base_url
3435
from unstructured_client._hooks.types import (
3536
AfterErrorContext,
3637
AfterErrorHook,
@@ -156,7 +157,8 @@ class SplitPdfHook(SDKInitHook, BeforeRequestHook, AfterSuccessHook, AfterErrorH
156157

157158
def __init__(self) -> None:
158159
self.client: Optional[HttpClient] = None
159-
self.base_url: Optional[str] = None
160+
self.partition_base_url: Optional[str] = None
161+
self.is_partition_request: bool = False
160162
self.async_client: Optional[AsyncHttpClient] = None
161163
self.coroutines_to_execute: dict[
162164
str, list[partial[Coroutine[Any, Any, httpx.Response]]]
@@ -212,7 +214,9 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
212214
# return await self.base_transport.handle_async_request(request)
213215

214216
# Instead, save the base url so we can use it for our dummy request
215-
self.base_url = base_url
217+
# As this can be overwritten with Platform API URL, we need to get it again in
218+
# `before_request` hook from the request object as the real URL is not available here.
219+
self.partition_base_url = base_url
216220

217221
# Explicit cast to httpx.Client to avoid a typing error
218222
httpx_client = cast(httpx.Client, client)
@@ -246,6 +250,16 @@ def before_request(
246250
Union[httpx.PreparedRequest, Exception]: If `splitPdfPage` is set to `true`,
247251
the last page request; otherwise, the original request.
248252
"""
253+
254+
# Actually the general.partition operation overwrites the default client's base url (as
255+
# the platform operations do). Here we need to get the base url from the request object.
256+
if hook_ctx.operation_id == "partition":
257+
self.partition_base_url = get_base_url(request.url)
258+
self.is_partition_request = True
259+
else:
260+
self.is_partition_request = False
261+
return request
262+
249263
if self.client is None:
250264
logger.warning("HTTP client not accessible! Continuing without splitting.")
251265
return request
@@ -391,7 +405,7 @@ def before_request(
391405
# dummy_request = httpx.Request("GET", "http://no-op")
392406
return httpx.Request(
393407
"GET",
394-
f"{self.base_url}/general/docs",
408+
f"{self.partition_base_url}/general/docs",
395409
headers={"operation_id": operation_id},
396410
)
397411

@@ -644,6 +658,9 @@ def after_success(
644658
combined response object; otherwise, the original response. Can return
645659
exception if it ocurred during the execution.
646660
"""
661+
if not self.is_partition_request:
662+
return response
663+
647664
# Grab the correct id out of the dummy request
648665
operation_id = response.request.headers.get("operation_id")
649666

0 commit comments

Comments
 (0)