Skip to content

Commit ed4a905

Browse files
committed
fix bug add testcases
1 parent 110046b commit ed4a905

File tree

5 files changed

+174
-3
lines changed

5 files changed

+174
-3
lines changed

src/python/qubed/set_operations.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@ def node_intersection(
7575
return QEnum_intersection(A, B)
7676

7777
if isinstance(A.values, WildcardGroup) and isinstance(B.values, WildcardGroup):
78-
return A, ValuesMetadata(WildcardGroup(), {}), B
78+
return (
79+
ValuesMetadata(QEnum([]), {}),
80+
ValuesMetadata(WildcardGroup(), {}),
81+
ValuesMetadata(QEnum([]), {}),
82+
)
7983

8084
# If A is a wildcard matcher then the intersection is everything
8185
# just_A is still *
@@ -92,7 +96,7 @@ def node_intersection(
9296
)
9397

9498

95-
def operation(A: Qube, B: Qube, operation_type: SetOperation, node_type) -> Qube:
99+
def operation(A: Qube, B: Qube, operation_type: SetOperation, node_type) -> Qube | None:
96100
assert A.key == B.key, (
97101
"The two Qube root nodes must have the same key to perform set operations,"
98102
f"would usually be two root nodes. They have {A.key} and {B.key} respectively"
@@ -118,6 +122,18 @@ def operation(A: Qube, B: Qube, operation_type: SetOperation, node_type) -> Qube
118122
output = list(_operation(key, A_nodes, B_nodes, operation_type, node_type))
119123
new_children.extend(output)
120124

125+
# print(f"operation {operation_type}: {A}, {B} {new_children = }")
126+
# print(f"{A.children = }")
127+
# print(f"{B.children = }")
128+
# print(f"{new_children = }")
129+
130+
# If there are now no children as a result of the operation, return nothing.
131+
if (A.children or B.children) and not new_children:
132+
if A.key == "root":
133+
return A.replace(children=())
134+
else:
135+
return None
136+
121137
# Whenever we modify children we should recompress them
122138
# But since `operation` is already recursive, we only need to compress this level not all levels
123139
# Hence we use the non-recursive _compress method
@@ -161,7 +177,14 @@ def _operation(
161177
values=intersection.values,
162178
metadata=intersection.metadata,
163179
)
164-
yield operation(new_node_a, new_node_b, operation_type, node_type)
180+
# print(f"{node_a = }")
181+
# print(f"{node_b = }")
182+
# print(f"{intersection.values =}")
183+
result = operation(
184+
new_node_a, new_node_b, operation_type, node_type
185+
)
186+
if result is not None:
187+
yield result
165188

166189
# Now we've removed all the intersections we can yield the just_A and just_B parts if needed
167190
if keep_just_A:

src/python/qubed/value_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ def __len__(self):
119119
def __iter__(self):
120120
return ["*"]
121121

122+
def __bool__(self):
123+
return True
124+
122125
@classmethod
123126
def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]:
124127
return [WildcardGroup()]

src/rust/lib.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
// #![allow(unused_variables)]
44

55

6+
use std::collections::HashMap;
7+
68
use pyo3::prelude::*;
79
use pyo3::wrap_pyfunction;
810
use pyo3::types::{PyDict, PyInt, PyList, PyString};
@@ -18,6 +20,47 @@ fn rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
1820
Ok(())
1921
}
2022

23+
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)]
24+
struct NodeId(usize);
25+
26+
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)]
27+
struct StringId(usize);
28+
29+
struct Node {
30+
key: StringId,
31+
metadata: HashMap<StringId, Vec<String>>,
32+
parent: NodeId,
33+
values: Vec<String>,
34+
children: HashMap<StringId, Vec<NodeId>>,
35+
}
36+
37+
38+
39+
struct Qube {
40+
root: NodeId,
41+
nodes: Vec<Node>,
42+
strings: Vec<String>,
43+
}
44+
45+
use std::ops;
46+
47+
impl ops::Index<StringId> for Qube {
48+
type Output = str;
49+
50+
fn index(&self, index: StringId) -> &str {
51+
&self.strings[index.0]
52+
}
53+
54+
}
55+
56+
impl ops::Index<NodeId> for Qube {
57+
type Output = Node;
58+
59+
fn index(&self, index: NodeId) -> &Node {
60+
&self.nodes[index.0]
61+
}
62+
63+
}
2164

2265
// use rsfdb::listiterator::KeyValueLevel;
2366
// use rsfdb::request::Request;

tests/test_set_operations.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,95 @@
11
from qubed import Qube
22

