@@ -25,9 +25,7 @@ use crate::native::{
2525/// assert!(native.sites.len() >= 1, "Vector length is less than 1");
2626/// ```
2727pub fn parse_site_native_file ( xml_path : & Path ) -> Result < SiteNative , Error > {
28- if !xml_path. exists ( ) {
29- return Err ( Error :: FileNotFound ( xml_path. to_path_buf ( ) ) ) ;
30- }
28+ check_valid_xml_file ( xml_path) ?;
3129
3230 let xml_file = read_to_string ( xml_path) ?;
3331 let native = parse_site_native_string ( & xml_file) ?;
@@ -406,9 +404,7 @@ pub fn parse_site_native_string(xml_str: &str) -> Result<SiteNative, Error> {
406404/// assert!(native.patients.len() >= 1, "Vector length is less than 1");
407405/// ```
408406pub fn parse_subject_native_file ( xml_path : & Path ) -> Result < SubjectNative , Error > {
409- if !xml_path. exists ( ) {
410- return Err ( Error :: FileNotFound ( xml_path. to_path_buf ( ) ) ) ;
411- }
407+ check_valid_xml_file ( xml_path) ?;
412408
413409 let xml_file = read_to_string ( xml_path) ?;
414410 let native = parse_subject_native_string ( & xml_file) ?;
@@ -617,9 +613,7 @@ pub fn parse_subject_native_string(xml_str: &str) -> Result<SubjectNative, Error
617613/// assert!(native.users.len() >= 1, "Vector length is less than 1");
618614/// ```
619615pub fn parse_user_native_file ( xml_path : & Path ) -> Result < UserNative , Error > {
620- if !xml_path. exists ( ) {
621- return Err ( Error :: FileNotFound ( xml_path. to_path_buf ( ) ) ) ;
622- }
616+ check_valid_xml_file ( xml_path) ?;
623617
624618 let xml_file = read_to_string ( xml_path) ?;
625619 let native = parse_user_native_string ( & xml_file) ?;
@@ -792,3 +786,88 @@ pub fn parse_user_native_string(xml_str: &str) -> Result<UserNative, Error> {
792786
793787 Ok ( native)
794788}
789+
790+ fn check_valid_xml_file ( xml_path : & Path ) -> Result < ( ) , Error > {
791+ if !xml_path. exists ( ) {
792+ return Err ( Error :: FileNotFound ( xml_path. to_path_buf ( ) ) ) ;
793+ }
794+
795+ if let Some ( extension) = xml_path. extension ( ) {
796+ if extension != "xml" {
797+ return Err ( Error :: InvalidFileType ( xml_path. to_owned ( ) ) ) ;
798+ }
799+ } else {
800+ return Err ( Error :: Unknown ) ;
801+ }
802+
803+ Ok ( ( ) )
804+ }
805+
806+ #[ cfg( test) ]
807+ mod tests {
808+ use super :: * ;
809+ use tempfile:: { tempdir, Builder } ;
810+
811+ #[ test]
812+ fn test_site_file_not_found_error ( ) {
813+ let dir = tempdir ( ) . unwrap ( ) . path ( ) . to_path_buf ( ) ;
814+ let result = parse_site_native_file ( & dir) ;
815+ assert ! ( result. is_err( ) ) ;
816+ assert ! ( matches!( result, Err ( Error :: FileNotFound ( _) ) ) ) ;
817+ }
818+
819+ #[ test]
820+ fn test_site_invaid_file_type_error ( ) {
821+ let file = Builder :: new ( )
822+ . prefix ( "test" )
823+ . suffix ( ".csv" )
824+ . tempfile ( )
825+ . unwrap ( ) ;
826+ let result = parse_site_native_file ( file. path ( ) ) ;
827+
828+ assert ! ( result. is_err( ) ) ;
829+ assert ! ( matches!( result, Err ( Error :: InvalidFileType ( _) ) ) ) ;
830+ }
831+
832+ #[ test]
833+ fn test_subject_file_not_found_error ( ) {
834+ let dir = tempdir ( ) . unwrap ( ) . path ( ) . to_path_buf ( ) ;
835+ let result = parse_subject_native_file ( & dir) ;
836+ assert ! ( result. is_err( ) ) ;
837+ assert ! ( matches!( result, Err ( Error :: FileNotFound ( _) ) ) ) ;
838+ }
839+
840+ #[ test]
841+ fn test_subject_invaid_file_type_error ( ) {
842+ let file = Builder :: new ( )
843+ . prefix ( "test" )
844+ . suffix ( ".csv" )
845+ . tempfile ( )
846+ . unwrap ( ) ;
847+ let result = parse_subject_native_file ( file. path ( ) ) ;
848+
849+ assert ! ( result. is_err( ) ) ;
850+ assert ! ( matches!( result, Err ( Error :: InvalidFileType ( _) ) ) ) ;
851+ }
852+
853+ #[ test]
854+ fn test_user_file_not_found_error ( ) {
855+ let dir = tempdir ( ) . unwrap ( ) . path ( ) . to_path_buf ( ) ;
856+ let result = parse_user_native_file ( & dir) ;
857+ assert ! ( result. is_err( ) ) ;
858+ assert ! ( matches!( result, Err ( Error :: FileNotFound ( _) ) ) ) ;
859+ }
860+
861+ #[ test]
862+ fn test_user_invaid_file_type_error ( ) {
863+ let file = Builder :: new ( )
864+ . prefix ( "test" )
865+ . suffix ( ".csv" )
866+ . tempfile ( )
867+ . unwrap ( ) ;
868+ let result = parse_user_native_file ( file. path ( ) ) ;
869+
870+ assert ! ( result. is_err( ) ) ;
871+ assert ! ( matches!( result, Err ( Error :: InvalidFileType ( _) ) ) ) ;
872+ }
873+ }
0 commit comments