Skip to content

Commit c1cebac

Browse files
committed
fix(TypeTreeHelper C) - TypeTreeNode correctly handle optional kwargs & collapse float reading
1 parent 0a381c9 commit c1cebac

File tree

1 file changed

+78
-82
lines changed

1 file changed

+78
-82
lines changed

UnityPyBoost/TypeTreeHelper.cpp

Lines changed: 78 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -153,22 +153,25 @@ inline PyObject *read_s8_array(ReaderT *reader, int32_t count)
153153
}
154154

155155
template <typename T, bool swap>
156-
inline PyObject *read_int(ReaderT *reader)
156+
inline PyObject *read_num(ReaderT *reader)
157157
{
158-
static_assert(std::is_integral<T>::value, "Unsupported type for read_int");
158+
static_assert(std::is_integral<T>::value || std::is_floating_point<T>::value, "Unsupported type for read_num");
159159

160160
if (reader->ptr + sizeof(T) > reader->end)
161161
{
162-
PyErr_SetString(PyExc_ValueError, "read_int out of bounds");
163-
return NULL;
162+
return PyErr_Format(PyExc_ValueError, "read_%s out of bounds", typeid(T).name());
164163
}
165164
T value = *(T *)reader->ptr;
166165
if constexpr (swap)
167166
{
168167
swap_any_inplace(&value);
169168
}
170169
reader->ptr += sizeof(T);
171-
if constexpr (std::is_signed<T>::value)
170+
if constexpr (std::is_floating_point<T>::value)
171+
{
172+
return PyFloat_FromDouble(value);
173+
}
174+
else if constexpr (std::is_signed<T>::value)
172175
{
173176
if constexpr (std::is_same<T, int64_t>::value)
174177
{
@@ -179,7 +182,7 @@ inline PyObject *read_int(ReaderT *reader)
179182
return PyLong_FromLong((int32_t)value);
180183
}
181184
}
182-
else
185+
else if constexpr (std::is_unsigned<T>::value)
183186
{
184187
if constexpr (std::is_same<T, uint64_t>::value)
185188
{
@@ -190,17 +193,20 @@ inline PyObject *read_int(ReaderT *reader)
190193
return PyLong_FromUnsignedLong((uint32_t)value);
191194
}
192195
}
196+
else
197+
{
198+
return PyErr_Format(PyExc_TypeError, "Unsupported type for read_num: %s", typeid(T).name());
199+
}
193200
}
194201

