|
1 | 1 | //! Copyright The KCL Authors. All rights reserved.
|
2 | 2 |
|
| 3 | +use std::net::IpAddr; |
3 | 4 | use std::net::Ipv4Addr;
|
4 | 5 | use std::net::Ipv6Addr;
|
5 | 6 | use std::str::FromStr;
|
@@ -576,35 +577,72 @@ pub extern "C-unwind" fn kclvm_net_is_IP_in_CIDR(
|
576 | 577 | let args = ptr_as_ref(args);
|
577 | 578 | let kwargs = ptr_as_ref(kwargs);
|
578 | 579 |
|
579 |
| - if let Some(ip) = get_call_arg_str(args, kwargs, 0, Some("ip")) { |
580 |
| - if let Some(cidr) = get_call_arg_str(args, kwargs, 1, Some("cidr")) { |
581 |
| - let parts: Vec<&str> = cidr.split('/').collect(); |
582 |
| - if parts.len() == 2 { |
583 |
| - let cidr_ip = parts[0]; |
584 |
| - let mask_bits = parts[1]; |
585 |
| - let ip_addr = match Ipv4Addr::from_str(&ip) { |
586 |
| - Ok(ip_addr) => ip_addr, |
587 |
| - Err(_) => return kclvm_value_False(ctx), |
588 |
| - }; |
589 |
| - let cidr_ip_addr = match Ipv4Addr::from_str(cidr_ip) { |
590 |
| - Ok(cidr_ip_addr) => cidr_ip_addr, |
591 |
| - Err(_) => return kclvm_value_False(ctx), |
592 |
| - }; |
593 |
| - let mask_bits = match mask_bits.parse::<u8>() { |
594 |
| - Ok(mask_bits) if mask_bits <= 32 => mask_bits, |
595 |
| - _ => return kclvm_value_False(ctx), |
596 |
| - }; |
597 |
| - let mask = !((1 << (32 - mask_bits)) - 1); |
598 |
| - let ip_u32 = u32::from_be_bytes(ip_addr.octets()); |
599 |
| - let cidr_ip_u32 = u32::from_be_bytes(cidr_ip_addr.octets()); |
600 |
| - let is_in_cidr = (ip_u32 & mask) == (cidr_ip_u32 & mask); |
601 |
| - return kclvm_value_Bool(ctx, is_in_cidr as i8); |
602 |
| - } |
| 580 | + let ip = match get_call_arg_str(args, kwargs, 0, Some("ip")) { |
| 581 | + Some(ip) => ip, |
| 582 | + None => { |
| 583 | + panic!("is_IP_in_CIDR() missing 2 required positional arguments: 'ip' and 'cidr'"); |
603 | 584 | }
|
604 |
| - return kclvm_value_False(ctx); |
| 585 | + }; |
| 586 | + let cidr = match get_call_arg_str(args, kwargs, 1, Some("cidr")) { |
| 587 | + Some(cidr) => cidr, |
| 588 | + None => { |
| 589 | + panic!("is_IP_in_CIDR() missing 2 required positional arguments: 'ip' and 'cidr'"); |
| 590 | + } |
| 591 | + }; |
| 592 | + |
| 593 | + let (cidr_ip, mask_bits) = match _parse_cidr(&cidr) { |
| 594 | + Some((ip, bits)) => (ip, bits), |
| 595 | + None => return kclvm_value_False(ctx), |
| 596 | + }; |
| 597 | + |
| 598 | + match (_parse_ip(&ip), cidr_ip) { |
| 599 | + (Ok(IpAddr::V4(ip)), IpAddr::V4(cidr_ip)) => { |
| 600 | + let mask_bits = match mask_bits { |
| 601 | + Some(bits) if bits <= 32 => bits, |
| 602 | + _ => return kclvm_value_False(ctx), |
| 603 | + }; |
| 604 | + kclvm_value_Bool(ctx, _check_ipv4_cidr(ip, cidr_ip, mask_bits) as i8) |
| 605 | + } |
| 606 | + (Ok(IpAddr::V6(ip)), IpAddr::V6(cidr_ip)) => { |
| 607 | + let mask_bits = match mask_bits { |
| 608 | + Some(bits) if bits <= 128 => bits, |
| 609 | + _ => return kclvm_value_False(ctx), |
| 610 | + }; |
| 611 | + kclvm_value_Bool(ctx, _check_ipv6_cidr(ip, cidr_ip, mask_bits) as i8) |
| 612 | + } |
| 613 | + _ => kclvm_value_False(ctx), |
605 | 614 | }
|
| 615 | +} |
| 616 | + |
| 617 | +fn _parse_cidr(cidr: &str) -> Option<(IpAddr, Option<u32>)> { |
| 618 | + let parts: Vec<&str> = cidr.split('/').collect(); |
| 619 | + if parts.len() != 2 { |
| 620 | + return None; |
| 621 | + } |
| 622 | + let ip = IpAddr::from_str(parts[0]).ok()?; |
| 623 | + let mask_bits = parts[1].parse::<u32>().ok(); |
| 624 | + Some((ip, mask_bits)) |
| 625 | +} |
606 | 626 |
|
607 |
| - panic!("is_IP_in_CIDR() missing 2 required positional arguments: 'ip' and 'cidr'"); |
| 627 | +fn _parse_ip(ip: &str) -> Result<IpAddr, ()> { |
| 628 | + IpAddr::from_str(ip).map_err(|_| ()) |
| 629 | +} |
| 630 | + |
| 631 | +fn _check_ipv4_cidr(ip: Ipv4Addr, cidr_ip: Ipv4Addr, mask_bits: u32) -> bool { |
| 632 | + let mask = !((1u32 << (32 - mask_bits)) - 1); |
| 633 | + let ip_u32 = u32::from(ip); |
| 634 | + let cidr_u32 = u32::from(cidr_ip); |
| 635 | + (ip_u32 & mask) == (cidr_u32 & mask) |
| 636 | +} |
| 637 | + |
| 638 | +fn _check_ipv6_cidr(ip: Ipv6Addr, cidr_ip: Ipv6Addr, mask_bits: u32) -> bool { |
| 639 | + let mask = match 128 - mask_bits { |
| 640 | + shift @ 0..=128 => !((1u128 << shift) - 1), |
| 641 | + _ => return false, |
| 642 | + }; |
| 643 | + let ip_u128 = u128::from(ip); |
| 644 | + let cidr_u128 = u128::from(cidr_ip); |
| 645 | + (ip_u128 & mask) == (cidr_u128 & mask) |
608 | 646 | }
|
609 | 647 |
|
610 | 648 | #[allow(non_camel_case_types, non_snake_case)]
|
|
0 commit comments