Skip to content

Commit eb6a83d

Browse files
committed
Environment - fix find_file and extend type hints
1 parent d8e8f27 commit eb6a83d

File tree

2 files changed

+49
-29
lines changed

2 files changed

+49
-29
lines changed

UnityPy/environment.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,34 @@
11
import io
2-
import os
32
import ntpath
3+
import os
44
import re
5-
from typing import List, Callable, Dict, Union
5+
from typing import Callable, Dict, List, Optional, Union
66
from zipfile import ZipFile
77

88
from fsspec import AbstractFileSystem
99
from fsspec.implementations.local import LocalFileSystem
1010

11-
12-
from .files import File, ObjectReader, SerializedFile
1311
from .enums import FileType
14-
from .helpers import ImportHelper
12+
from .files import BundleFile, File, ObjectReader, SerializedFile, WebFile
13+
from .helpers.ImportHelper import (
14+
FileSourceType,
15+
check_file_type,
16+
find_sensitive_path,
17+
parse_file,
18+
)
1519
from .streams import EndianBinaryReader
1620

1721
reSplit = re.compile(r"(.*?([^\/\\]+?))\.split\d+")
1822

1923

2024
class Environment:
21-
files: dict
22-
cabs: dict
25+
files: Dict[str, Union[SerializedFile, BundleFile, WebFile, EndianBinaryReader]]
26+
cabs: Dict[str, Union[SerializedFile, EndianBinaryReader]]
2327
path: str
2428
local_files: List[str]
2529
local_files_simple: List[str]
2630

27-
def __init__(self, *args, fs: AbstractFileSystem = None):
31+
def __init__(self, *args: FileSourceType, fs: Optional[AbstractFileSystem] = None):
2832
self.files = {}
2933
self.cabs = {}
3034
self.path = None
@@ -71,7 +75,7 @@ def load_folder(self, path: str):
7175
]
7276
)
7377

74-
def load(self, files: list):
78+
def load(self, files: List[str]):
7579
"""Loads all files into the Environment."""
7680
self.files.update(
7781
{
@@ -81,8 +85,8 @@ def load(self, files: list):
8185
}
8286
)
8387

84-
def _load_split_file(self, basename):
85-
file = []
88+
def _load_split_file(self, basename: str) -> bytes:
89+
file: list[str] = []
8690
for i in range(0, 999):
8791
item = f"{basename}.split{i}"
8892
if self.fs.exists(item):
@@ -94,9 +98,9 @@ def _load_split_file(self, basename):
9498

