diff --git a/src/odoc/odoc_file.ml b/src/odoc/odoc_file.ml index 67f9b07098..cc96b68ef2 100644 --- a/src/odoc/odoc_file.ml +++ b/src/odoc/odoc_file.ml @@ -29,13 +29,24 @@ type content = type t = { content : content; warnings : Odoc_model.Error.t list } (** Written at the top of the files. Checked when loading. *) -let magic = "odoc-%%VERSION%%" +let magic = "ODOC" + +let magic_version = "%%VERSION%%" (** Exceptions while saving are allowed to leak. *) let save_ file f = + let len = String.length magic_version in + (* Sanity check, see similar check in load_ *) + if len > 255 then + failwith + (Printf.sprintf + "Magic version string %S is too long, must be <= 255 characters" magic); + Fs.Directory.mkdir_p (Fs.File.dirname file); Io_utils.with_open_out_bin (Fs.File.to_string file) (fun oc -> output_string oc magic; + output_binary_int oc len; + output_string oc magic_version; f oc) let save_unit file (root : Root.t) (t : t) = @@ -78,19 +89,41 @@ let save_unit file ~warnings m = let load_ file f = let file = Fs.File.to_string file in - (if Sys.file_exists file then Ok file - else Error (`Msg (Printf.sprintf "File does not exist"))) - >>= fun file -> - Io_utils.with_open_in_bin file @@ fun ic -> - try + + let check_exists () = + if Sys.file_exists file then Ok () + else Error (`Msg (Printf.sprintf "File %s does not exist" file)) + in + + let check_magic ic = let actual_magic = really_input_string ic (String.length magic) in - if actual_magic = magic then f ic + if actual_magic = magic then Ok () + else + Error + (`Msg + (Printf.sprintf "%s has invalid magic %S, expected %S\n%!" file + actual_magic magic)) + in + let version_length ic () = + let len = input_binary_int ic in + if len > 0 && len <= 255 then Ok len + else Error (`Msg (Printf.sprintf "%s has invalid version length" file)) + in + let check_version ic len = + let actual_magic = really_input_string ic len in + if actual_magic = magic_version then Ok () else let msg = - Printf.sprintf "%s: invalid magic number %S, expected %S\n%!" file - actual_magic magic + Printf.sprintf "%s has invalid version %S, expected %S\n%!" file + actual_magic magic_version in Error (`Msg msg) + in + + check_exists () >>= fun () -> + Io_utils.with_open_in_bin file @@ fun ic -> + try + check_magic ic >>= version_length ic >>= check_version ic >>= fun () -> f ic with exn -> let msg = Printf.sprintf "Error while unmarshalling %S: %s\n%!" file