Skip to content

Commit b032ee7

Browse files
committed
support indexing of Argument
1 parent 1e9d64d commit b032ee7

File tree

3 files changed

+157
-9
lines changed

3 files changed

+157
-9
lines changed

README.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,3 @@ Set `make_anchor=True` when calling `gendoc` function and use standard ref synta
2020
The id is the same as the argument path. Variant types would be in square brackets.
2121

2222
Please refer to test files for detailed usage.
23-
24-
25-
## TODO
26-
27-
- [ ] possibly support of indexing by keys

dargs/dargs.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,32 @@ def __eq__(self, other: "Argument") -> bool:
7474
def __repr__(self) -> str:
7575
return f"<Argument {self.name}: {' | '.join(dd.__name__ for dd in self.dtype)}>"
7676

77+
def __getitem__(self, key: str) -> "Argument":
78+
key = key.lstrip("/")
79+
if key in ("", "."):
80+
return self
81+
if key.startswith("["):
82+
vkey, rkey = key[1:].split("]", 1)
83+
if vkey.count("=") == 1:
84+
fkey, ckey = vkey.split("=")
85+
else:
86+
[fkey] = self.sub_variants.keys()
87+
ckey = vkey
88+
return self.sub_variants[fkey][ckey][rkey]
89+
p1, p2 = key.find("/"), key.find("[")
90+
if max(p1, p2) < 0: # not found
91+
return self.sub_fields[key]
92+
else: # at least one found
93+
p = p1 if p2 < 0 or 0 < p1 < p2 else p2
94+
skey, rkey = key[:p], key[p:]
95+
return self[skey][rkey]
96+
97+
@property
98+
def I(self):
99+
# return a dummy argument that only has self as a sub field
100+
# can be used in indexing
101+
return Argument("_", dict, [self])
102+
77103
def _reorg_dtype(self):
78104
if isinstance(self.dtype, type) or self.dtype is None:
79105
self.dtype = [self.dtype]
@@ -111,7 +137,7 @@ def add_subfield(self, name: Union[str, "Argument"],
111137
newarg = Argument(name, *args, **kwargs)
112138
self.extend_subfields([newarg])
113139
return newarg
114-
140+
115141
def extend_subvariants(self, sub_variants: Optional[Iterable["Variant"]]):
116142
if sub_variants is None:
117143
return
@@ -223,7 +249,7 @@ def _check_strict(self, value: dict):
223249
if name not in allowed_keys:
224250
raise KeyError(f"undefined key `{name}` is "
225251
"not allowed in strict mode")
226-
252+
227253
# above are type checking part
228254
# below are normalizing part
229255

@@ -371,6 +397,9 @@ def __eq__(self, other: "Variant") -> bool:
371397
def __repr__(self) -> str:
372398
return f"<Variant {self.flag_name} in {{ {', '.join(self.choice_dict.keys())} }}>"
373399

400+
def __getitem__(self, key: str) -> "Argument":
401+
return self.choice_dict[key]
402+
374403
def set_default(self, default_tag : Union[bool, str]):
375404
if not default_tag:
376405
self.optional = False
@@ -424,7 +453,7 @@ def get_choice(self, argdict: dict) -> "Argument":
424453
else:
425454
raise KeyError(f"key `{self.flag_name}` is required "
426455
"to choose variant but not found.")
427-
456+
428457
def flatten_sub(self, argdict: dict) -> Dict[str, "Argument"]:
429458
choice = self.get_choice(argdict)
430459
fields = {self.flag_name: self.dummy_argument(), # as a placeholder

tests/test_creation.py

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,34 @@ def test_sub_fields(self):
3232
ca.set_repeat(True)
3333
self.assertTrue(ca == ref)
3434

35+
def test_idx_fields(self):
36+
s1 = Argument("sub1", int)
37+
vt1 = Argument("type1", dict, [
38+
Argument("shared", str),
39+
Argument("vnt1_1", dict, [
40+
Argument("vnt1_1_1", int)
41+
])
42+
])
43+
vt2 = Argument("type2", dict, [
44+
Argument("shared", int),
45+
])
46+
v1 = Variant("vnt_flag", [vt1, vt2])
47+
ca = Argument("base", dict, [s1], [v1])
48+
self.assertTrue(ca[''] is ca)
49+
self.assertTrue(ca['.'] is ca)
50+
self.assertTrue(ca['sub1'] == ca["./sub1"] == s1)
51+
with self.assertRaises(KeyError):
52+
ca["sub2"]
53+
self.assertTrue(ca['[type1]'] is vt1)
54+
self.assertTrue(ca['[type1]///'] is vt1)
55+
self.assertTrue(ca['[type1]/vnt1_1/vnt1_1_1'] == Argument("vnt1_1_1", int))
56+
self.assertTrue(ca['[type2]//shared'] == Argument("shared", int))
57+
with self.assertRaises(KeyError):
58+
s1["sub1"]
59+
self.assertTrue(s1.I["sub1"] is s1)
60+
self.assertTrue(ca.I["base[type1]"] is vt1)
61+
self.assertTrue(ca.I['base[type2]//shared'] == Argument("shared", int))
62+
3563
def test_sub_variants(self):
3664
ref = Argument("base", dict, [
3765
Argument("sub1", int),
@@ -64,7 +92,7 @@ def test_sub_variants(self):
6492
vt2s0 = vt2.add_subfield("shared", int)
6593
vt2s1 = vt2.add_subfield("vnt2_1", int)
6694
self.assertTrue(ca == ref)
67-
95+
# make sure we can modify the reference
6896
ref1 = Argument("base", dict, [
6997
Argument("sub1", int),
7098
Argument("sub2", str)
@@ -88,6 +116,102 @@ def test_sub_variants(self):
88116
v1.set_default(False)
89117
self.assertTrue(ca == ref)
90118

119+
def test_idx_variants(self):
120+
vt1 = Argument("type1", dict, [
121+
Argument("shared", int),
122+
Argument("vnt1_1", int),
123+
Argument("vnt1_2", dict, [
124+
Argument("vnt1_1_1", int)
125+
])
126+
])
127+
vt2 = Argument("type2", dict, [
128+
Argument("shared", int),
129+
Argument("vnt2_1", int),
130+
])
131+
vnt = Variant("vnt_flag", [vt1, vt2])
132+
self.assertTrue(vnt["type1"] is vt1)
133+
self.assertTrue(vnt["type2"] is vt2)
134+
with self.assertRaises(KeyError):
135+
vnt["type3"]
136+
137+
def test_complicated(self):
138+
ref = Argument("base", dict, [
139+
Argument("sub1", int),
140+
Argument("sub2", str)
141+
], [
142+
Variant("vnt_flag", [
143+
Argument("type1", dict, [
144+
Argument("shared", int),
145+
Argument("vnt1_1", int),
146+
Argument("vnt1_2", dict, [
147+
Argument("vnt1_1_1", int)
148+
])
149+
]),
150+
Argument("type2", dict, [
151+
Argument("shared", int),
152+
Argument("vnt2_1", int),
153+
]),
154+
Argument("type3", dict, [
155+
Argument("vnt3_1", int)
156+
], [ # testing cascade variants here
157+
Variant("vnt3_flag1", [
158+
Argument("v3f1t1", dict, [
159+
Argument('v3f1t1_1', int),
160+
Argument('v3f1t1_2', int)
161+
]),
162+
Argument("v3f1t2", dict, [
163+
Argument('v3f1t2_1', int)
164+
])
165+
]),
166+
Variant("vnt3_flag2", [
167+
Argument("v3f2t1", dict, [
168+
Argument('v3f2t1_1', int),
169+
Argument('v3f2t1_2', int)
170+
]),
171+
Argument("v3f2t2", dict, [
172+
Argument('v3f2t2_1', int)
173+
])
174+
])
175+
])
176+
])
177+
])
178+
ca = Argument("base", dict)
179+
s1 = ca.add_subfield("sub1", int)
180+
s2 = ca.add_subfield("sub2", str)
181+
v1 = ca.add_subvariant("vnt_flag")
182+
vt1 = v1.add_choice("type1", dict)
183+
vt1s0 = vt1.add_subfield("shared", int)
184+
vt1s1 = vt1.add_subfield("vnt1_1", int)
185+
vt1s2 = vt1.add_subfield("vnt1_2", dict)
186+
vt1ss = vt1s2.add_subfield("vnt1_1_1", int)
187+
vt2 = v1.add_choice("type2")
188+
vt2s0 = vt2.add_subfield("shared", int)
189+
vt2s1 = vt2.add_subfield("vnt2_1", int)
190+
vt3 = v1.add_choice("type3")
191+
vt3s1 = vt3.add_subfield("vnt3_1", int)
192+
vt3f1 = vt3.add_subvariant('vnt3_flag1')
193+
vt3f1t1 = vt3f1.add_choice("v3f1t1")
194+
vt3f1t1s1 = vt3f1t1.add_subfield("v3f1t1_1", int)
195+
vt3f1t1s2 = vt3f1t1.add_subfield("v3f1t1_2", int)
196+
vt3f1t2 = vt3f1.add_choice("v3f1t2")
197+
vt3f1t2s1 = vt3f1t2.add_subfield("v3f1t2_1", int)
198+
vt3f2 = vt3.add_subvariant('vnt3_flag2')
199+
vt3f2t1 = vt3f2.add_choice("v3f2t1")
200+
vt3f2t1s1 = vt3f2t1.add_subfield("v3f2t1_1", int)
201+
vt3f2t1s2 = vt3f2t1.add_subfield("v3f2t1_2", int)
202+
vt3f2t2 = vt3f2.add_choice("v3f2t2")
203+
vt3f2t2s1 = vt3f2t2.add_subfield("v3f2t2_1", int)
204+
self.assertTrue(ca == ref)
205+
self.assertTrue(ca['[type3][vnt3_flag1=v3f1t1]'] is vt3f1t1)
206+
self.assertTrue(ca.I['base[type3][vnt3_flag1=v3f1t1]/v3f1t1_2'] is vt3f1t1s2)
207+
self.assertTrue(ca.I['base[type3][vnt3_flag1=v3f1t2]/v3f1t2_1'] is vt3f1t2s1)
208+
self.assertTrue(ca.I['base[type3][vnt3_flag2=v3f2t1]/v3f2t1_1'] is vt3f2t1s1)
209+
self.assertTrue(ca.I['base[type3][vnt3_flag2=v3f2t2]/v3f2t2_1'] is vt3f2t2s1)
210+
with self.assertRaises((KeyError, ValueError)):
211+
ca.I['base[type3][v3f2t2]']
212+
with self.assertRaises((KeyError, ValueError)):
213+
ca.I['base[type3][vnt3_flag3=v3f2t2]/v3f2t2_1']
214+
91215

92216
if __name__ == "__main__":
93217
unittest.main()

0 commit comments

Comments
 (0)