|
12 | 12 | import torch |
13 | 13 | import torch.utils._pytree as pytree |
14 | 14 | from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact |
| 15 | +from torch._inductor.cpp_builder import normalize_path_separator |
15 | 16 | from torch.export import ExportedProgram |
16 | 17 | from torch.export._tree_utils import reorder_kwargs |
17 | 18 | from torch.export.pt2_archive._package_weights import ( |
@@ -75,6 +76,8 @@ class PT2ArchiveWriter: |
75 | 76 | """ |
76 | 77 |
|
77 | 78 | def __init__(self, archive_path_or_buffer: FileLike): |
| 79 | + if isinstance(archive_path_or_buffer, str): |
| 80 | + archive_path_or_buffer = normalize_path_separator(archive_path_or_buffer) |
78 | 81 | self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer) # type: ignore[arg-type] |
79 | 82 | # NOTICE: version here is different from the archive_version |
80 | 83 | # this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version |
@@ -169,6 +172,8 @@ class PT2ArchiveReader: |
169 | 172 | """ |
170 | 173 |
|
171 | 174 | def __init__(self, archive_path_or_buffer: FileLike): |
| 175 | + if isinstance(archive_path_or_buffer, str): |
| 176 | + archive_path_or_buffer = normalize_path_separator(archive_path_or_buffer) |
172 | 177 | self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type] |
173 | 178 | assert self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE, ( |
174 | 179 | "Invalid archive format" |
|
0 commit comments