Skip to content

Commit cb3600b

Browse files
Fix #1103, fix test_blob (for some numpy versions)
1 parent 10ca20b commit cb3600b

File tree

2 files changed

+30
-16
lines changed

2 files changed

+30
-16
lines changed

datajoint/table.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ def parents(self, primary=None, as_objects=False, foreign_key_info=False):
196196

197197
def children(self, primary=None, as_objects=False, foreign_key_info=False):
198198
"""
199-
200199
:param primary: if None, then all children are returned. If True, then only foreign keys composed of
201200
primary key attributes are considered. If False, return foreign keys including at least one
202201
secondary attribute.
@@ -230,7 +229,6 @@ def descendants(self, as_objects=False):
230229

231230
def ancestors(self, as_objects=False):
232231
"""
233-
234232
:param as_objects: False - a list of table names; True - a list of table objects.
235233
:return: list of tables ancestors in topological order.
236234
"""
@@ -246,6 +244,7 @@ def parts(self, as_objects=False):
246244
247245
:param as_objects: if False (default), the output is a dict describing the foreign keys. If True, return table objects.
248246
"""
247+
self.connection.dependencies.load(force=False)
249248
nodes = [
250249
node
251250
for node in self.connection.dependencies.nodes
@@ -427,7 +426,8 @@ def insert(
427426
self.connection.query(query)
428427
return
429428

430-
field_list = [] # collects the field list from first row (passed by reference)
429+
# collects the field list from first row (passed by reference)
430+
field_list = []
431431
rows = list(
432432
self.__make_row_to_insert(row, field_list, ignore_extra_fields)
433433
for row in rows
@@ -520,7 +520,8 @@ def cascade(table):
520520
delete_count = table.delete_quick(get_count=True)
521521
except IntegrityError as error:
522522
match = foreign_key_error_regexp.match(error.args[0]).groupdict()
523-
if "`.`" not in match["child"]: # if schema name missing, use table
523+
# if schema name missing, use table
524+
if "`.`" not in match["child"]:
524525
match["child"] = "{}.{}".format(
525526
table.full_table_name.split(".")[0], match["child"]
526527
)
@@ -962,7 +963,8 @@ def lookup_class_name(name, context, depth=3):
962963
while nodes:
963964
node = nodes.pop(0)
964965
for member_name, member in node["context"].items():
965-
if not member_name.startswith("_"): # skip IPython's implicit variables
966+
# skip IPython's implicit variables
967+
if not member_name.startswith("_"):
966968
if inspect.isclass(member) and issubclass(member, Table):
967969
if member.full_table_name == name: # found it!
968970
return ".".join([node["context_name"], member_name]).lstrip(".")

tests/test_blob.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,20 @@ def test_pack():
124124
assert x == unpack(pack(x)), "Set did not pack/unpack correctly"
125125

126126
x = tuple(range(10))
127-
assert x == unpack(pack(range(10))), "Iterator did not pack/unpack correctly"
127+
assert x == unpack(
128+
pack(range(10))), "Iterator did not pack/unpack correctly"
128129

129130
x = Decimal("1.24")
130-
assert x == approx(unpack(pack(x))), "Decimal object did not pack/unpack correctly"
131+
assert x == approx(
132+
unpack(pack(x))), "Decimal object did not pack/unpack correctly"
131133

132134
x = datetime.now()
133-
assert x == unpack(pack(x)), "Datetime object did not pack/unpack correctly"
135+
assert x == unpack(
136+
pack(x)), "Datetime object did not pack/unpack correctly"
134137

135138
x = np.bool_(True)
136-
assert x == unpack(pack(x)), "Numpy bool object did not pack/unpack correctly"
139+
assert x == unpack(
140+
pack(x)), "Numpy bool object did not pack/unpack correctly"
137141

138142
x = "test"
139143
assert x == unpack(pack(x)), "String object did not pack/unpack correctly"
@@ -154,13 +158,15 @@ def test_recarrays():
154158
x = x.view(np.recarray)
155159
assert_array_equal(x, unpack(pack(x)))
156160

157-
x = np.array([(3, 4)], dtype=[("tmp0", float), ("tmp1", "O")]).view(np.recarray)
161+
x = np.array([(3, 4)], dtype=[("tmp0", float),
162+
("tmp1", "O")]).view(np.recarray)
158163
assert_array_equal(x, unpack(pack(x)))
159164

160165

161166
def test_object_arrays():
162167
x = np.array(((1, 2, 3), True), dtype="object")
163-
assert_array_equal(x, unpack(pack(x)), "Object array did not serialize correctly")
168+
assert_array_equal(x, unpack(pack(x)),
169+
"Object array did not serialize correctly")
164170

165171

166172
def test_complex():
@@ -170,10 +176,12 @@ def test_complex():
170176
z = np.random.randn(10) + 1j * np.random.randn(10)
171177
assert_array_equal(z, unpack(pack(z)), "Arrays do not match!")
172178

173-
x = np.float32(np.random.randn(3, 4, 5)) + 1j * np.float32(np.random.randn(3, 4, 5))
179+
x = np.float32(np.random.randn(3, 4, 5)) + 1j * \
180+
np.float32(np.random.randn(3, 4, 5))
174181
assert_array_equal(x, unpack(pack(x)), "Arrays do not match!")
175182

176-
x = np.int16(np.random.randn(1, 2, 3)) + 1j * np.int16(np.random.randn(1, 2, 3))
183+
x = np.int16(np.random.randn(1, 2, 3)) + 1j * \
184+
np.int16(np.random.randn(1, 2, 3))
177185
assert_array_equal(x, unpack(pack(x)), "Arrays do not match!")
178186

179187

@@ -185,7 +193,8 @@ def test_insert_longblob(schema_any):
185193

186194
query_mym_blob = {"id": 1, "data": np.array([1, 2, 3])}
187195
Longblob.insert1(query_mym_blob)
188-
assert (Longblob & "id=1").fetch1()["data"].all() == query_mym_blob["data"].all()
196+
assert_array_equal(
197+
(Longblob & "id=1").fetch1()["data"], query_mym_blob["data"])
189198
(Longblob & "id=1").delete()
190199

191200

@@ -214,11 +223,14 @@ def test_insert_longblob_32bit(schema_any, enable_feature_32bit_dims):
214223
)
215224
]
216225
],
217-
dtype=[("hits", "O"), ("sides", "O"), ("tasks", "O"), ("stage", "O")],
226+
dtype=[("hits", "O"), ("sides", "O"),
227+
("tasks", "O"), ("stage", "O")],
218228
),
219229
}
220230
assert fetched["id"] == expected["id"]
221-
assert np.array_equal(fetched["data"], expected["data"])
231+
for name in expected['data'][0][0].dtype.names:
232+
assert_array_equal(
233+
expected['data'][0][0][name], fetched['data'][0][0][name])
222234
(Longblob & "id=1").delete()
223235

224236

0 commit comments

Comments
 (0)