33

4+
def set_operation_testcase(testcase):
5+
q1 = Qube.from_tree(testcase["q1"])
6+
q2 = Qube.from_tree(testcase["q2"])
7+
assert q1 | q2 == Qube.from_tree(testcase["union"])
8+
assert q1 & q2 == Qube.from_tree(testcase["intersection"])
9+
assert q1 - q2 == Qube.from_tree(testcase["q1 - q2"])
10+
11+
12+
# These are a bunch of testcases where q1 and q2 are specified and then their union/intersection/difference are checked
13+
# Generate them with code like:
14+
# q1 = Qube.from_tree("root, frequency=*, levtype=*, param=*, levelist=*, domain=a/b/c/d")
15+
# q2 = Qube.from_tree("root, frequency=*, levtype=*, param=*, domain=a/b/c/d")
16+
17+
# test = {
18+
# "q1": str(q1),
19+
# "q2": str(q2),
20+
# "union": str(q1 | q2),
21+
# "intersection": str(q1 & q2),
22+
# "q1 - q2": str(q1 - q2),
23+
# }
24+
# BUT MANUALLY CHECK THE OUTPUT BEFORE ADDING IT AS A TEST CASE!
25+
26+
27+
testcases = [
28+
# Simplest case, only leaves differ
29+
{
30+
"q1": "root, a=1, b=1, c=1",
31+
"q2": "root, a=1, b=1, c=2",
32+
"union": "root, a=1, b=1, c=1/2",
33+
"intersection": "root",
34+
"q1 - q2": "root",
35+
},
36+
# Some overlap but also each tree has unique items
37+
{
38+
"q1": "root, a=1, b=1, c=1/2/3",
39+
"q2": "root, a=1, b=1, c=2/3/4",
40+
"union": "root, a=1, b=1, c=1/2/3/4",
41+
"intersection": "root, a=1, b=1, c=2/3",
42+
"q1 - q2": "root",
43+
},
44+
# Overlap at two levels
45+
{
46+
"q1": "root, a=1, b=1/2, c=1/2/3",
47+
"q2": "root, a=1, b=2/3, c=2/3/4",
48+
"union": """
49+
root, a=1
50+
├── b=1, c=1/2/3
51+
├── b=2, c=1/2/3/4
52+
└── b=3, c=2/3/4
53+
""",
54+
"intersection": "root, a=1, b=2, c=2/3",
55+
"q1 - q2": "root",
56+
},
57+
# Check that we can merge even if the divergence point is higher
58+
{
59+
"q1": "root, a=1, b=1, c=1",
60+
"q2": "root, a=2, b=1, c=1",
61+
"union": "root, a=1/2, b=1, c=1",
62+
"intersection": "root",
63+
"q1 - q2": "root, a=1, b=1, c=1",
64+
},
65+
# Two equal qubes
66+
{
67+
"q1": "root, a=1, b=1, c=1",
68+
"q2": "root, a=1, b=1, c=1",
69+
"union": "root, a=1, b=1, c=1",
70+
"intersection": "root, a=1, b=1, c=1",
71+
"q1 - q2": "root",
72+
},
73+
# With wildcards
74+
{
75+
"q1": "root, frequency=*, levtype=*, param=*, levelist=*, domain=a/b/c/d",
76+
"q2": "root, frequency=*, levtype=*, param=*, domain=a/b/c/d",
77+
"union": """
78+
root, frequency=*, levtype=*, param=*
79+
├── domain=a/b/c/d
80+
└── levelist=*, domain=a/b/c/d
81+
""",
82+
"intersection": "root",
83+
"q1 - q2": "root",
84+
},
85+
]
86+
87+
88+
def test_cases():
89+
for case in testcases:
90+
set_operation_testcase(case)
91+
92+
493
def test_leaf_conservation():
594
q = Qube.from_dict(
695
{

tests/test_wildcard.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,16 @@ def test_intersection():
3434
},
3535
}
3636
)
37+
38+
39+
def test_wildcard_union():
40+
q1 = Qube.from_tree(
41+
"root, frequency=*, levtype=*, param=*, levelist=*, domain=a/b/c/d"
42+
)
43+
q2 = Qube.from_tree("root, frequency=*, levtype=*, param=*, domain=a/b/c/d")
44+
expected = Qube.from_tree("""
45+
root, frequency=*, levtype=*, param=*
46+
├── domain=a/b/c/d
47+
└── levelist=*, domain=a/b/c/d
48+
""")
49+
assert (q1 | q2) == expected

0 commit comments

Comments
 (0)