Skip to content

Commit e8ca4cb

Browse files
authored
Fix some type hints (EndianBinaryReader, EndianBinaryWriter) (#293)
* fix: optimize type hint in EndianBinaryReader * fix: optimize type hint in EndianBinaryWriter * fix: int conversion and legacy code in EndianBinaryWriter
1 parent 7886686 commit e8ca4cb

File tree

2 files changed

+53
-46
lines changed

2 files changed

+53
-46
lines changed

UnityPy/streams/EndianBinaryReader.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import sys
2-
from struct import Struct, unpack
32
import re
4-
from typing import List, Union
5-
from io import BytesIO, BufferedIOBase, IOBase, BufferedReader
3+
from struct import Struct, unpack
4+
from io import IOBase, BufferedReader
5+
6+
import builtins
7+
from typing import Callable, List, Optional, Tuple, Union
68

79
reNot0 = re.compile(b"(.*?)\x00", re.S)
810

@@ -43,13 +45,13 @@ class EndianBinaryReader:
4345

4446
def __new__(
4547
cls,
46-
item: Union[bytes, bytearray, memoryview, BytesIO, str],
48+
item: Union[bytes, bytearray, memoryview, IOBase, str],
4749
endian: str = ">",
4850
offset: int = 0,
4951
):
5052
if isinstance(item, (bytes, bytearray, memoryview)):
5153
obj = super(EndianBinaryReader, cls).__new__(EndianBinaryReader_Memoryview)
52-
elif isinstance(item, (IOBase, BufferedIOBase)):
54+
elif isinstance(item, IOBase):
5355
obj = super(EndianBinaryReader, cls).__new__(EndianBinaryReader_Streamable)
5456
elif isinstance(item, str):
5557
item = open(item, "rb")
@@ -75,17 +77,17 @@ def __new__(
7577
obj.__init__(item, endian)
7678
return obj
7779

78-
def __init__(self, item, endian=">", offset=0):
80+
def __init__(self, item, endian: str = ">", offset: int = 0):
7981
self.endian = endian
8082
self.BaseOffset = offset
8183
self.Position = 0
8284

8385
@property
84-
def bytes(self):
86+
def bytes(self) -> builtins.bytes:
8587
# implemented by Streamable and Memoryview versions
8688
return b""
8789

88-
def read(self, *args):
90+
def read(self, *args) -> builtins.bytes:
8991
# implemented by Streamable and Memoryview versions
9092
return b""
9193

@@ -95,7 +97,7 @@ def read_byte(self) -> int:
9597
def read_u_byte(self) -> int:
9698
return unpack(self.endian + "B", self.read(1))[0]
9799

98-
def read_bytes(self, num) -> bytes:
100+
def read_bytes(self, num) -> builtins.bytes:
99101
return self.read(num)
100102

101103
def read_short(self) -> int:
@@ -191,54 +193,54 @@ def read_color4(self) -> Color:
191193
self.read_float(), self.read_float(), self.read_float(), self.read_float()
192194
)
193195

194-
def read_byte_array(self) -> bytes:
196+
def read_byte_array(self) -> builtins.bytes:
195197
return self.read(self.read_int())
196198

197199
def read_matrix(self) -> Matrix4x4:
198200
return Matrix4x4(self.read_float_array(16))
199201

200-
def read_array(self, command, length: int) -> list:
202+
def read_array(self, command: Callable, length: int) -> list:
201203
return [command() for _ in range(length)]
202204

203-
def read_array_struct(self, param: str, length: int = None) -> list:
205+
def read_array_struct(self, param: str, length: Optional[int] = None) -> tuple:
204206
if length is None:
205207
length = self.read_int()
206208
struct = Struct(f"{self.endian}{length}{param}")
207209
return struct.unpack(self.read(struct.size))
208210

209-
def read_boolean_array(self, length: int = None) -> List[bool]:
211+
def read_boolean_array(self, length: Optional[int] = None) -> Tuple[bool]:
210212
return self.read_array_struct("?", length)
211213

212-
def read_u_byte_array(self, length: int = None) -> List[int]:
214+
def read_u_byte_array(self, length: Optional[int] = None) -> Tuple[int]:
213215
return self.read_array_struct("B", length)
214216

215-
def read_u_short_array(self, length: int = None) -> List[int]:
217+
def read_u_short_array(self, length: Optional[int] = None) -> Tuple[int]:
216218
return self.read_array_struct("h", length)
217219

218-
def read_short_array(self, length: int = None) -> List[int]:
220+
def read_short_array(self, length: Optional[int] = None) -> Tuple[int]:
219221
return self.read_array_struct("H", length)
220222

221-
def read_int_array(self, length: int = None) -> List[int]:
223+
def read_int_array(self, length: Optional[int] = None) -> Tuple[int]:
222224
return self.read_array_struct("i", length)
223225

224-
def read_u_int_array(self, length: int = None) -> List[int]:
226+
def read_u_int_array(self, length: Optional[int] = None) -> Tuple[int]:
225227
return self.read_array_struct("I", length)
226228

227-
def read_long_array(self, length: int = None) -> List[int]:
229+
def read_long_array(self, length: Optional[int] = None) -> Tuple[int]:
228230
return self.read_array_struct("q", length)
229231

230-
def read_u_long_array(self, length: int = None) -> List[int]:
232+
def read_u_long_array(self, length: Optional[int] = None) -> Tuple[int]:
231233
return self.read_array_struct("Q", length)
232234

233-
def read_u_int_array_array(self, length: int = None) -> List[List[int]]:
235+
def read_u_int_array_array(self, length: Optional[int] = None) -> List[Tuple[int]]:
234236
return self.read_array(
235237
self.read_u_int_array, length if length is not None else self.read_int()
236238
)
237239

238-
def read_float_array(self, length: int = None) -> List[float]:
240+
def read_float_array(self, length: Optional[int] = None) -> Tuple[float]:
239241
return self.read_array_struct("f", length)
240242

241-
def read_double_array(self, length: int = None) -> List[float]:
243+
def read_double_array(self, length: Optional[int] = None) -> Tuple[float]:
242244
return self.read_array_struct("d", length)
243245

244246
def read_string_array(self) -> List[str]:
@@ -259,7 +261,7 @@ def real_offset(self) -> int:
259261
"""
260262
return self.BaseOffset + self.Position
261263

262-
def read_the_rest(self, obj_start: int, obj_size: int) -> bytes:
264+
def read_the_rest(self, obj_start: int, obj_size: int) -> builtins.bytes:
263265
"""Returns the rest of the current reader bytes."""
264266
return self.read_bytes(obj_size - (self.Position - obj_start))
265267

@@ -268,7 +270,7 @@ class EndianBinaryReader_Memoryview(EndianBinaryReader):
268270
__slots__ = ("view", "_endian", "BaseOffset", "Position", "Length")
269271
view: memoryview
270272

271-
def __init__(self, view, endian=">", offset=0):
273+
def __init__(self, view, endian: str = ">", offset: int = 0):
272274
self._endian = ""
273275
super().__init__(view, endian=endian, offset=offset)
274276
self.view = memoryview(view)
@@ -293,20 +295,20 @@ def endian(self, value: str):
293295
self._endian = value
294296

295297
@property
296-
def bytes(self):
298+
def bytes(self) -> memoryview:
297299
return self.view
298300

299-
def dispose(self):
301+
def dispose(self) -> None:
300302
self.view.release()
301303

302-
def read(self, length: int):
304+
def read(self, length: int) -> memoryview:
303305
if not length:
304-
return b""
306+
return memoryview(b"")
305307
ret = self.view[self.Position : self.Position + length]
306308
self.Position += length
307309
return ret
308310

309-
def read_aligned_string(self):
311+
def read_aligned_string(self) -> str:
310312
length = self.read_int()
311313
if 0 < length <= self.Length - self.Position:
312314
string_data = self.read_bytes(length)
@@ -315,7 +317,7 @@ def read_aligned_string(self):
315317
return result
316318
return ""
317319

318-
def read_string_to_null(self, max_length=32767) -> str:
320+
def read_string_to_null(self, max_length: int = 32767) -> str:
319321
match = reNot0.search(self.view, self.Position, self.Position + max_length)
320322
if not match:
321323
if self.Position + max_length >= self.Length:

UnityPy/streams/EndianBinaryWriter.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
1-
import io
21
from struct import pack
2+
from io import BytesIO, IOBase
3+
4+
import builtins
5+
from typing import Callable, Union
36

47
from ..math import Color, Matrix4x4, Quaternion, Vector2, Vector3, Vector4, Rectangle
58

69

710
class EndianBinaryWriter:
811
endian: str
9-
Length: int
1012
Position: int
11-
stream: io.BufferedReader
13+
stream: IOBase
1214

13-
def __init__(self, input_=b"", endian=">"):
15+
def __init__(
16+
self,
17+
input_: Union[bytes, bytearray, IOBase] = b"",
18+
endian: str = ">"
19+
):
1420
if isinstance(input_, (bytes, bytearray)):
15-
self.stream = io.BytesIO(input_)
21+
self.stream = BytesIO(input_)
1622
self.stream.seek(0, 2)
17-
elif isinstance(input_, io.IOBase):
23+
elif isinstance(input_, IOBase):
1824
self.stream = input_
1925
else:
2026
raise ValueError("Invalid input type - %s." % type(input_))
@@ -36,7 +42,6 @@ def Length(self) -> int:
3642

3743
def dispose(self):
3844
self.stream.close()
39-
pass
4045

4146
def write(self, *args):
4247
if self.Position != self.stream.tell():
@@ -51,7 +56,7 @@ def write_byte(self, value: int):
5156
def write_u_byte(self, value: int):
5257
self.write(pack(self.endian + "B", value))
5358

54-
def write_bytes(self, value: bytes):
59+
def write_bytes(self, value: builtins.bytes):
5560
return self.write(value)
5661

5762
def write_short(self, value: int):
@@ -91,7 +96,7 @@ def write_aligned_string(self, value: str):
9196
self.write(bstring)
9297
self.align_stream(4)
9398

94-
def align_stream(self, alignment=4):
99+
def align_stream(self, alignment:int = 4):
95100
pos = self.stream.tell()
96101
align = (alignment - pos % alignment) % alignment
97102
self.write(b"\0" * align)
@@ -124,10 +129,10 @@ def write_rectangle_f(self, value: Rectangle):
124129
self.write_float(value.height)
125130

126131
def write_color_uint(self, value: Color):
127-
self.write_u_byte(value.R * 255)
128-
self.write_u_byte(value.G * 255)
129-
self.write_u_byte(value.B * 255)
130-
self.write_u_byte(value.A * 255)
132+
self.write_u_byte(int(value.R * 255))
133+
self.write_u_byte(int(value.G * 255))
134+
self.write_u_byte(int(value.B * 255))
135+
self.write_u_byte(int(value.A * 255))
131136

132137
def write_color4(self, value: Color):
133138
self.write_float(value.R)
@@ -139,13 +144,13 @@ def write_matrix(self, value: Matrix4x4):
139144
for val in value.M:
140145
self.write_float(val)
141146

142-
def write_array(self, command, value: list, write_length: bool = True):
147+
def write_array(self, command: Callable, value: list, write_length: bool = True):
143148
if write_length:
144149
self.write_int(len(value))
145150
for val in value:
146151
command(val)
147152

148-
def write_byte_array(self, value: bytes):
153+
def write_byte_array(self, value: builtins.bytes):
149154
self.write_int(len(value))
150155
self.write(value)
151156

0 commit comments

Comments
 (0)