Skip to content

Commit a542208

Browse files
committed
SpriteHelper - code cleanup
1 parent bed31f2 commit a542208

File tree

1 file changed

+51
-27
lines changed

1 file changed

+51
-27
lines changed

UnityPy/export/SpriteHelper.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Iterable, cast, Dict, Any, Union
44

55
from PIL import Image, ImageDraw
6+
from PIL.Image import Transform, Transpose
67

8+
from ..classes import SpriteAtlasData
79
from ..enums import (
810
ClassIDType,
911
SpriteMeshType,
@@ -42,21 +44,25 @@ def __init__(self, settings_raw):
4244
def get_image(
4345
sprite: Sprite, texture: PPtr[Texture2D], alpha_texture: Optional[PPtr[Texture2D]]
4446
) -> Image.Image:
47+
assert sprite.assets_file, "Sprite assets file is not set!"
48+
cache = cast(
49+
Dict[Any, Any], sprite.assets_file._cache
50+
) # TODO: edit in SerializibleFile
4551
if alpha_texture:
4652
cache_id = (texture.path_id, alpha_texture.path_id)
47-
if cache_id not in sprite.assets_file._cache:
53+
if cache_id not in cache:
4854
original_image = get_image_from_texture2d(texture.read(), False)
4955
alpha_image = get_image_from_texture2d(alpha_texture.read(), False)
5056
original_image = Image.merge(
5157
"RGBA", (*original_image.split()[:3], alpha_image.split()[0])
5258
)
53-
sprite.assets_file._cache[cache_id] = original_image
59+
cache[cache_id] = original_image
5460
else:
5561
cache_id = texture.path_id
56-
if cache_id not in sprite.assets_file._cache:
62+
if cache_id not in cache:
5763
original_image = get_image_from_texture2d(texture.read(), False)
58-
sprite.assets_file._cache[cache_id] = original_image
59-
return sprite.assets_file._cache[cache_id]
64+
cache[cache_id] = original_image
65+
return cache[cache_id]
6066

6167

6268
def get_image_from_sprite(m_Sprite: Sprite) -> Image.Image:
@@ -65,6 +71,7 @@ def get_image_from_sprite(m_Sprite: Sprite) -> Image.Image:
6571
atlas = m_Sprite.m_SpriteAtlas.read()
6672
elif m_Sprite.m_AtlasTags:
6773
# looks like the direct pointer is empty, let's try to find the Atlas via its name
74+
assert m_Sprite.assets_file, "Sprite assets file is not set!"
6875
for obj in m_Sprite.assets_file.objects.values():
6976
if obj.type == ClassIDType.SpriteAtlas:
7077
atlas = obj.read()
@@ -78,6 +85,9 @@ def get_image_from_sprite(m_Sprite: Sprite) -> Image.Image:
7885
for key, value in atlas.m_RenderDataMap
7986
if key == m_Sprite.m_RenderDataKey
8087
)
88+
assert isinstance(sprite_atlas_data, SpriteAtlasData), (
89+
"SpriteAtlasData not found!"
90+
)
8191
else:
8292
sprite_atlas_data = m_Sprite.m_RD
8393

@@ -88,41 +98,44 @@ def get_image_from_sprite(m_Sprite: Sprite) -> Image.Image:
8898

8999
original_image = get_image(m_Sprite, m_Texture2D, alpha_texture)
90100

91-
sprite_image = original_image.crop((
92-
texture_rect.x,
93-
texture_rect.y,
94-
texture_rect.x + texture_rect.width,
95-
texture_rect.y + texture_rect.height,
96-
))
101+
sprite_image = original_image.crop(
102+
(
103+
texture_rect.x,
104+
texture_rect.y,
105+
texture_rect.x + texture_rect.width,
106+
texture_rect.y + texture_rect.height,
107+
)
108+
)
97109

98110
settings_raw = SpriteSettings(settings_raw)
99111
if settings_raw.packed == 1:
100112
rotation = settings_raw.packingRotation
101113
if rotation == SpritePackingRotation.kSPRFlipHorizontal:
102-
sprite_image = sprite_image.transpose(Image.FLIP_LEFT_RIGHT)
114+
sprite_image = sprite_image.transpose(Transpose.FLIP_LEFT_RIGHT)
103115
# spriteImage = RotateFlip(RotateFlipType.RotateNoneFlipX)
104116
elif rotation == SpritePackingRotation.kSPRFlipVertical:
105-
sprite_image = sprite_image.transpose(Image.FLIP_TOP_BOTTOM)
117+
sprite_image = sprite_image.transpose(Transpose.FLIP_TOP_BOTTOM)
106118
# spriteImage.RotateFlip(RotateFlipType.RotateNoneFlipY)
107119
elif rotation == SpritePackingRotation.kSPRRotate180:
108-
sprite_image = sprite_image.transpose(Image.ROTATE_180)
120+
sprite_image = sprite_image.transpose(Transpose.ROTATE_180)
109121
# spriteImage.RotateFlip(RotateFlipType.Rotate180FlipNone)
110122
elif rotation == SpritePackingRotation.kSPRRotate90:
111-
sprite_image = sprite_image.transpose(Image.ROTATE_270)
123+
sprite_image = sprite_image.transpose(Transpose.ROTATE_270)
112124
# spriteImage.RotateFlip(RotateFlipType.Rotate270FlipNone)
113125

