Skip to content

Commit c9a8001

Browse files
committed
keywork pk test cases
1 parent 97d6e55 commit c9a8001

File tree

3 files changed

+54
-3
lines changed

3 files changed

+54
-3
lines changed

datajoint/declare.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,13 @@ def format_attribute(attr):
443443
return f"`{attr}`"
444444
return f"({attr})"
445445

446-
match = re.match(
447-
r"(?P<unique>unique\s+)?index\s*\(\s*(?P<args>.*)\)", line, re.I
448-
).groupdict()
446+
try:
447+
match = re.match(
448+
r"(?P<unique>unique\s+)?index\s*\(\s*(?P<args>.*)\)", line, re.I
449+
).groupdict()
450+
except AttributeError:
451+
raise DataJointError(f'Table definition syntax error in line "{line}"')
452+
449453
attr_list = re.findall(r"(?:[^,(]|\([^)]*\))+", match["args"])
450454
index_sql.append(
451455
"{unique}index ({attrs})".format(

tests_old/schema_simple.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,24 @@
1414
schema = dj.Schema(PREFIX + "_relational", locals(), connection=dj.conn(**CONN_INFO))
1515

1616

17+
@schema
18+
class SelectPK(dj.Lookup):
19+
definition = """ # tests sql keyword escaping
20+
id: int
21+
select : int
22+
"""
23+
contents = list(dict(id=i, select=i * j) for i in range(3) for j in range(4, 0, -1))
24+
25+
26+
@schema
27+
class KeyPK(dj.Lookup):
28+
definition = """ # tests sql keyword escaping
29+
id : int
30+
key : int
31+
"""
32+
contents = list(dict(id=i, key=i + j) for i in range(3) for j in range(4, 0, -1))
33+
34+
1735
@schema
1836
class IJ(dj.Lookup):
1937
definition = """ # tests restrictions

tests_old/test_relational_operand.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
L,
2626
DataA,
2727
DataB,
28+
SelectPK,
29+
KeyPK,
2830
TTestUpdate,
2931
IJ,
3032
JI,
@@ -527,6 +529,33 @@ def test_restrictions_by_top():
527529
{"id_l": 20, "cond_in_l": 1},
528530
]
529531

532+
@staticmethod
533+
def test_top_restriction_with_keywords():
534+
select = SelectPK() & dj.Top(limit=9, order_by=["select desc"])
535+
key = KeyPK() & dj.Top(limit=9, order_by="key desc")
536+
assert select.fetch(as_dict=True) == [
537+
{"id": 2, "select": 8},
538+
{"id": 2, "select": 6},
539+
{"id": 1, "select": 4},
540+
{"id": 2, "select": 4},
541+
{"id": 1, "select": 3},
542+
{"id": 1, "select": 2},
543+
{"id": 2, "select": 2},
544+
{"id": 1, "select": 1},
545+
{"id": 0, "select": 0},
546+
]
547+
assert key.fetch(as_dict=True) == [
548+
{"id": 2, "key": 6},
549+
{"id": 2, "key": 5},
550+
{"id": 1, "key": 5},
551+
{"id": 0, "key": 4},
552+
{"id": 1, "key": 4},
553+
{"id": 2, "key": 4},
554+
{"id": 0, "key": 3},
555+
{"id": 1, "key": 3},
556+
{"id": 2, "key": 3},
557+
]
558+
530559
@staticmethod
531560
def test_top_errors():
532561
with assert_raises(DataJointError) as err1:

0 commit comments

Comments
 (0)