Skip to content

Commit e166d37

Browse files
authored
Merge pull request #18 from HowieG/fix-retrieve
2 parents 30239d3 + bdad989 commit e166d37

File tree

9 files changed

+267
-247
lines changed

9 files changed

+267
-247
lines changed

poetry.lock

Lines changed: 11 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "multion"
3-
version = "1.2.0"
3+
version = "1.3.0"
44
description = ""
55
readme = "README.md"
66
authors = []
@@ -10,7 +10,7 @@ packages = [
1010

1111
[tool.poetry.dependencies]
1212
python = "^3.8"
13-
agentops = "^0.2.2"
13+
agentops = "^0.2.3"
1414
httpx = ">=0.21.2"
1515
httpx-sse = "0.4.0"
1616
pydantic = ">= 1.9.2"

src/multion/base_client.py

Lines changed: 237 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .types.internal_server_error_response import InternalServerErrorResponse
2727
from .types.mode import Mode
2828
from .types.payment_required_response import PaymentRequiredResponse
29+
from .types.retrieve_output import RetrieveOutput
2930
from .types.unauthorized_response import UnauthorizedResponse
3031

3132
# this is used as the default value for optional parameters
@@ -68,7 +69,7 @@ def __init__(
6869
follow_redirects: typing.Optional[bool] = True,
6970
httpx_client: typing.Optional[httpx.Client] = None,
7071
):
71-
_defaulted_timeout = timeout if timeout is not None else 60 if httpx_client is None else None
72+
_defaulted_timeout = timeout if timeout is not None else 180 if httpx_client is None else None
7273
if api_key is None:
7374
raise ApiError(
7475
body="The client must be instantiated be either passing in api_key or setting MULTION_API_KEY"
@@ -131,7 +132,7 @@ def browse(
131132
api_key="YOUR_API_KEY",
132133
)
133134
client.browse(
134-
cmd="find the top post on hackernews",
135+
cmd="Find the top post on Hackernews.",
135136
url="https://news.ycombinator.com/",
136137
)
137138
"""
@@ -206,6 +207,122 @@ def browse(
206207
raise ApiError(status_code=_response.status_code, body=_response.text)
207208
raise ApiError(status_code=_response.status_code, body=_response_json)
208209

210+
def retrieve(
211+
self,
212+
*,
213+
cmd: str,
214+
url: typing.Optional[str] = OMIT,
215+
session_id: typing.Optional[str] = OMIT,
216+
local: typing.Optional[bool] = OMIT,
217+
fields: typing.Optional[typing.Sequence[str]] = OMIT,
218+
format: typing.Optional[typing.Literal["json"]] = OMIT,
219+
max_items: typing.Optional[float] = OMIT,
220+
full_page: typing.Optional[bool] = OMIT,
221+
render_js: typing.Optional[bool] = OMIT,
222+
scroll_to_bottom: typing.Optional[bool] = OMIT,
223+
include_screenshot: typing.Optional[bool] = OMIT,
224+
request_options: typing.Optional[RequestOptions] = None,
225+
) -> RetrieveOutput:
226+
"""
227+
Retrieve data from webpage based on a url and natural language command that guides agents data extraction process.
228+
229+
The function can create a new session or be used as part of a session.
230+
231+
Parameters:
232+
- cmd: str. A specific natural language instruction on data the agent should extract.
233+
234+
- url: typing.Optional[str]. The URL to create or continue session from.
235+
236+
- session_id: typing.Optional[str]. Continues the session with session_id if provided.
237+
238+
- local: typing.Optional[bool]. Boolean flag to indicate if session to be run locally or in the cloud (Default: False). If set to true, the session will be run locally via your chrome extension. If set to false, the session will be run in the cloud.
239+
240+
- fields: typing.Optional[typing.Sequence[str]]. List of fields (columns) to be outputted in data.
241+
242+
- format: typing.Optional[typing.Literal["json"]]. Format of response data. (Default: json)
243+
244+
- max_items: typing.Optional[float]. Maximum number of data items to retrieve. (Default: 100)
245+
246+
- full_page: typing.Optional[bool]. Flag to retrieve full page (Default: True). If set to false, the data will only be retrieved from the current session viewport.
247+
248+
- render_js: typing.Optional[bool]. Flag to include rich JS and ARIA elements in data retrieved. (Default: False)
249+
250+
- scroll_to_bottom: typing.Optional[bool]. Flag to scroll to the bottom of the page (Default: False). If set to true, the page will be scrolled to the bottom for a maximum of 5 seconds before data is retrieved.
251+
252+
- include_screenshot: typing.Optional[bool]. Flag to include a screenshot with the response. (Default: False)
253+
254+
- request_options: typing.Optional[RequestOptions]. Request-specific configuration.
255+
---
256+
from multion.client import MultiOn
257+
258+
client = MultiOn(
259+
api_key="YOUR_API_KEY",
260+
)
261+
client.retrieve(
262+
cmd="Find the top post on Hackernews and get its title and points.",
263+
url="https://news.ycombinator.com/",
264+
fields=["title", "points"],
265+
)
266+
"""
267+
_request: typing.Dict[str, typing.Any] = {"cmd": cmd}
268+
if url is not OMIT:
269+
_request["url"] = url
270+
if session_id is not OMIT:
271+
_request["session_id"] = session_id
272+
if local is not OMIT:
273+
_request["local"] = local
274+
if fields is not OMIT:
275+
_request["fields"] = fields
276+
if format is not OMIT:
277+
_request["format"] = format
278+
if max_items is not OMIT:
279+
_request["max_items"] = max_items
280+
if full_page is not OMIT:
281+
_request["full_page"] = full_page
282+
if render_js is not OMIT:
283+
_request["render_js"] = render_js
284+
if scroll_to_bottom is not OMIT:
285+
_request["scroll_to_bottom"] = scroll_to_bottom
286+
if include_screenshot is not OMIT:
287+
_request["include_screenshot"] = include_screenshot
288+
_response = self._client_wrapper.httpx_client.request(
289+
method="POST",
290+
url=urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "retrieve"),
291+
params=jsonable_encoder(
292+
request_options.get("additional_query_parameters") if request_options is not None else None
293+
),
294+
json=jsonable_encoder(_request)
295+
if request_options is None or request_options.get("additional_body_parameters") is None
296+
else {
297+
**jsonable_encoder(_request),
298+
**(jsonable_encoder(remove_none_from_dict(request_options.get("additional_body_parameters", {})))),
299+
},
300+
headers=jsonable_encoder(
301+
remove_none_from_dict(
302+
{
303+
**self._client_wrapper.get_headers(),
304+
**(request_options.get("additional_headers", {}) if request_options is not None else {}),
305+
}
306+
)
307+
),
308+
timeout=request_options.get("timeout_in_seconds")
309+
if request_options is not None and request_options.get("timeout_in_seconds") is not None
310+
else self._client_wrapper.get_timeout(),
311+
retries=0,
312+
max_retries=request_options.get("max_retries") if request_options is not None else 0, # type: ignore
313+
)
314+
if 200 <= _response.status_code < 300:
315+
return typing.cast(RetrieveOutput, construct_type(type_=RetrieveOutput, object_=_response.json())) # type: ignore
316+
if _response.status_code == 422:
317+
raise UnprocessableEntityError(
318+
typing.cast(HttpValidationError, construct_type(type_=HttpValidationError, object_=_response.json())) # type: ignore
319+
)
320+
try:
321+
_response_json = _response.json()
322+
except JSONDecodeError:
323+
raise ApiError(status_code=_response.status_code, body=_response.text)
324+
raise ApiError(status_code=_response.status_code, body=_response_json)
325+
209326

210327
class AsyncBaseMultiOn:
211328
"""
@@ -243,7 +360,7 @@ def __init__(
243360
follow_redirects: typing.Optional[bool] = True,
244361
httpx_client: typing.Optional[httpx.AsyncClient] = None,
245362
):
246-
_defaulted_timeout = timeout if timeout is not None else 60 if httpx_client is None else None
363+
_defaulted_timeout = timeout if timeout is not None else 180 if httpx_client is None else None
247364
if api_key is None:
248365
raise ApiError(
249366
body="The client must be instantiated be either passing in api_key or setting MULTION_API_KEY"
@@ -306,7 +423,7 @@ async def browse(
306423
api_key="YOUR_API_KEY",
307424
)
308425
await client.browse(
309-
cmd="find the top post on hackernews",
426+
cmd="Find the top post on Hackernews.",
310427
url="https://news.ycombinator.com/",
311428
)
312429
"""
@@ -381,6 +498,122 @@ async def browse(
381498
raise ApiError(status_code=_response.status_code, body=_response.text)
382499
raise ApiError(status_code=_response.status_code, body=_response_json)
383500

501+
async def retrieve(
502+
self,
503+
*,
504+
cmd: str,
505+
url: typing.Optional[str] = OMIT,
506+
session_id: typing.Optional[str] = OMIT,
507+
local: typing.Optional[bool] = OMIT,
508+
fields: typing.Optional[typing.Sequence[str]] = OMIT,
509+
format: typing.Optional[typing.Literal["json"]] = OMIT,
510+
max_items: typing.Optional[float] = OMIT,
511+
full_page: typing.Optional[bool] = OMIT,
512+
render_js: typing.Optional[bool] = OMIT,
513+
scroll_to_bottom: typing.Optional[bool] = OMIT,
514+
include_screenshot: typing.Optional[bool] = OMIT,
515+
request_options: typing.Optional[RequestOptions] = None,
516+
) -> RetrieveOutput:
517+
"""
518+
Retrieve data from webpage based on a url and natural language command that guides agents data extraction process.
519+
520+
The function can create a new session or be used as part of a session.
521+
522+
Parameters:
523+
- cmd: str. A specific natural language instruction on data the agent should extract.
524+
525+
- url: typing.Optional[str]. The URL to create or continue session from.
526+
527+
- session_id: typing.Optional[str]. Continues the session with session_id if provided.
528+
529+
- local: typing.Optional[bool]. Boolean flag to indicate if session to be run locally or in the cloud (Default: False). If set to true, the session will be run locally via your chrome extension. If set to false, the session will be run in the cloud.
530+
531+
- fields: typing.Optional[typing.Sequence[str]]. List of fields (columns) to be outputted in data.
532+
533+
- format: typing.Optional[typing.Literal["json"]]. Format of response data. (Default: json)
534+
535+
- max_items: typing.Optional[float]. Maximum number of data items to retrieve. (Default: 100)
536+
537+
- full_page: typing.Optional[bool]. Flag to retrieve full page (Default: True). If set to false, the data will only be retrieved from the current session viewport.
538+
539+
- render_js: typing.Optional[bool]. Flag to include rich JS and ARIA elements in data retrieved. (Default: False)
540+
541+
- scroll_to_bottom: typing.Optional[bool]. Flag to scroll to the bottom of the page (Default: False). If set to true, the page will be scrolled to the bottom for a maximum of 5 seconds before data is retrieved.
542+
543+
- include_screenshot: typing.Optional[bool]. Flag to include a screenshot with the response. (Default: False)
544+
545+
- request_options: typing.Optional[RequestOptions]. Request-specific configuration.
546+
---
547+
from multion.client import AsyncMultiOn
548+
549+
client = AsyncMultiOn(
550+
api_key="YOUR_API_KEY",
551+
)
552+
await client.retrieve(
553+
cmd="Find the top post on Hackernews and get its title and points.",
554+
url="https://news.ycombinator.com/",
555+
fields=["title", "points"],
556+
)
557+
"""
558+
_request: typing.Dict[str, typing.Any] = {"cmd": cmd}
559+
if url is not OMIT:
560+
_request["url"] = url
561+
if session_id is not OMIT:
562+
_request["session_id"] = session_id
563+
if local is not OMIT:
564+
_request["local"] = local
565+
if fields is not OMIT:
566+
_request["fields"] = fields
567+
if format is not OMIT:
568+
_request["format"] = format
569+
if max_items is not OMIT:
570+
_request["max_items"] = max_items
571+
if full_page is not OMIT:
572+
_request["full_page"] = full_page
573+
if render_js is not OMIT:
574+
_request["render_js"] = render_js
575+
if scroll_to_bottom is not OMIT:
576+
_request["scroll_to_bottom"] = scroll_to_bottom
577+
if include_screenshot is not OMIT:
578+
_request["include_screenshot"] = include_screenshot
579+
_response = await self._client_wrapper.httpx_client.request(
580+
method="POST",
581+
url=urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "retrieve"),
582+
params=jsonable_encoder(
583+
request_options.get("additional_query_parameters") if request_options is not None else None
584+
),
585+
json=jsonable_encoder(_request)
586+
if request_options is None or request_options.get("additional_body_parameters") is None
587+
else {
588+
**jsonable_encoder(_request),
589+
**(jsonable_encoder(remove_none_from_dict(request_options.get("additional_body_parameters", {})))),
590+
},
591+
headers=jsonable_encoder(
592+
remove_none_from_dict(
593+
{
594+
**self._client_wrapper.get_headers(),
595+
**(request_options.get("additional_headers", {}) if request_options is not None else {}),
596+
}
597+
)
598+
),
599+
timeout=request_options.get("timeout_in_seconds")
600+
if request_options is not None and request_options.get("timeout_in_seconds") is not None
601+
else self._client_wrapper.get_timeout(),
602+
retries=0,
603+
max_retries=request_options.get("max_retries") if request_options is not None else 0, # type: ignore
604+
)
605+
if 200 <= _response.status_code < 300:
606+
return typing.cast(RetrieveOutput, construct_type(type_=RetrieveOutput, object_=_response.json())) # type: ignore
607+
if _response.status_code == 422:
608+
raise UnprocessableEntityError(
609+
typing.cast(HttpValidationError, construct_type(type_=HttpValidationError, object_=_response.json())) # type: ignore
610+
)
611+
try:
612+
_response_json = _response.json()
613+
except JSONDecodeError:
614+
raise ApiError(status_code=_response.status_code, body=_response.text)
615+
raise ApiError(status_code=_response.status_code, body=_response_json)
616+
384617

385618
def _get_base_url(*, base_url: typing.Optional[str] = None, environment: MultiOnEnvironment) -> str:
386619
if base_url is not None:

src/multion/client.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010

1111
from .wrappers import wraps_function
12+
from .types.retrieve_output import RetrieveOutput
1213

1314

1415
# this is used as the default value for optional parameters
@@ -66,6 +67,12 @@ def browse(self, *args, **kwargs):
6667
agentops.start_session(tags=["multion-sdk"])
6768
return super().browse(*args, **kwargs)
6869

70+
@agentops.record_function(event_name="retrieve") # type: ignore
71+
@wraps_function(BaseMultiOn.retrieve)
72+
def retrieve(self, *args, **kwargs) -> RetrieveOutput:
73+
agentops.start_session(tags=["multion-sdk"])
74+
return super().retrieve(*args, **kwargs)
75+
6976

7077
class AsyncMultiOn(AsyncBaseMultiOn):
7178
"""
@@ -117,3 +124,9 @@ def __init__(
117124
async def browse(self, *args, **kwargs):
118125
agentops.start_session(tags=["multion-sdk"])
119126
return super().browse(*args, **kwargs)
127+
128+
@agentops.record_function(event_name="retrieve") # type: ignore
129+
@wraps_function(BaseMultiOn.retrieve)
130+
async def retrieve(self, *args, **kwargs) -> RetrieveOutput:
131+
agentops.start_session(tags=["multion-sdk"])
132+
return super().retrieve(*args, **kwargs)

0 commit comments

Comments
 (0)