Skip to content

Commit 034d7da

Browse files
authored
Add tests for safe extraction methods in test_download.py (#794)
- Implemented tests to ensure _safe_extract prevents path traversal, absolute path, device file, unsafe symlink, and absolute symlink attacks. - Verified that normal files are extracted correctly without security risks. These tests enhance the security of the extraction process for tar files in the arxiv module. Signed-off-by: Abhinav Garg <abhinavg@stanford.edu>
1 parent 4ddb9bb commit 034d7da

File tree

1 file changed

+203
-0
lines changed

1 file changed

+203
-0
lines changed

tests/test_download.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,209 @@ def test_arxiv_extractor(self) -> None:
351351
assert "Introduction" in extracted_text
352352
assert "This is the introduction" in extracted_text
353353

354+
def test_safe_extract_path_traversal_prevention(self, tmp_path: Path) -> None:
355+
"""Test that _safe_extract prevents path traversal attacks."""
356+
import io
357+
358+
from nemo_curator.download.arxiv import _safe_extract
359+
360+
# Create a malicious tar file that tries to write outside the extraction directory
361+
malicious_tar_path = tmp_path / "malicious.tar"
362+
363+
with tarfile.open(malicious_tar_path, "w") as tar:
364+
# Add a normal file first
365+
normal_data = io.BytesIO(b"normal content\n")
366+
normal_tarinfo = tarfile.TarInfo(name="normal.txt")
367+
normal_tarinfo.size = len(normal_data.getbuffer())
368+
tar.addfile(normal_tarinfo, fileobj=normal_data)
369+
370+
# Add a malicious file that tries to escape the extraction directory
371+
malicious_data = io.BytesIO(b"malicious content\n")
372+
malicious_path = "../../../evil.txt" # Path traversal attempt
373+
malicious_tarinfo = tarfile.TarInfo(name=malicious_path)
374+
malicious_tarinfo.size = len(malicious_data.getbuffer())
375+
tar.addfile(malicious_tarinfo, fileobj=malicious_data)
376+
377+
# Create extraction directory
378+
extraction_dir = tmp_path / "extraction"
379+
extraction_dir.mkdir()
380+
381+
# Test that _safe_extract raises ValueError for path traversal
382+
with (
383+
tarfile.open(malicious_tar_path, "r") as tar,
384+
pytest.raises(ValueError, match="Path traversal attempt detected"),
385+
):
386+
_safe_extract(tar, str(extraction_dir))
387+
388+
# Verify that the malicious file was not created outside the extraction directory
389+
evil_file_path = tmp_path / "evil.txt"
390+
assert not evil_file_path.exists(), "Malicious file should not have been created outside extraction directory"
391+
392+
# Verify that the extraction directory is still safe
393+
extracted_files = list(extraction_dir.rglob("*"))
394+
for file_path in extracted_files:
395+
# All extracted files should be within the extraction directory
396+
assert str(file_path).startswith(str(extraction_dir)), (
397+
f"File {file_path} was extracted outside safe directory"
398+
)
399+
400+
def test_safe_extract_absolute_path_prevention(self, tmp_path: Path) -> None:
401+
"""Test that _safe_extract prevents absolute path attacks."""
402+
import io
403+
404+
from nemo_curator.download.arxiv import _safe_extract
405+
406+
# Create a malicious tar file with absolute path
407+
malicious_tar_path = tmp_path / "absolute_path.tar"
408+
409+
with tarfile.open(malicious_tar_path, "w") as tar:
410+
# Add a file with absolute path
411+
malicious_data = io.BytesIO(b"absolute path content\n")
412+
absolute_path = str(tmp_path / "absolute_evil.txt") # Absolute path within tmp_path
413+
malicious_tarinfo = tarfile.TarInfo(name=absolute_path)
414+
malicious_tarinfo.size = len(malicious_data.getbuffer())
415+
tar.addfile(malicious_tarinfo, fileobj=malicious_data)
416+
417+
# Create extraction directory
418+
extraction_dir = tmp_path / "extraction"
419+
extraction_dir.mkdir()
420+
421+
# Test that _safe_extract raises ValueError for absolute path
422+
with (
423+
tarfile.open(malicious_tar_path, "r") as tar,
424+
pytest.raises(ValueError, match="Absolute path not allowed"),
425+
):
426+
_safe_extract(tar, str(extraction_dir))
427+
428+
def test_safe_extract_normal_files(self, tmp_path: Path) -> None:
429+
"""Test that _safe_extract works correctly with normal files."""
430+
import io
431+
432+
from nemo_curator.download.arxiv import _safe_extract
433+
434+
# Create a normal tar file
435+
normal_tar_path = tmp_path / "normal.tar"
436+
437+
with tarfile.open(normal_tar_path, "w") as tar:
438+
# Add normal files
439+
for i in range(3):
440+
file_data = io.BytesIO(f"content of file {i}\n".encode())
441+
tarinfo = tarfile.TarInfo(name=f"file_{i}.txt")
442+
tarinfo.size = len(file_data.getbuffer())
443+
tar.addfile(tarinfo, fileobj=file_data)
444+
445+
# Add a file in a subdirectory
446+
subdir_data = io.BytesIO(b"subdirectory content\n")
447+
subdir_tarinfo = tarfile.TarInfo(name="subdir/subfile.txt")
448+
subdir_tarinfo.size = len(subdir_data.getbuffer())
449+
tar.addfile(subdir_tarinfo, fileobj=subdir_data)
450+
451+
# Create extraction directory
452+
extraction_dir = tmp_path / "extraction"
453+
extraction_dir.mkdir()
454+
455+
# Test that _safe_extract works correctly with normal files
456+
with tarfile.open(normal_tar_path, "r") as tar:
457+
_safe_extract(tar, str(extraction_dir))
458+
459+
# Verify all files were extracted correctly
460+
assert (extraction_dir / "file_0.txt").exists()
461+
assert (extraction_dir / "file_1.txt").exists()
462+
assert (extraction_dir / "file_2.txt").exists()
463+
assert (extraction_dir / "subdir" / "subfile.txt").exists()
464+
465+
# Verify content
466+
with open(extraction_dir / "file_0.txt") as f:
467+
assert f.read() == "content of file 0\n"
468+
with open(extraction_dir / "subdir" / "subfile.txt") as f:
469+
assert f.read() == "subdirectory content\n"
470+
471+
def test_safe_extract_device_file_prevention(self, tmp_path: Path) -> None:
472+
"""Test that _safe_extract prevents extraction of device files."""
473+
474+
from nemo_curator.download.arxiv import _safe_extract
475+
476+
# Create a malicious tar file with a device file
477+
malicious_tar_path = tmp_path / "device_file.tar"
478+
479+
with tarfile.open(malicious_tar_path, "w") as tar:
480+
# Add a device file (character device)
481+
device_tarinfo = tarfile.TarInfo(name="evil_device")
482+
device_tarinfo.type = tarfile.CHRTYPE # Character device
483+
device_tarinfo.devmajor = 1
484+
device_tarinfo.devminor = 3
485+
tar.addfile(device_tarinfo)
486+
487+
# Create extraction directory
488+
extraction_dir = tmp_path / "extraction"
489+
extraction_dir.mkdir()
490+
491+
# Test that _safe_extract raises ValueError for device files
492+
with (
493+
tarfile.open(malicious_tar_path, "r") as tar,
494+
pytest.raises(ValueError, match="Device files not allowed"),
495+
):
496+
_safe_extract(tar, str(extraction_dir))
497+
498+
def test_safe_extract_symlink_prevention(self, tmp_path: Path) -> None:
499+
"""Test that _safe_extract prevents unsafe symlinks."""
500+
import io
501+
502+
from nemo_curator.download.arxiv import _safe_extract
503+
504+
# Create a malicious tar file with unsafe symlinks
505+
malicious_tar_path = tmp_path / "symlink_attack.tar"
506+
507+
with tarfile.open(malicious_tar_path, "w") as tar:
508+
# Add a normal file first
509+
normal_data = io.BytesIO(b"normal content\n")
510+
normal_tarinfo = tarfile.TarInfo(name="normal.txt")
511+
normal_tarinfo.size = len(normal_data.getbuffer())
512+
tar.addfile(normal_tarinfo, fileobj=normal_data)
513+
514+
# Add a symlink that tries to escape the extraction directory
515+
symlink_tarinfo = tarfile.TarInfo(name="evil_symlink")
516+
symlink_tarinfo.type = tarfile.SYMTYPE
517+
symlink_tarinfo.linkname = "../../../etc/passwd" # Path traversal via symlink
518+
tar.addfile(symlink_tarinfo)
519+
520+
# Create extraction directory
521+
extraction_dir = tmp_path / "extraction"
522+
extraction_dir.mkdir()
523+
524+
# Test that _safe_extract raises ValueError for unsafe symlinks
525+
with (
526+
tarfile.open(malicious_tar_path, "r") as tar,
527+
pytest.raises(ValueError, match="Symlink target outside extraction directory"),
528+
):
529+
_safe_extract(tar, str(extraction_dir))
530+
531+
def test_safe_extract_absolute_symlink_prevention(self, tmp_path: Path) -> None:
532+
"""Test that _safe_extract prevents symlinks with absolute targets."""
533+
534+
from nemo_curator.download.arxiv import _safe_extract
535+
536+
# Create a malicious tar file with absolute symlink target
537+
malicious_tar_path = tmp_path / "absolute_symlink.tar"
538+
539+
with tarfile.open(malicious_tar_path, "w") as tar:
540+
# Add a symlink with absolute target
541+
symlink_tarinfo = tarfile.TarInfo(name="absolute_symlink")
542+
symlink_tarinfo.type = tarfile.SYMTYPE
543+
symlink_tarinfo.linkname = "/etc/passwd" # Absolute symlink target
544+
tar.addfile(symlink_tarinfo)
545+
546+
# Create extraction directory
547+
extraction_dir = tmp_path / "extraction"
548+
extraction_dir.mkdir()
549+
550+
# Test that _safe_extract raises ValueError for absolute symlink targets
551+
with (
552+
tarfile.open(malicious_tar_path, "r") as tar,
553+
pytest.raises(ValueError, match="Absolute symlink target not allowed"),
554+
):
555+
_safe_extract(tar, str(extraction_dir))
556+
354557

355558
class TestCommonCrawl:
356559
def test_common_crawl_downloader_existing_file(self, tmp_path: Path, monkeypatch: MonkeyPatch) -> None:

0 commit comments

Comments
 (0)