Skip to content

Commit a91b86d

Browse files
authored
Merge pull request #4775 from opsmill/lgu-fix-search-ipv6
Search anywhere supports IPV6 extended format
2 parents 93c4695 + 56378c5 commit a91b86d

File tree

4 files changed

+265
-2
lines changed

4 files changed

+265
-2
lines changed

backend/infrahub/core/attribute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ def validate_format(cls, value: Any, name: str, schema: AttributeSchema) -> None
872872
raise ValidationError({name: f"{value} is not a valid {schema.kind}"}) from exc
873873

874874
def serialize_value(self) -> str:
875-
"""Serialize the value before storing it in the database."""
875+
"""Serialize the value before storing it in the database. If network is an IPv6 network, it is converted to collapsed form."""
876876

877877
return ipaddress.ip_network(self.value).with_prefixlen
878878

@@ -998,7 +998,7 @@ def validate_format(cls, value: Any, name: str, schema: AttributeSchema) -> None
998998
raise ValidationError({name: f"{value} is not a valid {schema.kind}"}) from exc
999999

10001000
def serialize_value(self) -> str:
1001-
"""Serialize the value before storing it in the database."""
1001+
"""Adds a prefix to address before storing it in the database. If address in an IPv6 address, it is converted to collapsed form."""
10021002

10031003
return ipaddress.ip_interface(self.value).with_prefixlen
10041004

backend/infrahub/graphql/queries/search.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import ipaddress
34
from typing import TYPE_CHECKING, Any, Optional
45

56
from graphene import Boolean, Field, Int, List, ObjectType, String
@@ -29,6 +30,72 @@ class NodeEdges(ObjectType):
2930
edges = Field(List(of_type=NodeEdge, required=True), required=False)
3031

3132

