Skip to content

Commit 519aee5

Browse files
authored
[Security] fix _uncompress_file_tar security problem (PaddlePaddle#76601)
* [Security] fix _uncompress_file_tar security problem * [Security] fix _uncompress_file_tar security problem * [Security] fix _uncompress_file_tar security problem
1 parent 6f80d5c commit 519aee5

File tree

1 file changed

+80
-12
lines changed

1 file changed

+80
-12
lines changed

python/paddle/utils/download.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import tarfile
2323
import time
2424
import zipfile
25+
from pathlib import Path
2526
from typing import Literal
2627

2728
import httpx
@@ -313,35 +314,102 @@ def _uncompress_file_zip(filepath):
313314
return uncompressed_path
314315

315316

316-
def _uncompress_file_tar(filepath, mode="r:*"):
317-
with tarfile.open(filepath, mode) as files:
318-
file_list_tmp = files.getnames()
319-
file_list = []
320-
for file in file_list_tmp:
321-
assert file[0] != "/", (
322-
f"uncompress file path {file} should not start with /"
317+
def _is_within_directory(directory, target):
318+
"""Check if the target path is within the given directory."""
319+
abs_directory = Path(directory).resolve()
320+
abs_target = Path(target).resolve()
321+
try:
322+
abs_target.relative_to(abs_directory)
323+
return True
324+
except ValueError:
325+
return False
326+
327+
328+
def _validate_tar_member_name(name):
329+
"""
330+
Validate tar member name for security.
331+
332+
Raises ValueError if the name contains unsafe patterns:
333+
- Absolute paths (Unix: /path, Windows: C:\\path, UNC: \\\\server\\share)
334+
- Path traversal components ('..')
335+
"""
336+
# Check for absolute paths (cross-platform)
337+
name_path = Path(name)
338+
if name_path.is_absolute():
339+
raise ValueError(f"Unsafe absolute path in tar: {name}")
340+
341+
# Check for path traversal components '..'
342+
if '..' in name_path.parts:
343+
raise ValueError(f"Unsafe path traversal '..' in tar: {name}")
344+
345+
346+
def _safe_extract(tar, path, members=None):
347+
"""
348+
Safely extract tar files to prevent path traversal attacks.
349+
350+
Security measures:
351+
1. Verify resolved paths are within target directory
352+
2. Skip symlinks, hardlinks and other special files
353+
3. Only extract regular files and directories
354+
"""
355+
members_to_check = members if members is not None else tar.getmembers()
356+
extract_members = []
357+
358+
for member in members_to_check:
359+
# Compute the target path and verify it's within the destination
360+
member_path = Path(path) / member.name
361+
if not _is_within_directory(path, member_path):
362+
raise ValueError(
363+
f"Attempted path traversal in tar file: {member.name}"
323364
)
324-
file_list.append(file.replace("../", ""))
325365

366+
# Skip symlinks, hardlinks, and other special files to prevent symlink attacks
367+
if member.issym():
368+
logger.warning(
369+
f"Skipping symbolic link in tar for security: {member.name}"
370+
)
371+
continue
372+
elif member.islnk():
373+
logger.warning(
374+
f"Skipping hard link in tar for security: {member.name}"
375+
)
376+
continue
377+
elif not (member.isfile() or member.isdir()):
378+
logger.warning(
379+
f"Skipping special file in tar for security: {member.name}"
380+
)
381+
continue
382+
383+
extract_members.append(member)
384+
385+
tar.extractall(path, members=extract_members)
386+
387+
388+
def _uncompress_file_tar(filepath, mode="r:*"):
389+
with tarfile.open(filepath, mode) as files:
390+
file_list = files.getnames()
326391
file_dir = os.path.dirname(filepath)
327392

393+
# Validate all member names before extraction
394+
for name in file_list:
395+
_validate_tar_member_name(name)
396+
328397
if _is_a_single_file(file_list):
329398
rootpath = file_list[0]
330399
uncompressed_path = os.path.join(file_dir, rootpath)
331-
files.extractall(file_dir)
400+
_safe_extract(files, file_dir)
332401
elif _is_a_single_dir(file_list):
333402
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
334403
os.sep
335404
)[-1]
336405
uncompressed_path = os.path.join(file_dir, rootpath)
337-
files.extractall(file_dir)
406+
_safe_extract(files, file_dir)
338407
else:
339408
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
340409
uncompressed_path = os.path.join(file_dir, rootpath)
341410
if not os.path.exists(uncompressed_path):
342411
os.makedirs(uncompressed_path)
343-
344-
files.extractall(os.path.join(file_dir, rootpath))
412+
_safe_extract(files, os.path.join(file_dir, rootpath))
345413

346414
return uncompressed_path
347415

0 commit comments

Comments
 (0)