9599
def load_file(
96100
self,
97-
file: Union[io.IOBase, str],
98-
parent: Union["Environment", File] = None,
99-
name: str = None,
101+
file: FileSourceType,
102+
parent: Optional[Union["Environment", File]] = None,
103+
name: Optional[str] = None,
100104
is_dependency: bool = False,
101105
):
102106
if not parent:
@@ -105,9 +109,10 @@ def load_file(
105109
if isinstance(file, str):
106110
split_match = reSplit.match(file)
107111
if split_match:
108-
basepath, basename = split_match.groups()
112+
basepath, _basename = split_match.groups()
113+
assert isinstance(basepath, str)
109114
name = basepath
110-
file = self._load_split_file(name)
115+
file = self._load_split_file(basepath)
111116
else:
112117
name = file
113118
if not os.path.exists(file):
@@ -119,14 +124,14 @@ def load_file(
119124
file = self._load_split_file(file)
120125
# Unity paths are case insensitive, so we need to find "Resources/Foo.asset" when the record says "resources/foo.asset"
121126
elif not os.path.exists(file):
122-
file = ImportHelper.find_sensitive_path(self.path, file)
127+
file = find_sensitive_path(self.path, file)
123128
# nonexistent files might be packaging errors or references to Unity's global Library/
124129
if file is None:
125130
return
126131
if type(file) == str:
127132
file = self.fs.open(file, "rb")
128133

129-
typ, reader = ImportHelper.check_file_type(file)
134+
typ, reader = check_file_type(file)
130135

131136
stream_name = (
132137
name
@@ -141,15 +146,15 @@ def load_file(
141146
if typ == FileType.ZIP:
142147
f = self.load_zip_file(file)
143148
else:
144-
f = ImportHelper.parse_file(
149+
f = parse_file(
145150
reader, self, name=stream_name, typ=typ, is_dependency=is_dependency
146151
)
147-
152+
148153
if isinstance(f, (SerializedFile, EndianBinaryReader)):
149154
self.register_cab(stream_name, f)
150155

151156
self.files[stream_name] = f
152-
157+
return f
153158

154159
def load_zip_file(self, value):
155160
buffer = None
@@ -274,7 +279,7 @@ def load_assets(self, assets: List[str], open_f: Callable[[str], io.IOBase]):
274279
for path in assets:
275280
splitMatch = reSplit.match(path)
276281
if splitMatch:
277-
basepath, basename = splitMatch.groups()
282+
basepath, _basename = splitMatch.groups()
278283

279284
if basepath in split_files:
280285
continue
@@ -306,21 +311,33 @@ def find_file(self, name: str, is_dependency: bool = True) -> Union[File, None]:
306311
cab = self.get_cab(simple_name)
307312
if cab:
308313
return cab
314+
fp = self.fs.sep.join([self.path, name])
315+
if self.fs.exists(fp):
316+
return self.load_file(fp, name=name, is_dependency=is_dependency)
309317

310318
if len(self.local_files) == 0 and self.path:
311319
for root, _, files in self.fs.walk(self.path):
312-
for name in files:
313-
self.local_files.append(self.fs.sep.join([root, name]))
320+
for f in files:
321+
self.local_files.append(self.fs.sep.join([root, f]))
322+
self.local_files_simple.append(
323+
self.fs.sep.join([root, simplify_name(f)])
324+
)
314325

315326
if name in self.local_files:
316327
fp = name
317328
elif simple_name in self.local_files_simple:
318329
fp = self.local_files[self.local_files_simple.index(simple_name)]
319330
else:
320-
raise FileNotFoundError(f"File {name} not found in {self.path}")
321-
322-
f = self.load_file(fp, name=name, is_dependency=is_dependency)
323-
return f
331+
fp = next((f for f in self.local_files if f.endswith(name)), None)
332+
if not fp:
333+
fp = next(
334+
(f for f in self.local_files_simple if f.endswith(simple_name)),
335+
None,
336+
)
337+
if not fp:
338+
raise FileNotFoundError(f"File {name} not found in {self.path}")
339+
340+
return self.load_file(fp, name=name, is_dependency=is_dependency)
324341

325342

326343
def simplify_name(name: str) -> str:

UnityPy/helpers/ImportHelper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import io
34
import os
45
from typing import Union, List, Optional, Tuple
56
from .CompressionHelper import BROTLI_MAGIC, GZIP_MAGIC
@@ -8,6 +9,8 @@
89
from .. import files
910

1011

12+
FileSourceType = Union[str, bytes, bytearray, io.IOBase]
13+
1114
def file_name_without_extension(file_name: str) -> str:
1215
return os.path.join(
1316
os.path.dirname(file_name), os.path.splitext(os.path.basename(file_name))[0]
@@ -42,7 +45,7 @@ def find_all_files(directory: str, search_str: str) -> List[str]:
4245
]
4346

4447

45-
def check_file_type(input_) -> Tuple[Optional[FileType], Optional[EndianBinaryReader]]:
48+
def check_file_type(input_: FileSourceType) -> Tuple[Optional[FileType], Optional[EndianBinaryReader]]:
4649
if isinstance(input_, str) and os.path.isfile(input_):
4750
reader = EndianBinaryReader(open(input_, "rb"))
4851
elif isinstance(input_, EndianBinaryReader):

0 commit comments

Comments
 (0)