Skip to content

Commit 739e211

Browse files
committed
Add support for partial address search
1 parent 32460c1 commit 739e211

File tree

3 files changed

+148
-7
lines changed

3 files changed

+148
-7
lines changed

backend/infrahub/core/utils.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,66 @@ def convert_ip_to_binary_str(
191191
return ip_bin.zfill(obj.max_prefixlen)
192192

193193

194-
def collapse_ipv6_address_or_network(address_or_network: str) -> str:
195-
if "/" in address_or_network:
196-
return ipaddress.IPv6Network(address_or_network).with_prefixlen
197-
return str(ipaddress.IPv6Address(address_or_network))
194+
def collapse_ipv6(s: str) -> str:
195+
"""Collapse an ipv6 address, ipv6 network, or a partial ipv6 address. Raises an error if input does not resemble an IPv6 address."""
196+
197+
try:
198+
return str(ipaddress.IPv6Address(s))
199+
except ipaddress.AddressValueError:
200+
pass
201+
202+
try:
203+
return ipaddress.IPv6Network(s).with_prefixlen
204+
except ipaddress.AddressValueError:
205+
pass
206+
207+
# Input string might be an incomplete address in IPv6 format,
208+
# in which case we would like the collapsed form equivalent of this incomplete address for matching purposes.
209+
# To get it, we first try to pad the incomplete address with zeros, then we retrieve the collapsed form
210+
# of the full address, and we remove extra "::" or ":0" at the end of it.
211+
212+
error_message = "Input string does not match IPv6 format"
213+
214+
if "::" in s:
215+
raise ValueError(error_message)
216+
217+
# Add padding to complete the address if needed
218+
segments = s.split(":")
219+
220+
if len(segments) == 0:
221+
raise ValueError(error_message)
222+
223+
# If any of the non-last segments has less than 4 characters it means we deal with
224+
# a IPv6 collapsed form or an invalid address
225+
for segment in segments[:-1]:
226+
if len(segment) != 4:
227+
raise ValueError(error_message)
228+
229+
# Add 0 padding to last segment
230+
if len(segments[-1]) > 4:
231+
raise ValueError(error_message)
232+
233+
segments[-1] += "0" * (4 - len(segments[-1]))
234+
235+
# Complete the address to have 8 segments by padding with zeros
236+
while len(segments) < 8:
237+
segments.append("0000")
238+
239+
# Create a full IPv6 address from the partial input
240+
full_address = ":".join(segments)
241+
242+
# Create an IPv6Address object for validation and to build IPv6 collapsed form.
243+
ipv6_address = ipaddress.IPv6Address(full_address)
244+
245+
compressed_address = ipv6_address.compressed
246+
247+
# We padded with zeros so address might endswith "::" or ":0".
248+
if compressed_address.endswith(("::", ":0")):
249+
return compressed_address[:-2]
250+
251+
# Otherwise, it means 8th segment of ipv6 address was not full and not composed of 0 only
252+
# e.g. 2001:0db8:0000:0000:0000:0000:03
253+
return compressed_address
198254

199255

200256
# --------------------------------------------------------------------------------

backend/infrahub/graphql/queries/search.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from infrahub.core.constants import InfrahubKind
99
from infrahub.core.manager import NodeManager
10-
from infrahub.core.utils import collapse_ipv6_address_or_network
10+
from infrahub.core.utils import collapse_ipv6
1111

1212
if TYPE_CHECKING:
1313
from graphql import GraphQLResolveInfo
@@ -51,8 +51,8 @@ async def search_resolver(
5151
result.append(matching)
5252
else:
5353
try:
54-
# Convert any IPv6 address/network to collapsed format as it might be stored in db.
55-
q = collapse_ipv6_address_or_network(q)
54+
# Convert any IPv6 address, network or partial address to collapsed format as it might be stored in db.
55+
q = collapse_ipv6(q)
5656
except ValueError:
5757
pass
5858

backend/tests/unit/graphql/queries/test_search.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import pytest
12
from graphql import graphql
23

34
from infrahub.core.branch import Branch
45
from infrahub.core.node import Node
6+
from infrahub.core.utils import collapse_ipv6
57
from infrahub.database import InfrahubDatabase
68
from infrahub.graphql.initialization import prepare_graphql_params
79

@@ -170,11 +172,60 @@ async def test_search_ipv6_network_extended_format(
170172
)
171173

172174

175+
async def test_search_ipv6_partial_address(
176+
db: InfrahubDatabase,
177+
ip_dataset_01,
178+
branch: Branch,
179+
):
180+
gql_params = prepare_graphql_params(db=db, include_subscription=False, branch=branch)
181+
182+
res_two_segments = await graphql(
183+
schema=gql_params.schema,
184+
source=SEARCH_QUERY,
185+
context_value=gql_params.context,
186+
root_value=None,
187+
variable_values={"search": "2001:0db8"},
188+
)
189+
190+
res_partial_segment_1 = await graphql(
191+
schema=gql_params.schema,
192+
source=SEARCH_QUERY,
193+
context_value=gql_params.context,
194+
root_value=None,
195+
variable_values={"search": "2001:0db8:0"},
196+
)
197+
198+
res_partial_segment_2 = await graphql(
199+
schema=gql_params.schema,
200+
source=SEARCH_QUERY,
201+
context_value=gql_params.context,
202+
root_value=None,
203+
variable_values={"search": "2001:0db8:0000:0"},
204+
)
205+
206+
assert (
207+
res_two_segments.data["InfrahubSearchAnywhere"]["count"]
208+
== res_partial_segment_1.data["InfrahubSearchAnywhere"]["count"]
209+
== res_partial_segment_2.data["InfrahubSearchAnywhere"]["count"]
210+
== 2
211+
)
212+
213+
assert (
214+
res_two_segments.data["InfrahubSearchAnywhere"]["edges"][0]["node"]["id"]
215+
== res_partial_segment_1.data["InfrahubSearchAnywhere"]["edges"][0]["node"]["id"]
216+
== res_partial_segment_2.data["InfrahubSearchAnywhere"]["edges"][0]["node"]["id"]
217+
)
218+
219+
173220
async def test_search_ipv4(
174221
db: InfrahubDatabase,
175222
ip_dataset_01,
176223
branch: Branch,
177224
):
225+
"""
226+
This only tests that ipv6 search specific behavior does not break ipv4 search.
227+
"""
228+
178229
gql_params = prepare_graphql_params(db=db, include_subscription=False, branch=branch)
179230

180231
result_address = await graphql(
@@ -203,3 +254,37 @@ async def test_search_ipv4(
203254
result_address.data["InfrahubSearchAnywhere"]["edges"][0]["node"]["id"]
204255
== result_network.data["InfrahubSearchAnywhere"]["edges"][0]["node"]["id"]
205256
)
257+
258+
259+
@pytest.mark.parametrize(
260+
"query,expected",
261+
[
262+
("2001:0db8:0000:0000:0000:0000:0000:0000/48", "2001:db8::/48"),
263+
("2001:0db8:0000:0000:0000:0000:0000:0000", "2001:db8::"),
264+
("2001:0db8", "2001:db8"),
265+
("2001:0db8:0", "2001:db8"),
266+
("2001:0db8:0000", "2001:db8"),
267+
("2001:0db8:0000:0", "2001:db8"),
268+
("2001:0db8:0000:0000:00", "2001:db8"),
269+
("2001:0db8:0000:0001:00", "2001:db8:0:1"),
270+
("2001:0db8:0001:0002:00", "2001:db8:1:2"),
271+
("2001:0db8:0001:0000:0002:0000:0003", "2001:db8:1:0:2:0:3"),
272+
],
273+
)
274+
def test_collapse_ipv6_address_or_network(query, expected):
275+
assert collapse_ipv6(query) == expected
276+
277+
278+
@pytest.mark.parametrize(
279+
"query",
280+
[
281+
"invalid",
282+
"invalid:case",
283+
"2001:invalid",
284+
"2001:0db81:0000",
285+
"10.0.0.0",
286+
],
287+
)
288+
def test_collapse_ipv6_address_or_network_invalid_cases(query):
289+
with pytest.raises(ValueError):
290+
collapse_ipv6(query)

0 commit comments

Comments
 (0)