Skip to content

Commit 6a002d0

Browse files
committed
Add test for dj.U, change make_SQL DISTINCT usage.
1 parent 1928742 commit 6a002d0

File tree

3 files changed

+47
-4
lines changed

3 files changed

+47
-4
lines changed

datajoint/expression.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,13 @@ def where_clause(self):
101101
return '' if not self.restriction else ' WHERE(%s)' % ')AND('.join(
102102
str(s) for s in self.restriction)
103103

104-
def make_sql(self, fields=None):
104+
def make_sql(self, fields=None, distinct=True):
105105
"""
106106
Make the SQL SELECT statement.
107107
:param fields: used to explicitly set the select attributes
108108
"""
109-
return 'SELECT {fields} FROM {from_}{where}'.format(
109+
return 'SELECT {distinct}{fields} FROM {from_}{where}'.format(
110+
distinct="DISTINCT " if distinct else "",
110111
fields=self.heading.as_sql(fields or self.heading.names),
111112
from_=self.from_clause(), where=self.where_clause())
112113

@@ -508,9 +509,11 @@ def cursor(self, offset=0, limit=None, order_by=None, as_dict=False):
508509
"""
509510
if offset and limit is None:
510511
raise DataJointError('limit is required when offset is set')
511-
sql = self.make_sql()
512512
if order_by is not None:
513+
sql = self.make_sql(distinct=False)
513514
sql += ' ORDER BY ' + ', '.join(order_by)
515+
else:
516+
sql = self.make_sql()
514517
if limit is not None:
515518
sql += ' LIMIT %d' % limit + (' OFFSET %d' % offset if offset else "")
516519
logger.debug(sql)

tests/schema.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,3 +439,13 @@ class SessionDateA(dj.Lookup):
439439
('mouse1', '2020-12-03'),
440440
('mouse1', '2020-12-04')
441441
]
442+
443+
444+
@schema
445+
class Stimulus(dj.Lookup):
446+
definition = """
447+
id: int
448+
---
449+
contrast: int
450+
brightness: int
451+
"""

tests/test_fetch.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pandas
77
import warnings
88
from . import schema
9-
from .schema import Parent
9+
from .schema import Parent, Stimulus
1010
import datajoint as dj
1111
import os
1212

@@ -296,3 +296,33 @@ def test_fetch_group_by(self):
296296
fetchedData = Parent().fetch('KEY', order_by='name')
297297
print(fetchedData)
298298
assert fetchedData == expectedData
299+
300+
def test_dj_U_DISTINCT(self):
301+
# Test developed to see if removing DISTINCT from the select statement
302+
# generation breakes the dj.U universal set imlementation
303+
304+
# Contents to be inserted
305+
contents = [
306+
(1,2,3),
307+
(2,2,3),
308+
(3,3,2),
309+
(4,5,5)
310+
]
311+
Stimulus.insert(contents)
312+
313+
# Query the whole table
314+
testQuery = Stimulus()
315+
316+
# Use dj.U to create a list of unique contrast and brightness combinations
317+
result = dj.U('contrast', 'brightness') & testQuery
318+
expectedResult = [{'contrast': 2, 'brightness': 3},
319+
{'contrast': 3, 'brightness': 2},
320+
{'contrast': 5, 'brightness': 5}]
321+
322+
fechedResult = result.fetch(as_dict=True)
323+
324+
# Cleanup table
325+
Stimulus.delete()
326+
print(result.make_sql())
327+
# Test to see if the repeated row was removed in the results
328+
assert fechedResult == expectedResult

0 commit comments

Comments
 (0)