Skip to content

Commit 90e6e27

Browse files
committed
refactor: keep take_screenshot consistent
1 parent 711ccae commit 90e6e27

File tree

2 files changed

+85
-17
lines changed

2 files changed

+85
-17
lines changed

pydoll/browser/tab.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ async def refresh(
557557

558558
async def take_screenshot(
559559
self,
560-
path: Optional[str] = None,
560+
path: Optional[str | Path] = None,
561561
quality: int = 100,
562562
beyond_viewport: bool = False,
563563
as_base64: bool = False,
@@ -582,17 +582,28 @@ async def take_screenshot(
582582
if not path and not as_base64:
583583
raise MissingScreenshotPath()
584584

585-
output_extension = path.split('.')[-1] if path else ScreenshotFormat.PNG
585+
if path and isinstance(path, str):
586+
output_extension = path.split('.')[-1]
587+
elif path and isinstance(path, Path):
588+
output_extension = path.suffix.lstrip('.')
589+
else:
590+
output_extension = ScreenshotFormat.JPEG
591+
592+
# Normalize jpg to jpeg (CDP only accepts jpeg)
593+
output_extension = output_extension.replace('jpg', 'jpeg') if output_extension == 'jpg' else output_extension
594+
586595
if not ScreenshotFormat.has_value(output_extension):
587596
raise InvalidFileExtension(f'{output_extension} extension is not supported.')
588597

598+
output_format = ScreenshotFormat.get_value(output_extension)
599+
589600
logger.info(
590601
f'Taking screenshot: path={path}, quality={quality}, '
591602
f'beyond_viewport={beyond_viewport}, as_base64={as_base64}'
592603
)
593604
response: CaptureScreenshotResponse = await self._execute_command(
594605
PageCommands.capture_screenshot(
595-
format=ScreenshotFormat.get_value(output_extension),
606+
format=output_format,
596607
quality=quality,
597608
capture_beyond_viewport=beyond_viewport,
598609
)
@@ -612,15 +623,15 @@ async def take_screenshot(
612623

613624
if path:
614625
screenshot_bytes = decode_base64_to_bytes(screenshot_data)
615-
async with aiofiles.open(path, 'wb') as file:
626+
async with aiofiles.open(str(path), 'wb') as file:
616627
await file.write(screenshot_bytes)
617628
logger.info(f'Screenshot saved to: {path}')
618629

619630
return None
620631

621632
async def print_to_pdf(
622633
self,
623-
path: str,
634+
path: Optional[str | Path] = None,
624635
landscape: bool = False,
625636
display_header_footer: bool = False,
626637
print_background: bool = True,
@@ -631,7 +642,7 @@ async def print_to_pdf(
631642
Generate PDF of current page.
632643
633644
Args:
634-
path: File path for PDF output.
645+
path: File path for PDF output. Required if as_base64=False.
635646
landscape: Use landscape orientation.
636647
display_header_footer: Include header/footer.
637648
print_background: Include background graphics.
@@ -640,6 +651,9 @@ async def print_to_pdf(
640651
641652
Returns:
642653
Base64 PDF data if as_base64=True, None otherwise.
654+
655+
Raises:
656+
ValueError: If path is not provided when as_base64=False.
643657
"""
644658
logger.info(
645659
f'Generating PDF: path={path}, landscape={landscape}, '
@@ -659,6 +673,9 @@ async def print_to_pdf(
659673
logger.info('PDF generated and returned as base64')
660674
return pdf_data
661675

676+
if path is None:
677+
raise ValueError("path is required when as_base64=False")
678+
662679
pdf_bytes = decode_base64_to_bytes(pdf_data)
663680
async with aiofiles.open(path, 'wb') as file:
664681
await file.write(pdf_bytes)
@@ -826,7 +843,7 @@ async def continue_with_auth(
826843

827844
@asynccontextmanager
828845
async def expect_file_chooser(
829-
self, files: Union[str, Path, list[Union[str, Path]]]
846+
self, files: str | Path | list[str | Path]
830847
) -> AsyncGenerator[None, None]:
831848
"""
832849
Context manager for automatic file upload handling.

pydoll/elements/web_element.py

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import json
55
import logging
6+
from pathlib import Path
67
from typing import TYPE_CHECKING
78

89
import aiofiles
@@ -24,6 +25,8 @@
2425
ElementNotFound,
2526
ElementNotInteractable,
2627
ElementNotVisible,
28+
InvalidFileExtension,
29+
MissingScreenshotPath,
2730
WaitElementTimeout,
2831
)
2932
from pydoll.protocol.input.types import (
@@ -226,12 +229,48 @@ async def get_siblings_elements(
226229
logger.debug(f'Siblings found: {len(siblings)}')
227230
return siblings
228231

229-
async def take_screenshot(self, path: str, quality: int = 100):
232+
async def take_screenshot(
233+
self,
234+
path: Optional[str | Path] = None,
235+
quality: int = 100,
236+
as_base64: bool = False,
237+
) -> Optional[str]:
230238
"""
231239
Capture screenshot of this element only.
232240
233241
Automatically scrolls element into view before capturing.
242+
243+
Args:
244+
path: File path for screenshot (extension determines format).
245+
quality: Image quality 0-100 (default 100).
246+
as_base64: Return as base64 string instead of saving file.
247+
248+
Returns:
249+
Base64 screenshot data if as_base64=True, None otherwise.
250+
251+
Raises:
252+
InvalidFileExtension: If file extension not supported.
253+
MissingScreenshotPath: If path is None and as_base64 is False.
234254
"""
255+
if not path and not as_base64:
256+
raise MissingScreenshotPath()
257+
258+
if path and isinstance(path, str):
259+
output_extension = path.split('.')[-1]
260+
elif path and isinstance(path, Path):
261+
output_extension = path.suffix.lstrip('.')
262+
else:
263+
output_extension = ScreenshotFormat.JPEG
264+
265+
# Normalize jpg to jpeg (CDP only accepts jpeg)
266+
if output_extension == 'jpg':
267+
output_extension = 'jpeg'
268+
269+
if not ScreenshotFormat.has_value(output_extension):
270+
raise InvalidFileExtension(f'{output_extension} extension is not supported.')
271+
272+
file_format = ScreenshotFormat.get_value(output_extension)
273+
235274
bounds = await self.get_bounds_using_js()
236275
clip = Viewport(
237276
x=bounds['x'],
@@ -241,18 +280,29 @@ async def take_screenshot(self, path: str, quality: int = 100):
241280
scale=1,
242281
)
243282
logger.debug(
244-
f'Taking element screenshot: path={path}, quality={quality}, '
283+
f'Taking element screenshot: path={path}, quality={quality}, as_base64={as_base64}, '
245284
f'clip={{x: {clip["x"]}, y: {clip["y"]}, w: {clip["width"]}, h: {clip["height"]}}}'
246285
)
286+
247287
screenshot: CaptureScreenshotResponse = await self._connection_handler.execute_command(
248288
PageCommands.capture_screenshot(
249-
format=ScreenshotFormat.JPEG, clip=clip, quality=quality
289+
format=file_format, clip=clip, quality=quality
250290
)
251291
)
252-
async with aiofiles.open(path, 'wb') as file:
253-
image_bytes = decode_base64_to_bytes(screenshot['result']['data'])
254-
await file.write(image_bytes)
255-
logger.info(f'Element screenshot saved: {path}')
292+
293+
screenshot_data = screenshot['result']['data']
294+
295+
if as_base64:
296+
logger.info('Element screenshot captured and returned as base64')
297+
return screenshot_data
298+
299+
if path:
300+
image_bytes = decode_base64_to_bytes(screenshot_data)
301+
async with aiofiles.open(str(path), 'wb') as file:
302+
await file.write(image_bytes)
303+
logger.info(f'Element screenshot saved: {path}')
304+
305+
return None
256306

257307
def get_attribute(self, name: str) -> Optional[str]:
258308
"""
@@ -417,7 +467,7 @@ async def insert_text(self, text: str):
417467
logger.info(f'Inserting text on element (length={len(text)})')
418468
await self._execute_command(InputCommands.insert_text(text))
419469

420-
async def set_input_files(self, files: list[str]):
470+
async def set_input_files(self, files: str | Path | list[str | Path]):
421471
"""
422472
Set file paths for file input element.
423473
@@ -432,9 +482,10 @@ async def set_input_files(self, files: list[str]):
432482
or self._attributes.get('type', '').lower() != 'file'
433483
):
434484
raise ElementNotAFileInput()
435-
logger.info(f'Setting input files: count={len(files)}')
485+
files_list = [str(file) for file in files] if isinstance(files, list) else [str(files)]
486+
logger.info(f'Setting input files: count={len(files_list)}')
436487
await self._execute_command(
437-
DomCommands.set_file_input_files(files=files, object_id=self._object_id)
488+
DomCommands.set_file_input_files(files=files_list, object_id=self._object_id)
438489
)
439490

440491
async def type_text(self, text: str, interval: float = 0.1):

0 commit comments

Comments
 (0)