Skip to content

Commit ee97521

Browse files
Merge branch 'master' of github.com:datajoint/datajoint-python into test-bucket-rules
2 parents a87199f + 518b882 commit ee97521

File tree

4 files changed

+28
-7
lines changed

4 files changed

+28
-7
lines changed

datajoint/admin.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def kill(restriction=None, connection=None, order_by=None): # pragma: no cover
2626
view and kill database connections.
2727
:param restriction: restriction to be applied to processlist
2828
:param connection: a datajoint.Connection object. Default calls datajoint.conn()
29-
:param order_by: order by string clause for output ordering. defaults to 'id'.
29+
:param order_by: order by a single attribute or the list of attributes. defaults to 'id'.
3030
3131
Restrictions are specified as strings and can involve any of the attributes of
3232
information_schema.processlist: ID, USER, HOST, DB, COMMAND, TIME, STATE, INFO.
@@ -39,6 +39,9 @@ def kill(restriction=None, connection=None, order_by=None): # pragma: no cover
3939
if connection is None:
4040
connection = conn()
4141

42+
if order_by is not None and not isinstance(order_by, str):
43+
order_by = ','.join(order_by)
44+
4245
query = 'SELECT * FROM information_schema.processlist WHERE id <> CONNECTION_ID()' + (
4346
"" if restriction is None else ' AND (%s)' % restriction) + (
4447
' ORDER BY %s' % (order_by or 'id'))

datajoint/connection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def is_connected(self):
185185
return True
186186

187187
@staticmethod
188-
def __execute_query(cursor, query, args, cursor_class, suppress_warnings):
188+
def _execute_query(cursor, query, args, cursor_class, suppress_warnings):
189189
try:
190190
with warnings.catch_warnings():
191191
if suppress_warnings:
@@ -211,7 +211,7 @@ def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconn
211211
cursor_class = client.cursors.DictCursor if as_dict else client.cursors.Cursor
212212
cursor = self._conn.cursor(cursor=cursor_class)
213213
try:
214-
self.__execute_query(cursor, query, args, cursor_class, suppress_warnings)
214+
self._execute_query(cursor, query, args, cursor_class, suppress_warnings)
215215
except errors.LostConnectionError:
216216
if not reconnect:
217217
raise
@@ -222,7 +222,7 @@ def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconn
222222
raise errors.LostConnectionError("Connection was lost during a transaction.") from None
223223
logger.debug("Re-executing")
224224
cursor = self._conn.cursor(cursor=cursor_class)
225-
self.__execute_query(cursor, query, args, cursor_class, suppress_warnings)
225+
self._execute_query(cursor, query, args, cursor_class, suppress_warnings)
226226
return cursor
227227

228228
def get_user(self):

datajoint/table.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,11 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
193193
"""
194194

195195
if isinstance(rows, pandas.DataFrame):
196-
rows = rows.to_records()
196+
# drop 'extra' synthetic index for 1-field index case -
197+
# frames with more advanced indices should be prepared by user.
198+
rows = rows.reset_index(
199+
drop=len(rows.index.names) == 1 and not rows.index.names[0]
200+
).to_records(index=False)
197201

198202
# prohibit direct inserts into auto-populated tables
199203
if not allow_direct_insert and not getattr(self, '_allow_insert', True): # allow_insert is only used in AutoPopulate
@@ -535,7 +539,6 @@ def describe(self, context=None, printout=True):
535539
parent_name = list(self.connection.dependencies.in_edges(parent_name))[0][0]
536540
lst = [(attr, ref) for attr, ref in fk_props['attr_map'].items() if ref != attr]
537541
definition += '->{props} {class_name}.proj({proj_list})\n'.format(
538-
attr_list=', '.join(r[0] for r in lst),
539542
props=index_props,
540543
class_name=lookup_class_name(parent_name, context) or parent_name,
541544
proj_list=','.join('{}="{}"'.format(a, b) for a, b in lst))

tests/test_relation.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def test_insert_select(self):
103103
'real_id', 'date_of_birth', 'subject_notes', subject_id='subject_id+1000', species='"human"'))
104104
assert_equal(len(self.subject), 2*original_length)
105105

106-
def test_insert_pandas(self):
106+
def test_insert_pandas_roundtrip(self):
107+
''' ensure fetched frames can be inserted '''
107108
schema.TTest2.delete()
108109
n = len(schema.TTest())
109110
assert_true(n > 0)
@@ -113,6 +114,20 @@ def test_insert_pandas(self):
113114
schema.TTest2.insert(df)
114115
assert_equal(len(schema.TTest2()), n)
115116

117+
def test_insert_pandas_userframe(self):
118+
'''
119+
ensure simple user-created frames (1 field, non-custom index)
120+
can be inserted without extra index adjustment
121+
'''
122+
schema.TTest2.delete()
123+
n = len(schema.TTest())
124+
assert_true(n > 0)
125+
df = pandas.DataFrame(schema.TTest.fetch())
126+
assert_true(isinstance(df, pandas.DataFrame))
127+
assert_equal(len(df), n)
128+
schema.TTest2.insert(df)
129+
assert_equal(len(schema.TTest2()), n)
130+
116131
@raises(dj.DataJointError)
117132
def test_insert_select_ignore_extra_fields0(self):
118133
""" need ignore extra fields for insert select """

0 commit comments

Comments
 (0)