195202
template <typename T, bool swap>
196-
inline PyObject *read_int_array(ReaderT *reader, int32_t count)
203+
inline PyObject *read_num_array(ReaderT *reader, int32_t count)
197204
{
198-
static_assert(std::is_integral<T>::value, "Unsupported type for read_int_array");
205+
static_assert(std::is_integral<T>::value || std::is_floating_point<T>::value, "Unsupported type for read_num_array");
199206

200207
if (reader->ptr + sizeof(T) * count > reader->end)
201208
{
202-
PyErr_SetString(PyExc_ValueError, "read_int_array out of bounds");
203-
return NULL;
209+
return PyErr_Format(PyExc_ValueError, "read_%s_array out of bounds", typeid(T).name());
204210
}
205211
PyObject *list = PyList_New(count);
206212
T *ptr = (T *)reader->ptr;
@@ -211,72 +217,39 @@ inline PyObject *read_int_array(ReaderT *reader, int32_t count)
211217
{
212218
swap_any_inplace(&value);
213219
}
214-
if constexpr (std::is_signed<T>::value)
220+
PyObject *item;
221+
if constexpr (std::is_floating_point<T>::value)
222+
{
223+
item = PyFloat_FromDouble(value);
224+
}
225+
else if constexpr (std::is_signed<T>::value)
215226
{
216227
if constexpr (std::is_same<T, int64_t>::value)
217228
{
218-
PyList_SET_ITEM(list, i, PyLong_FromLongLong(value));
229+
item = PyLong_FromLongLong(value);
219230
}
220231
else
221232
{
222-
PyList_SET_ITEM(list, i, PyLong_FromLong((int32_t)value));
233+
item = PyLong_FromLong((int32_t)value);
223234
}
224235
}
225-
else
236+
else if constexpr (std::is_unsigned<T>::value)
226237
{
227238
if constexpr (std::is_same<T, uint64_t>::value)
228239
{
229-
PyList_SET_ITEM(list, i, PyLong_FromUnsignedLongLong(value));
240+
item = PyLong_FromUnsignedLongLong(value);
230241
}
231242
else
232243
{
233-
PyList_SET_ITEM(list, i, PyLong_FromUnsignedLong((uint32_t)value));
244+
item = PyLong_FromUnsignedLong((uint32_t)value);
234245
}
235246
}
236-
}
237-
reader->ptr = (uint8_t *)ptr;
238-
return list;
239-
}
240-
241-
template <typename T, bool swap>
242-
inline PyObject *read_float(ReaderT *reader)
243-
{
244-
static_assert(std::is_floating_point<T>::value, "Unsupported type for read_float");
245-
246-
if (reader->ptr + sizeof(T) > reader->end)
247-
{
248-
PyErr_SetString(PyExc_ValueError, "read_float out of bounds");
249-
return NULL;
250-
}
251-
T value = *(T *)reader->ptr;
252-
if constexpr (swap)
253-
{
254-
swap_any_inplace(&value);
255-
}
256-
reader->ptr += sizeof(T);
257-
return PyFloat_FromDouble(value);
258-
}
259-
260-
template <typename T, bool swap>
261-
inline PyObject *read_float_array(ReaderT *reader, int32_t count)
262-
{
263-
static_assert(std::is_floating_point<T>::value, "Unsupported type for read_float_array");
264-
265-
if (reader->ptr + sizeof(T) * count > reader->end)
266-
{
267-
PyErr_SetString(PyExc_ValueError, "read_float_array out of bounds");
268-
return NULL;
269-
}
270-
T *ptr = (T *)reader->ptr;
271-
PyObject *list = PyList_New(count);
272-
for (auto i = 0; i < count; i++)
273-
{
274-
T value = *ptr++;
275-
if constexpr (swap)
247+
else
276248
{
277-
swap_any_inplace(&value);
249+
Py_DECREF(list);
250+
return PyErr_Format(PyExc_TypeError, "Unsupported type for read_num_array: %s", typeid(T).name());
278251
}
279-
PyList_SET_ITEM(list, i, PyFloat_FromDouble(value));
252+
PyList_SET_ITEM(list, i, item);
280253
}
281254
reader->ptr = (uint8_t *)ptr;
282255
return list;
@@ -735,31 +708,31 @@ PyObject *read_typetree_value(ReaderT *reader, TypeTreeNodeObject *node, TypeTre
735708
value = read_u8(reader);
736709
break;
737710
case NodeDataType::u16:
738-
value = read_int<uint16_t, swap>(reader);
711+
value = read_num<uint16_t, swap>(reader);
739712
break;
740713
case NodeDataType::u32:
741-
value = read_int<uint32_t, swap>(reader);
714+
value = read_num<uint32_t, swap>(reader);
742715
break;
743716
case NodeDataType::u64:
744-
value = read_int<uint64_t, swap>(reader);
717+
value = read_num<uint64_t, swap>(reader);
745718
break;
746719
case NodeDataType::s8:
747720
value = read_s8(reader);
748721
break;
749722
case NodeDataType::s16:
750-
value = read_int<int16_t, swap>(reader);
723+
value = read_num<int16_t, swap>(reader);
751724
break;
752725
case NodeDataType::s32:
753-
value = read_int<int32_t, swap>(reader);
726+
value = read_num<int32_t, swap>(reader);
754727
break;
755728
case NodeDataType::s64:
756-
value = read_int<int64_t, swap>(reader);
729+
value = read_num<int64_t, swap>(reader);
757730
break;
758731
case NodeDataType::f32:
759-
value = read_float<float, swap>(reader);
732+
value = read_num<float, swap>(reader);
760733
break;
761734
case NodeDataType::f64:
762-
value = read_float<double, swap>(reader);
735+
value = read_num<double, swap>(reader);
763736
break;
764737
case NodeDataType::boolean:
765738
value = read_bool(reader);
@@ -915,31 +888,31 @@ PyObject *read_typetree_value_array(ReaderT *reader, TypeTreeNodeObject *node, T
915888
value = read_u8_array(reader, count);
916889
break;
917890
case NodeDataType::u16:
918-
value = read_int_array<uint16_t, swap>(reader, count);
891+
value = read_num_array<uint16_t, swap>(reader, count);
919892
break;
920893
case NodeDataType::u32:
921-
value = read_int_array<uint32_t, swap>(reader, count);
894+
value = read_num_array<uint32_t, swap>(reader, count);
922895
break;
923896
case NodeDataType::u64:
924-
value = read_int_array<uint64_t, swap>(reader, count);
897+
value = read_num_array<uint64_t, swap>(reader, count);
925898
break;
926899
case NodeDataType::s8:
927900
value = read_s8_array(reader, count);
928901
break;
929902
case NodeDataType::s16:
930-
value = read_int_array<int16_t, swap>(reader, count);
903+
value = read_num_array<int16_t, swap>(reader, count);
931904
break;
932905
case NodeDataType::s32:
933-
value = read_int_array<int32_t, swap>(reader, count);
906+
value = read_num_array<int32_t, swap>(reader, count);
934907
break;
935908
case NodeDataType::s64:
936-
value = read_int_array<int64_t, swap>(reader, count);
909+
value = read_num_array<int64_t, swap>(reader, count);
937910
break;
938911
case NodeDataType::f32:
939-
value = read_float_array<float, swap>(reader, count);
912+
value = read_num_array<float, swap>(reader, count);
940913
break;
941914
case NodeDataType::f64:
942-
value = read_float_array<double, swap>(reader, count);
915+
value = read_num_array<double, swap>(reader, count);
943916
break;
944917
case NodeDataType::boolean:
945918
value = read_bool_array(reader, count);
@@ -967,6 +940,16 @@ static inline void set_none_if_null_n_incref(PyObject **field)
967940
Py_INCREF(*field);
968941
}
969942

943+
static bool is_null_none_or_type(PyObject *obj, PyTypeObject *type, const char *type_name, const char *field_name)
944+
{
945+
if (obj == nullptr || obj == Py_None || PyObject_TypeCheck(obj, type))
946+
{
947+
return true;
948+
}
949+
PyErr_Format(PyExc_TypeError, "Expected %s or None for %s, got %R", type_name, field_name, obj);
950+
return false;
951+
}
952+
970953
PyObject *read_typetree(PyObject *self, PyObject *args, PyObject *kwargs)
971954
{
972955
const char *kwlist[] = {"data", "node", "endian", "as_dict", "assetsfile", "classes", NULL};
@@ -1166,20 +1149,20 @@ static int TypeTreeNode_init(TypeTreeNodeObject *self, PyObject *args, PyObject
11661149
self->m_RefTypeHash = nullptr;
11671150
self->_clean_name = nullptr;
11681151

1169-
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!O!O!O!|O!O!O!O!O!O!", (char **)kwlist,
1152+
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!O!O!O!O!|OOOOOO", (char **)kwlist,
11701153
// required fields
11711154
&PyLong_Type, &self->m_Level,
11721155
&PyUnicode_Type, &self->m_Type,
11731156
&PyUnicode_Type, &self->m_Name,
11741157
&PyLong_Type, &self->m_ByteSize,
11751158
&PyLong_Type, &self->m_Version,
11761159
// optional fields
1177-
&PyList_Type, &self->m_Children,
1178-
&PyLong_Type, &self->m_TypeFlags,
1179-
&PyLong_Type, &self->m_VariableCount,
1180-
&PyLong_Type, &self->m_Index,
1181-
&PyLong_Type, &self->m_MetaFlag,
1182-
&PyLong_Type, &self->m_RefTypeHash))
1160+
&self->m_Children,
1161+
&self->m_TypeFlags,
1162+
&self->m_VariableCount,
1163+
&self->m_Index,
1164+
&self->m_MetaFlag,
1165+
&self->m_RefTypeHash))
11831166
{
11841167
return -1;
11851168
}
@@ -1192,14 +1175,27 @@ static int TypeTreeNode_init(TypeTreeNodeObject *self, PyObject *args, PyObject
11921175
Py_INCREF(self->m_Version);
11931176

11941177
// optional values - can still be nullptr
1195-
if (self->m_Children == nullptr)
1178+
if (self->m_Children == nullptr || self->m_Children == Py_None)
11961179
{
1180+
if (self->m_Children == Py_None)
1181+
{
1182+
// in older Python's Py_None is not immortal
1183+
Py_DECREF(self->m_Children);
1184+
}
11971185
self->m_Children = PyList_New(0);
11981186
}
11991187
else
12001188
{
12011189
Py_INCREF(self->m_Children);
12021190
}
1191+
1192+
if (!is_null_none_or_type(self->m_TypeFlags, &PyLong_Type, "int", "m_TypeFlags") ||
1193+
!is_null_none_or_type(self->m_VariableCount, &PyLong_Type, "int", "m_VariableCount") ||
1194+
!is_null_none_or_type(self->m_Index, &PyLong_Type, "int", "m_Index") ||
1195+
!is_null_none_or_type(self->m_MetaFlag, &PyLong_Type, "int", "m_MetaFlag") ||
1196+
!is_null_none_or_type(self->m_RefTypeHash, &PyUnicode_Type, "str", "m_RefTypeHash"))
1197+
return -1;
1198+
12031199
set_none_if_null_n_incref(&self->m_TypeFlags);
12041200
set_none_if_null_n_incref(&self->m_VariableCount);
12051201
set_none_if_null_n_incref(&self->m_Index);

0 commit comments

Comments
 (0)