114126
if settings_raw.packingMode == SpritePackingMode.kSPMTight:
127+
assert m_Sprite.object_reader, "Sprite object reader is not set!"
115128
mesh = MeshHandler(m_Sprite.m_RD, m_Sprite.object_reader.version)
116129
mesh.process()
117130

118-
if any(u or v for u, v in mesh.m_UV0):
131+
if mesh.m_UV0 and any(u or v for u, v in mesh.m_UV0):
119132
# copy triangles from mesh
120133
sprite_image = render_sprite_mesh(m_Sprite, mesh, original_image)
121134
else:
122135
# create mask to keep only the polygon
123136
sprite_image = mask_sprite(m_Sprite, mesh, sprite_image)
124137

125-
return sprite_image.transpose(Image.FLIP_TOP_BOTTOM)
138+
return sprite_image.transpose(Transpose.FLIP_TOP_BOTTOM)
126139

127140

128141
def mask_sprite(
@@ -135,6 +148,9 @@ def mask_sprite(
135148
# shift the whole point matrix into the positive space
136149
# multiply them with a factor to scale them to the image
137150
positions = mesh.m_Vertices
151+
assert positions, "No vertices found in sprite mesh!"
152+
# find the axis that has only one value - can be removed
153+
# usually the z axis
138154
min_x = min(x for x, _y, _z in positions)
139155
min_y = min(y for _x, y, _z in positions)
140156
factor = m_Sprite.m_PixelsToUnits
@@ -174,7 +190,11 @@ def render_sprite_mesh(
174190
) -> Image.Image:
175191
for triangles in mesh.get_triangles():
176192
positions = mesh.m_Vertices
193+
if not positions:
194+
continue
177195
uv = mesh.m_UV0
196+
if not uv:
197+
raise ValueError("No UV coordinates found in sprite mesh!")
178198

179199
# 2. patch position data
180200
# 2.1 make positions 2d
@@ -212,19 +232,21 @@ def render_sprite_mesh(
212232
for tri in triangles:
213233
copy_triangle(
214234
texture,
215-
[uv_abs[i] for i in tri],
235+
tuple(uv_abs[i] for i in tri), # type: ignore
216236
sprite,
217-
[positions_abs[i] for i in tri],
237+
tuple(positions_abs[i] for i in tri), # type: ignore
218238
)
219239

220240
return sprite
241+
else:
242+
raise ValueError("No triangles found in mesh!")
221243

222244

223245
def copy_triangle(
224246
src_img: Image.Image,
225-
src_tri: Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]],
247+
src_tri: Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]],
226248
dst_img: Image.Image,
227-
dst_tri: Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]],
249+
dst_tri: Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]],
228250
) -> None:
229251
src_off = (
230252
(src_tri[1][0] - src_tri[0][0], src_tri[1][1] - src_tri[0][1]),
@@ -286,9 +308,9 @@ def copy_triangle(
286308
A = np.linalg.solve(M, y)
287309
else:
288310
# np.lingal.solve - obviously way faster, but numpy will only come with 2.0
289-
A = linalg_solve(M, y)
311+
A = linalg_solve(M, y) # type: ignore
290312

291-
transformed = src_img.transform(dst_img.size, Image.AFFINE, A)
313+
transformed = src_img.transform(dst_img.size, Transform.AFFINE, A)
292314

293315
mask = Image.new("1", dst_img.size)
294316
maskdraw = ImageDraw.Draw(mask)
@@ -297,18 +319,20 @@ def copy_triangle(
297319
dst_img.paste(transformed, mask=mask)
298320

299321

300-
def linalg_solve(M: List[List[float]], y: List[float]) -> List[float]:
322+
def linalg_solve(
323+
M: List[List[Union[float, int]]], y: List[Union[float, int]]
324+
) -> List[float]:
301325
# M^-1 * y
302326
M_i = get_matrix_inverse(M)
303327
return [sum(M_i[i][j] * y[j] for j in range(len(y))) for i in range(len(M_i))]
304328

305329

306-
def transpose_matrix(m: List[List[float]]) -> List[List[float]]:
330+
def transpose_matrix(m: List[List[float]]) -> Iterable[List[float]]:
307331
# https://stackoverflow.com/a/39881366
308332
return map(list, zip(*m))
309333

310334

311-
def get_matrix_minor(m: List[List[float]], i: int, j: int) -> List[float]:
335+
def get_matrix_minor(m: List[List[float]], i: int, j: int) -> List[List[float]]:
312336
# https://stackoverflow.com/a/39881366
313337
return [row[:j] + row[j + 1 :] for row in (m[:i] + m[i + 1 :])]
314338

0 commit comments

Comments
 (0)