@@ -351,6 +351,209 @@ def test_arxiv_extractor(self) -> None:
351351 assert "Introduction" in extracted_text
352352 assert "This is the introduction" in extracted_text
353353
354+ def test_safe_extract_path_traversal_prevention (self , tmp_path : Path ) -> None :
355+ """Test that _safe_extract prevents path traversal attacks."""
356+ import io
357+
358+ from nemo_curator .download .arxiv import _safe_extract
359+
360+ # Create a malicious tar file that tries to write outside the extraction directory
361+ malicious_tar_path = tmp_path / "malicious.tar"
362+
363+ with tarfile .open (malicious_tar_path , "w" ) as tar :
364+ # Add a normal file first
365+ normal_data = io .BytesIO (b"normal content\n " )
366+ normal_tarinfo = tarfile .TarInfo (name = "normal.txt" )
367+ normal_tarinfo .size = len (normal_data .getbuffer ())
368+ tar .addfile (normal_tarinfo , fileobj = normal_data )
369+
370+ # Add a malicious file that tries to escape the extraction directory
371+ malicious_data = io .BytesIO (b"malicious content\n " )
372+ malicious_path = "../../../evil.txt" # Path traversal attempt
373+ malicious_tarinfo = tarfile .TarInfo (name = malicious_path )
374+ malicious_tarinfo .size = len (malicious_data .getbuffer ())
375+ tar .addfile (malicious_tarinfo , fileobj = malicious_data )
376+
377+ # Create extraction directory
378+ extraction_dir = tmp_path / "extraction"
379+ extraction_dir .mkdir ()
380+
381+ # Test that _safe_extract raises ValueError for path traversal
382+ with (
383+ tarfile .open (malicious_tar_path , "r" ) as tar ,
384+ pytest .raises (ValueError , match = "Path traversal attempt detected" ),
385+ ):
386+ _safe_extract (tar , str (extraction_dir ))
387+
388+ # Verify that the malicious file was not created outside the extraction directory
389+ evil_file_path = tmp_path / "evil.txt"
390+ assert not evil_file_path .exists (), "Malicious file should not have been created outside extraction directory"
391+
392+ # Verify that the extraction directory is still safe
393+ extracted_files = list (extraction_dir .rglob ("*" ))
394+ for file_path in extracted_files :
395+ # All extracted files should be within the extraction directory
396+ assert str (file_path ).startswith (str (extraction_dir )), (
397+ f"File { file_path } was extracted outside safe directory"
398+ )
399+
400+ def test_safe_extract_absolute_path_prevention (self , tmp_path : Path ) -> None :
401+ """Test that _safe_extract prevents absolute path attacks."""
402+ import io
403+
404+ from nemo_curator .download .arxiv import _safe_extract
405+
406+ # Create a malicious tar file with absolute path
407+ malicious_tar_path = tmp_path / "absolute_path.tar"
408+
409+ with tarfile .open (malicious_tar_path , "w" ) as tar :
410+ # Add a file with absolute path
411+ malicious_data = io .BytesIO (b"absolute path content\n " )
412+ absolute_path = str (tmp_path / "absolute_evil.txt" ) # Absolute path within tmp_path
413+ malicious_tarinfo = tarfile .TarInfo (name = absolute_path )
414+ malicious_tarinfo .size = len (malicious_data .getbuffer ())
415+ tar .addfile (malicious_tarinfo , fileobj = malicious_data )
416+
417+ # Create extraction directory
418+ extraction_dir = tmp_path / "extraction"
419+ extraction_dir .mkdir ()
420+
421+ # Test that _safe_extract raises ValueError for absolute path
422+ with (
423+ tarfile .open (malicious_tar_path , "r" ) as tar ,
424+ pytest .raises (ValueError , match = "Absolute path not allowed" ),
425+ ):
426+ _safe_extract (tar , str (extraction_dir ))
427+
428+ def test_safe_extract_normal_files (self , tmp_path : Path ) -> None :
429+ """Test that _safe_extract works correctly with normal files."""
430+ import io
431+
432+ from nemo_curator .download .arxiv import _safe_extract
433+
434+ # Create a normal tar file
435+ normal_tar_path = tmp_path / "normal.tar"
436+
437+ with tarfile .open (normal_tar_path , "w" ) as tar :
438+ # Add normal files
439+ for i in range (3 ):
440+ file_data = io .BytesIO (f"content of file { i } \n " .encode ())
441+ tarinfo = tarfile .TarInfo (name = f"file_{ i } .txt" )
442+ tarinfo .size = len (file_data .getbuffer ())
443+ tar .addfile (tarinfo , fileobj = file_data )
444+
445+ # Add a file in a subdirectory
446+ subdir_data = io .BytesIO (b"subdirectory content\n " )
447+ subdir_tarinfo = tarfile .TarInfo (name = "subdir/subfile.txt" )
448+ subdir_tarinfo .size = len (subdir_data .getbuffer ())
449+ tar .addfile (subdir_tarinfo , fileobj = subdir_data )
450+
451+ # Create extraction directory
452+ extraction_dir = tmp_path / "extraction"
453+ extraction_dir .mkdir ()
454+
455+ # Test that _safe_extract works correctly with normal files
456+ with tarfile .open (normal_tar_path , "r" ) as tar :
457+ _safe_extract (tar , str (extraction_dir ))
458+
459+ # Verify all files were extracted correctly
460+ assert (extraction_dir / "file_0.txt" ).exists ()
461+ assert (extraction_dir / "file_1.txt" ).exists ()
462+ assert (extraction_dir / "file_2.txt" ).exists ()
463+ assert (extraction_dir / "subdir" / "subfile.txt" ).exists ()
464+
465+ # Verify content
466+ with open (extraction_dir / "file_0.txt" ) as f :
467+ assert f .read () == "content of file 0\n "
468+ with open (extraction_dir / "subdir" / "subfile.txt" ) as f :
469+ assert f .read () == "subdirectory content\n "
470+
471+ def test_safe_extract_device_file_prevention (self , tmp_path : Path ) -> None :
472+ """Test that _safe_extract prevents extraction of device files."""
473+
474+ from nemo_curator .download .arxiv import _safe_extract
475+
476+ # Create a malicious tar file with a device file
477+ malicious_tar_path = tmp_path / "device_file.tar"
478+
479+ with tarfile .open (malicious_tar_path , "w" ) as tar :
480+ # Add a device file (character device)
481+ device_tarinfo = tarfile .TarInfo (name = "evil_device" )
482+ device_tarinfo .type = tarfile .CHRTYPE # Character device
483+ device_tarinfo .devmajor = 1
484+ device_tarinfo .devminor = 3
485+ tar .addfile (device_tarinfo )
486+
487+ # Create extraction directory
488+ extraction_dir = tmp_path / "extraction"
489+ extraction_dir .mkdir ()
490+
491+ # Test that _safe_extract raises ValueError for device files
492+ with (
493+ tarfile .open (malicious_tar_path , "r" ) as tar ,
494+ pytest .raises (ValueError , match = "Device files not allowed" ),
495+ ):
496+ _safe_extract (tar , str (extraction_dir ))
497+
498+ def test_safe_extract_symlink_prevention (self , tmp_path : Path ) -> None :
499+ """Test that _safe_extract prevents unsafe symlinks."""
500+ import io
501+
502+ from nemo_curator .download .arxiv import _safe_extract
503+
504+ # Create a malicious tar file with unsafe symlinks
505+ malicious_tar_path = tmp_path / "symlink_attack.tar"
506+
507+ with tarfile .open (malicious_tar_path , "w" ) as tar :
508+ # Add a normal file first
509+ normal_data = io .BytesIO (b"normal content\n " )
510+ normal_tarinfo = tarfile .TarInfo (name = "normal.txt" )
511+ normal_tarinfo .size = len (normal_data .getbuffer ())
512+ tar .addfile (normal_tarinfo , fileobj = normal_data )
513+
514+ # Add a symlink that tries to escape the extraction directory
515+ symlink_tarinfo = tarfile .TarInfo (name = "evil_symlink" )
516+ symlink_tarinfo .type = tarfile .SYMTYPE
517+ symlink_tarinfo .linkname = "../../../etc/passwd" # Path traversal via symlink
518+ tar .addfile (symlink_tarinfo )
519+
520+ # Create extraction directory
521+ extraction_dir = tmp_path / "extraction"
522+ extraction_dir .mkdir ()
523+
524+ # Test that _safe_extract raises ValueError for unsafe symlinks
525+ with (
526+ tarfile .open (malicious_tar_path , "r" ) as tar ,
527+ pytest .raises (ValueError , match = "Symlink target outside extraction directory" ),
528+ ):
529+ _safe_extract (tar , str (extraction_dir ))
530+
531+ def test_safe_extract_absolute_symlink_prevention (self , tmp_path : Path ) -> None :
532+ """Test that _safe_extract prevents symlinks with absolute targets."""
533+
534+ from nemo_curator .download .arxiv import _safe_extract
535+
536+ # Create a malicious tar file with absolute symlink target
537+ malicious_tar_path = tmp_path / "absolute_symlink.tar"
538+
539+ with tarfile .open (malicious_tar_path , "w" ) as tar :
540+ # Add a symlink with absolute target
541+ symlink_tarinfo = tarfile .TarInfo (name = "absolute_symlink" )
542+ symlink_tarinfo .type = tarfile .SYMTYPE
543+ symlink_tarinfo .linkname = "/etc/passwd" # Absolute symlink target
544+ tar .addfile (symlink_tarinfo )
545+
546+ # Create extraction directory
547+ extraction_dir = tmp_path / "extraction"
548+ extraction_dir .mkdir ()
549+
550+ # Test that _safe_extract raises ValueError for absolute symlink targets
551+ with (
552+ tarfile .open (malicious_tar_path , "r" ) as tar ,
553+ pytest .raises (ValueError , match = "Absolute symlink target not allowed" ),
554+ ):
555+ _safe_extract (tar , str (extraction_dir ))
556+
354557
355558class TestCommonCrawl :
356559 def test_common_crawl_downloader_existing_file (self , tmp_path : Path , monkeypatch : MonkeyPatch ) -> None :
0 commit comments