33+
def _collapse_ipv6(s: str) -> str:
34+
"""Collapse an ipv6 address, ipv6 network, or a partial ipv6 address in extended format, into its collapsed form.
35+
Raises an error if input does not resemble an IPv6 address in extended format. It means this function also raises
36+
an error if input string is the start of an IPv6 address in collapsed format.
37+
"""
38+
39+
try:
40+
return str(ipaddress.IPv6Address(s))
41+
except ipaddress.AddressValueError:
42+
pass
43+
44+
try:
45+
return ipaddress.IPv6Network(s).with_prefixlen
46+
except ipaddress.AddressValueError:
47+
pass
48+
49+
# Input string might be an incomplete address in IPv6 format,
50+
# in which case we would like the collapsed form equivalent of this incomplete address for matching purposes.
51+
# To get it, we first try to pad the incomplete address with zeros, then we retrieve the collapsed form
52+
# of the full address, and we remove extra "::" or ":0" at the end of it.
53+
54+
error_message = "Input string does not match IPv6 extended format"
55+
56+
# Input string cannot be an IPv6 in extended format if it contains ":"
57+
if "::" in s:
58+
raise ValueError(error_message)
59+
60+
# Add padding to complete the address if needed
61+
segments = s.split(":")
62+
63+
if len(segments) == 0:
64+
raise ValueError(error_message)
65+
66+
# If any of the non-last segments has less than 4 characters it means we deal with
67+
# a IPv6 collapsed form or an invalid address
68+
for segment in segments[:-1]:
69+
if len(segment) != 4:
70+
raise ValueError(error_message)
71+
72+
# Add 0 padding to last segment
73+
if len(segments[-1]) > 4:
74+
raise ValueError(error_message)
75+
76+
segments[-1] += "0" * (4 - len(segments[-1]))
77+
78+
# Complete the address to have 8 segments by padding with zeros
79+
while len(segments) < 8:
80+
segments.append("0000")
81+
82+
# Create a full IPv6 address from the partial input
83+
full_address = ":".join(segments)
84+
85+
# Create an IPv6Address object for validation and to build IPv6 collapsed form.
86+
ipv6_address = ipaddress.IPv6Address(full_address)
87+
88+
compressed_address = ipv6_address.compressed
89+
90+
# We padded with zeros so address might endswith "::" or ":0".
91+
if compressed_address.endswith(("::", ":0")):
92+
return compressed_address[:-2]
93+
94+
# Otherwise, it means 8th segment of ipv6 address was not full and not composed of 0 only
95+
# e.g. 2001:0db8:0000:0000:0000:0000:03
96+
return compressed_address
97+
98+
3299
async def search_resolver(
33100
root: dict, # pylint: disable=unused-argument
34101
info: GraphQLResolveInfo,
@@ -49,6 +116,12 @@ async def search_resolver(
49116
if matching:
50117
result.append(matching)
51118
else:
119+
try:
120+
# Convert any IPv6 address, network or partial address to collapsed format as it might be stored in db.
121+
q = _collapse_ipv6(q)
122+
except (ValueError, ipaddress.AddressValueError):
123+
pass
124+
52125
result.extend(
53126
await NodeManager.query(
54127
db=context.db,

backend/tests/unit/graphql/queries/test_search.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import pytest
12
from graphql import graphql
23

34
from infrahub.core.branch import Branch
45
from infrahub.core.node import Node
56
from infrahub.database import InfrahubDatabase
67
from infrahub.graphql.initialization import prepare_graphql_params
8+
from infrahub.graphql.queries.search import _collapse_ipv6
79

810
SEARCH_QUERY = """
911
query ($search: String!) {
@@ -93,3 +95,190 @@ async def test_search_anywhere_by_string(
9395

9496
assert sorted(node_ids) == sorted([person_john_main.id, person_jane_main.id])
9597
assert sorted(node_kinds) == sorted([person_john_main.get_kind(), person_jane_main.get_kind()])
98+
99+
100+
async def test_search_ipv6_address_extended_format(
101+
db: InfrahubDatabase,
102+
ip_dataset_01,
103+
branch: Branch,
104+
):
105+
gql_params = prepare_graphql_params(db=db, include_subscription=False, branch=branch)
106+
107+
res_collapsed = await graphql(
108+
schema=gql_params.schema,
109+
source=SEARCH_QUERY,
110+
context_value=gql_params.context,
111+
root_value=None,
112+
variable_values={"search": "2001:db8::"},
113+
)
114+
115+
res_extended = await graphql(
116+
schema=gql_params.schema,
117+
source=SEARCH_QUERY,
118+
context_value=gql_params.context,
119+
root_value=None,
120+
variable_values={"search": "2001:0db8:0000:0000:0000:0000:0000:0000"},
121+
)
122+
123+
assert (
124+
res_extended.data["InfrahubSearchAnywhere"]["count"]
125+
== res_collapsed.data["InfrahubSearchAnywhere"]["count"]
126+
== 2
127+
)
128+
129+
assert (
130+
res_extended.data["InfrahubSearchAnywhere"]["edges"][0]["node"]["id"]
131+
== res_collapsed.data["InfrahubSearchAnywhere"]["edges"][0]["node"]["id"]
132+
)
133+
134+
assert (
135+
res_extended.data["InfrahubSearchAnywhere"]["edges"][1]["node"]["id"]
136+
== res_collapsed.data["InfrahubSearchAnywhere"]["edges"][1]["node"]["id"]
137+
)
138+
139+
140+
async def test_search_ipv6_network_extended_format(
141+
db: InfrahubDatabase,
142+
ip_dataset_01,
143+
branch: Branch,
144+
):
145+
gql_params = prepare_graphql_params(db=db, include_subscription=False, branch=branch)
146+
147+
res_collapsed = await graphql(
148+
schema=gql_params.schema,
149+
source=SEARCH_QUERY,
150+
context_value=gql_params.context,
151+
root_value=None,
152+
variable_values={"search": "2001:db8::/48"},
153+
)
154+
155+
res_extended = await graphql(
156+
schema=gql_params.schema,
157+
source=SEARCH_QUERY,
158+
context_value=gql_params.context,
159+
root_value=None,
160+
variable_values={"search": "2001:0db8:0000:0000:0000:0000:0000:0000/48"},
161+
)
162+
163+
assert (
164+
res_extended.data["InfrahubSearchAnywhere"]["count"]
165+
== res_collapsed.data["InfrahubSearchAnywhere"]["count"]
166+
== 1
167+
)
168+
169+
assert (
170+
res_extended.data["InfrahubSearchAnywhere"]["edges"][0]["node"]["id"]
171+
== res_collapsed.data["InfrahubSearchAnywhere"]["edges"][0]["node"]["id"]
172+
)
173+
174+
175+
async def test_search_ipv6_partial_address(
176+
db: InfrahubDatabase,
177+
ip_dataset_01,
178+
branch: Branch,
179+
):
180+
gql_params = prepare_graphql_params(db=db, include_subscription=False, branch=branch)
181+
182+
res_two_segments = await graphql(
183+
schema=gql_params.schema,
184+
source=SEARCH_QUERY,
185+
context_value=gql_params.context,
186+
root_value=None,
187+
variable_values={"search": "2001:0db8"},
188+
)
189+
190+
res_partial_segment_1 = await graphql(
191+
schema=gql_params.schema,
192+
source=SEARCH_QUERY,
193+
context_value=gql_params.context,
194+
root_value=None,
195+
variable_values={"search": "2001:0db8:0"},
196+
)
197+
198+
res_partial_segment_2 = await graphql(
199+
schema=gql_params.schema,
200+
source=SEARCH_QUERY,
201+
context_value=gql_params.context,
202+
root_value=None,
203+
variable_values={"search": "2001:0db8:0000:0"},
204+
)
205+
206+
assert (
207+
res_two_segments.data["InfrahubSearchAnywhere"]["count"]
208+
== res_partial_segment_1.data["InfrahubSearchAnywhere"]["count"]
209+
== res_partial_segment_2.data["InfrahubSearchAnywhere"]["count"]
210+
== 2
211+
)
212+
213+
assert (
214+
res_two_segments.data["InfrahubSearchAnywhere"]["edges"][0]["node"]["id"]
215+
== res_partial_segment_1.data["InfrahubSearchAnywhere"]["edges"][0]["node"]["id"]
216+
== res_partial_segment_2.data["InfrahubSearchAnywhere"]["edges"][0]["node"]["id"]
217+
)
218+
219+
220+
async def test_search_ipv4(
221+
db: InfrahubDatabase,
222+
ip_dataset_01,
223+
branch: Branch,
224+
):
225+
"""
226+
This only tests that ipv6 search specific behavior does not break ipv4 search.
227+
"""
228+
229+
gql_params = prepare_graphql_params(db=db, include_subscription=False, branch=branch)
230+
231+
result_address = await graphql(
232+
schema=gql_params.schema,
233+
source=SEARCH_QUERY,
234+
context_value=gql_params.context,
235+
root_value=None,
236+
variable_values={"search": "10.0.0.0"},
237+
)
238+
239+
result_network = await graphql(
240+
schema=gql_params.schema,
241+
source=SEARCH_QUERY,
242+
context_value=gql_params.context,
243+
root_value=None,
244+
variable_values={"search": "10.0.0.0/8"},
245+
)
246+
247+
assert (
248+
result_address.data["InfrahubSearchAnywhere"]["count"]
249+
== result_network.data["InfrahubSearchAnywhere"]["count"]
250+
== 1
251+
)
252+
253+
assert (
254+
result_address.data["InfrahubSearchAnywhere"]["edges"][0]["node"]["id"]
255+
== result_network.data["InfrahubSearchAnywhere"]["edges"][0]["node"]["id"]
256+
)
257+
258+
259+
@pytest.mark.parametrize(
260+
"query,expected",
261+
[
262+
("2001:0db8:0000:0000:0000:0000:0000:0000/48", "2001:db8::/48"),
263+
("2001:0db8:0000:0000:0000:0000:0000:0000", "2001:db8::"),
264+
("2001:0db8", "2001:db8"),
265+
("2001:0db8:0", "2001:db8"),
266+
("2001:0db8:0000", "2001:db8"),
267+
("2001:0db8:0000:0", "2001:db8"),
268+
("2001:0db8:0000:0000:00", "2001:db8"),
269+
("2001:0db8:0000:0001:00", "2001:db8:0:1"),
270+
("2001:0db8:0001:0002:00", "2001:db8:1:2"),
271+
("2001:0db8:0001:0000:0002:0000:0003", "2001:db8:1:0:2:0:3"),
272+
],
273+
)
274+
def test_collapse_ipv6_address_or_network(query, expected):
275+
assert _collapse_ipv6(query) == expected
276+
277+
278+
@pytest.mark.parametrize(
279+
"query",
280+
["invalid", "invalid:case", "2001:invalid", "2001:0db81:0000", "10.0.0.0", "2001:db8:1"],
281+
)
282+
def test_collapse_ipv6_address_or_network_invalid_cases(query):
283+
with pytest.raises(ValueError):
284+
_collapse_ipv6(query)

changelog/4613.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Search anywhere now supports IPv6 extended format

0 commit comments

Comments
 (0)