@@ -74,16 +74,92 @@ def main(args):
7474 save_dir ))
7575 if args .dataset == 'lmd_full' :
7676 with tarfile .open (target_download_location ) as f :
77- f .extractall (save_dir )
77+ def is_within_directory (directory , target ):
78+
79+ abs_directory = os .path .abspath (directory )
80+ abs_target = os .path .abspath (target )
81+
82+ prefix = os .path .commonprefix ([abs_directory , abs_target ])
83+
84+ return prefix == abs_directory
85+
86+ def safe_extract (tar , path = "." , members = None , * , numeric_owner = False ):
87+
88+ for member in tar .getmembers ():
89+ member_path = os .path .join (path , member .name )
90+ if not is_within_directory (path , member_path ):
91+ raise Exception ("Attempted Path Traversal in Tar File" )
92+
93+ tar .extractall (path , members , numeric_owner = numeric_owner )
94+
95+
96+ safe_extract (f , save_dir )
7897 elif args .dataset == 'lmd_matched' :
7998 with tarfile .open (target_download_location ) as f :
80- f .extractall (save_dir )
99+ def is_within_directory (directory , target ):
100+
101+ abs_directory = os .path .abspath (directory )
102+ abs_target = os .path .abspath (target )
103+
104+ prefix = os .path .commonprefix ([abs_directory , abs_target ])
105+
106+ return prefix == abs_directory
107+
108+ def safe_extract (tar , path = "." , members = None , * , numeric_owner = False ):
109+
110+ for member in tar .getmembers ():
111+ member_path = os .path .join (path , member .name )
112+ if not is_within_directory (path , member_path ):
113+ raise Exception ("Attempted Path Traversal in Tar File" )
114+
115+ tar .extractall (path , members , numeric_owner = numeric_owner )
116+
117+
118+ safe_extract (f , save_dir )
81119 elif args .dataset == 'lmd_aligned' :
82120 with tarfile .open (target_download_location ) as f :
83- f .extractall (save_dir )
121+ def is_within_directory (directory , target ):
122+
123+ abs_directory = os .path .abspath (directory )
124+ abs_target = os .path .abspath (target )
125+
126+ prefix = os .path .commonprefix ([abs_directory , abs_target ])
127+
128+ return prefix == abs_directory
129+
130+ def safe_extract (tar , path = "." , members = None , * , numeric_owner = False ):
131+
132+ for member in tar .getmembers ():
133+ member_path = os .path .join (path , member .name )
134+ if not is_within_directory (path , member_path ):
135+ raise Exception ("Attempted Path Traversal in Tar File" )
136+
137+ tar .extractall (path , members , numeric_owner = numeric_owner )
138+
139+
140+ safe_extract (f , save_dir )
84141 elif args .dataset == 'clean_midi' :
85142 with tarfile .open (target_download_location ) as f :
86- f .extractall (save_dir )
143+ def is_within_directory (directory , target ):
144+
145+ abs_directory = os .path .abspath (directory )
146+ abs_target = os .path .abspath (target )
147+
148+ prefix = os .path .commonprefix ([abs_directory , abs_target ])
149+
150+ return prefix == abs_directory
151+
152+ def safe_extract (tar , path = "." , members = None , * , numeric_owner = False ):
153+
154+ for member in tar .getmembers ():
155+ member_path = os .path .join (path , member .name )
156+ if not is_within_directory (path , member_path ):
157+ raise Exception ("Attempted Path Traversal in Tar File" )
158+
159+ tar .extractall (path , members , numeric_owner = numeric_owner )
160+
161+
162+ safe_extract (f , save_dir )
87163 elif args .dataset == 'maestro_v1' :
88164 with zipfile .ZipFile (target_download_location , 'r' ) as fobj :
89165 fobj .extractall (save_dir )
0 commit comments