11# FILE: tests/test_create_totalseg_subset.py
2- import unittest
3- import pandas as pd
42import os
5- from scripts .create_totalseg_subset import plot_and_save_distribution
3+ import sys
4+ import pandas as pd
5+ import pytest
6+
7+ # Ensure package import works when tests are run directly
8+ sys .path .insert (0 , os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' )))
9+
10+ from scripts .create_totalseg_subset import plot_and_save_distribution , create_directory_structure , copy_selected_files
611
7- class TestCreateTotalSegSubset (unittest .TestCase ):
8- def setUp (self ):
12+ class TestCreateTotalSegSubset :
13+ @pytest .fixture (autouse = True )
14+ def setup_and_teardown (self , tmp_path ):
915 # Create a sample dataframe
1016 self .data = pd .DataFrame ({
1117 'age' : [25 , 35 , 45 , 55 , 65 , 75 ],
1218 'gender' : ['M' , 'F' , 'M' , 'F' , 'M' , 'F' ]
1319 })
14- self .filename = 'test_distribution.png'
20+ self .filename = tmp_path / 'test_distribution.png'
21+ yield
22+ # Remove the file after test if it exists
23+ if self .filename .exists ():
24+ self .filename .unlink ()
1525
1626 def test_plot_and_save_distribution (self ):
1727 # Call the function to plot and save the distribution
18- plot_and_save_distribution (self .data , "Test Title" , self .filename )
19-
28+ plot_and_save_distribution (self .data , "Test Title" , str (self .filename ))
2029 # Check if the file is created
21- self .assertTrue ( os . path . exists (self . filename ) )
30+ assert self .filename . exists ()
2231
23- def tearDown (self ):
24- # Remove the file after test
25- if os .path .exists (self .filename ):
26- os .remove (self .filename )
32+ def test_create_directory_structure (self , tmp_path ):
33+ subdirs = ["imagesTr" , "labelsTr" ]
34+ create_directory_structure (tmp_path , subdirs )
35+ for subdir in subdirs :
36+ assert (tmp_path / subdir ).exists ()
37+ assert (tmp_path / subdir ).is_dir ()
2738
28- if __name__ == '__main__' :
29- unittest .main ()
39+ def test_copy_selected_files (self , tmp_path ):
40+ # Setup original and new base directories
41+ original_base = tmp_path / "original"
42+ new_base = tmp_path / "new"
43+ subdirs = ["imagesTr" , "labelsTr" ]
44+ image_ids = ["img001" , "img002" ]
45+ create_directory_structure (original_base , subdirs )
46+ create_directory_structure (new_base , subdirs )
47+ # Create files in original_base
48+ for subdir in subdirs :
49+ for img_id in image_ids :
50+ file_path = original_base / subdir / f"{ img_id } _something.nii.gz"
51+ file_path .write_text ("dummy data" )
52+ # Add a file that should not be copied
53+ (original_base / subdir / "otherfile.nii.gz" ).write_text ("not copied" )
54+ # Copy selected files
55+ copy_selected_files (original_base , new_base , subdirs , image_ids )
56+ # Check that only the correct files are copied
57+ for subdir in subdirs :
58+ for img_id in image_ids :
59+ assert (new_base / subdir / f"{ img_id } _something.nii.gz" ).exists ()
60+ assert not (new_base / subdir / "otherfile.nii.gz" ).exists ()
0 commit comments