|
| 1 | +from collections.abc import Callable |
| 2 | +import operator as py_operator |
| 3 | + |
1 | 4 | from pyspark.sql import Column |
2 | 5 | import pyspark.sql.functions as F |
| 6 | + |
3 | 7 | from databricks.labs.dqx.rule import register_rule |
4 | | -from databricks.labs.dqx.check_funcs import make_condition, _get_normalized_column_and_expr |
| 8 | +from databricks.labs.dqx.check_funcs import make_condition, _get_normalized_column_and_expr, _get_limit_expr |
5 | 9 |
|
6 | 10 | POINT_TYPE = "ST_Point" |
7 | 11 | LINESTRING_TYPE = "ST_LineString" |
|
10 | 14 | MULTILINESTRING_TYPE = "ST_MultiLineString" |
11 | 15 | MULTIPOLYGON_TYPE = "ST_MultiPolygon" |
12 | 16 | GEOMETRYCOLLECTION_TYPE = "ST_GeometryCollection" |
| 17 | +DEFAULT_SRID = 4326 |
13 | 18 |
|
14 | 19 |
|
15 | 20 | @register_rule("row") |
@@ -448,3 +453,335 @@ def has_y_coordinate_between(column: str | Column, min_value: float, max_value: |
448 | 453 | F.concat_ws("", F.lit("value `"), col_expr.cast("string"), F.lit(condition_str)), |
449 | 454 | f"{col_str_norm}_has_y_coordinates_outside_range", |
450 | 455 | ) |
| 456 | + |
| 457 | + |
| 458 | +@register_rule("row") |
| 459 | +def is_area_equal_to( |
| 460 | + column: str | Column, value: int | float | str | Column, srid: int | None = 3857, geodesic: bool = False |
| 461 | +) -> Column: |
| 462 | + """ |
| 463 | + Checks if the areas of values in a geometry or geography column are equal to a specified value. By default, the 2D |
| 464 | + Cartesian area in WGS84 (Pseudo-Mercator) with units of meters squared is used. An SRID can be specified to |
| 465 | + transform the input values and compute areas with specific units of measure. |
| 466 | +
|
| 467 | + Args: |
| 468 | + column: Column to check; can be a string column name or a column expression |
| 469 | + value: Value to use in the condition as number, column name or sql expression |
| 470 | + srid: Optional integer SRID to use for computing the area of the geometry or geography value (default `None`). |
| 471 | + If an SRID is provided, the input value is translated and area is calculated using the units of measure of |
| 472 | + the specified coordinate reference system (e.g. meters squared for `srid=3857`). |
| 473 | + geodesic: Whether to use the 2D geodesic area (default `False`). |
| 474 | +
|
| 475 | + Returns: |
| 476 | + Column object indicating whether the area the geometries in the input column are equal to the provided value |
| 477 | +
|
| 478 | + Note: |
| 479 | + This function requires Databricks serverless compute or runtime 17.1 or above. |
| 480 | + """ |
| 481 | + return _compare_sql_function_result( |
| 482 | + column, |
| 483 | + value, |
| 484 | + spatial_function="st_area", |
| 485 | + spatial_quantity_label="area", |
| 486 | + spatial_quantity_name="area", |
| 487 | + compare_op=py_operator.ne, |
| 488 | + compare_op_label="not equal to", |
| 489 | + compare_op_name="not_equal_to", |
| 490 | + srid=srid, |
| 491 | + geodesic=geodesic, |
| 492 | + ) |
| 493 | + |
| 494 | + |
| 495 | +@register_rule("row") |
| 496 | +def is_area_not_equal_to( |
| 497 | + column: str | Column, value: int | float | str | Column, srid: int | None = 3857, geodesic: bool = False |
| 498 | +) -> Column: |
| 499 | + """ |
| 500 | + Checks if the areas of values in a geometry column are not equal to a specified value. By default, the 2D |
| 501 | + Cartesian area in WGS84 (Pseudo-Mercator) with units of meters squared is used. An SRID can be specified to |
| 502 | + transform the input values and compute areas with specific units of measure. |
| 503 | +
|
| 504 | + Args: |
| 505 | + column: Column to check; can be a string column name or a column expression |
| 506 | + value: Value to use in the condition as number, column name or sql expression |
| 507 | + srid: Optional integer SRID to use for computing the area of the geometry or geography value (default `None`). |
| 508 | + If an SRID is provided, the input value is translated and area is calculated using the units of measure of |
| 509 | + the specified coordinate reference system (e.g. meters squared for `srid=3857`). |
| 510 | + geodesic: Whether to use the 2D geodesic area (default `False`). |
| 511 | +
|
| 512 | + Returns: |
| 513 | + Column object indicating whether the area the geometries in the input column are not equal to the provided value |
| 514 | +
|
| 515 | + Note: |
| 516 | + This function requires Databricks serverless compute or runtime 17.1 or above. |
| 517 | + """ |
| 518 | + return _compare_sql_function_result( |
| 519 | + column, |
| 520 | + value, |
| 521 | + spatial_function="st_area", |
| 522 | + spatial_quantity_label="area", |
| 523 | + spatial_quantity_name="area", |
| 524 | + compare_op=py_operator.eq, |
| 525 | + compare_op_label="equal to", |
| 526 | + compare_op_name="equal_to", |
| 527 | + srid=srid, |
| 528 | + geodesic=geodesic, |
| 529 | + ) |
| 530 | + |
| 531 | + |
| 532 | +@register_rule("row") |
| 533 | +def is_area_not_greater_than( |
| 534 | + column: str | Column, value: int | float | str | Column, srid: int | None = 3857, geodesic: bool = False |
| 535 | +) -> Column: |
| 536 | + """ |
| 537 | + Checks if the areas of values in a geometry column are not greater than a specified limit. By default, the 2D |
| 538 | + Cartesian area in WGS84 (Pseudo-Mercator) with units of meters squared is used. An SRID can be specified to |
| 539 | + transform the input values and compute areas with specific units of measure. |
| 540 | +
|
| 541 | + Args: |
| 542 | + column: Column to check; can be a string column name or a column expression |
| 543 | + value: Value to use in the condition as number, column name or sql expression |
| 544 | + srid: Optional integer SRID to use for computing the area of the geometry or geography value (default `None`). |
| 545 | + If an SRID is provided, the input value is translated and area is calculated using the units of measure of |
| 546 | + the specified coordinate reference system (e.g. meters squared for `srid=3857`). |
| 547 | + geodesic: Whether to use the 2D geodesic area (default `False`). |
| 548 | +
|
| 549 | + Returns: |
| 550 | + Column object indicating whether the area the geometries in the input column is greater than the provided value |
| 551 | +
|
| 552 | + Note: |
| 553 | + This function requires Databricks serverless compute or runtime 17.1 or above. |
| 554 | + """ |
| 555 | + return _compare_sql_function_result( |
| 556 | + column, |
| 557 | + value, |
| 558 | + spatial_function="st_area", |
| 559 | + spatial_quantity_label="area", |
| 560 | + spatial_quantity_name="area", |
| 561 | + compare_op=py_operator.gt, |
| 562 | + compare_op_label="greater than", |
| 563 | + compare_op_name="greater_than", |
| 564 | + srid=srid, |
| 565 | + geodesic=geodesic, |
| 566 | + ) |
| 567 | + |
| 568 | + |
| 569 | +@register_rule("row") |
| 570 | +def is_area_not_less_than( |
| 571 | + column: str | Column, value: int | float | str | Column, srid: int | None = 3857, geodesic: bool = False |
| 572 | +) -> Column: |
| 573 | + """ |
| 574 | + Checks if the areas of values in a geometry column are not less than a specified limit. By default, the 2D |
| 575 | + Cartesian area in WGS84 (Pseudo-Mercator) with units of meters squared is used. An SRID can be specified to |
| 576 | + transform the input values and compute areas with specific units of measure. |
| 577 | +
|
| 578 | + Args: |
| 579 | + column: Column to check; can be a string column name or a column expression |
| 580 | + value: Value to use in the condition as number, column name or sql expression |
| 581 | + srid: Optional integer SRID to use for computing the area of the geometry or geography value (default `None`). |
| 582 | + If an SRID is provided, the input value is translated and area is calculated using the units of measure of |
| 583 | + the specified coordinate reference system (e.g. meters squared for `srid=3857`). |
| 584 | + geodesic: Whether to use the 2D geodesic area (default `False`). |
| 585 | +
|
| 586 | + Returns: |
| 587 | + Column object indicating whether the area the geometries in the input column is less than the provided value |
| 588 | +
|
| 589 | + Note: |
| 590 | + This function requires Databricks serverless compute or runtime 17.1 or above. |
| 591 | + """ |
| 592 | + return _compare_sql_function_result( |
| 593 | + column, |
| 594 | + value, |
| 595 | + spatial_function="st_area", |
| 596 | + spatial_quantity_label="area", |
| 597 | + spatial_quantity_name="area", |
| 598 | + compare_op=py_operator.lt, |
| 599 | + compare_op_label="less than", |
| 600 | + compare_op_name="less_than", |
| 601 | + srid=srid, |
| 602 | + geodesic=geodesic, |
| 603 | + ) |
| 604 | + |
| 605 | + |
| 606 | +@register_rule("row") |
| 607 | +def is_num_points_equal_to(column: str | Column, value: int | float | str | Column) -> Column: |
| 608 | + """ |
| 609 | + Checks if the number of coordinate pairs in values of a geometry column is equal to a specified value. |
| 610 | +
|
| 611 | + Args: |
| 612 | + column: Column to check; can be a string column name or a column expression |
| 613 | + value: Value to use in the condition as number, column name or sql expression |
| 614 | +
|
| 615 | + Returns: |
| 616 | + Column object indicating whether the number of coordinate pairs in the geometries of the input column is |
| 617 | + equal to the provided value |
| 618 | +
|
| 619 | + Note: |
| 620 | + This function requires Databricks serverless compute or runtime 17.1 or above. |
| 621 | + """ |
| 622 | + return _compare_sql_function_result( |
| 623 | + column, |
| 624 | + value, |
| 625 | + spatial_function="st_npoints", |
| 626 | + spatial_quantity_label="number of coordinates", |
| 627 | + spatial_quantity_name="num_points", |
| 628 | + compare_op=py_operator.ne, |
| 629 | + compare_op_label="not equal to", |
| 630 | + compare_op_name="not_equal_to", |
| 631 | + ) |
| 632 | + |
| 633 | + |
| 634 | +@register_rule("row") |
| 635 | +def is_num_points_not_equal_to(column: str | Column, value: int | float | str | Column) -> Column: |
| 636 | + """ |
| 637 | + Checks if the number of coordinate pairs in values of a geometry column is not equal to a specified value. |
| 638 | +
|
| 639 | + Args: |
| 640 | + column: Column to check; can be a string column name or a column expression |
| 641 | + value: Value to use in the condition as number, column name or sql expression |
| 642 | +
|
| 643 | + Returns: |
| 644 | + Column object indicating whether the number of coordinate pairs in the geometries of the input column is not |
| 645 | + equal to the provided value |
| 646 | +
|
| 647 | + Note: |
| 648 | + This function requires Databricks serverless compute or runtime 17.1 or above. |
| 649 | + """ |
| 650 | + return _compare_sql_function_result( |
| 651 | + column, |
| 652 | + value, |
| 653 | + spatial_function="st_npoints", |
| 654 | + spatial_quantity_label="number of coordinates", |
| 655 | + spatial_quantity_name="num_points", |
| 656 | + compare_op=py_operator.eq, |
| 657 | + compare_op_label="equal to", |
| 658 | + compare_op_name="equal_to", |
| 659 | + ) |
| 660 | + |
| 661 | + |
| 662 | +@register_rule("row") |
| 663 | +def is_num_points_not_greater_than(column: str | Column, value: int | float | str | Column) -> Column: |
| 664 | + """ |
| 665 | + Checks if the number of coordinate pairs in the values of a geometry column is not greater than a specified limit. |
| 666 | +
|
| 667 | + Args: |
| 668 | + column: Column to check; can be a string column name or a column expression |
| 669 | + value: Value to use in the condition as number, column name or sql expression |
| 670 | +
|
| 671 | + Returns: |
| 672 | + Column object indicating whether the number of coordinate pairs in the geometries of the input column is |
| 673 | + greater than the provided value |
| 674 | +
|
| 675 | + Note: |
| 676 | + This function requires Databricks serverless compute or runtime 17.1 or above. |
| 677 | + """ |
| 678 | + return _compare_sql_function_result( |
| 679 | + column, |
| 680 | + value, |
| 681 | + spatial_function="st_npoints", |
| 682 | + spatial_quantity_label="number of coordinates", |
| 683 | + spatial_quantity_name="num_points", |
| 684 | + compare_op=py_operator.gt, |
| 685 | + compare_op_label="greater than", |
| 686 | + compare_op_name="greater_than", |
| 687 | + ) |
| 688 | + |
| 689 | + |
| 690 | +@register_rule("row") |
| 691 | +def is_num_points_not_less_than(column: str | Column, value: int | float | str | Column) -> Column: |
| 692 | + """ |
| 693 | + Checks if the number of coordinate pairs in values of a geometry column is not less than a specified limit. |
| 694 | +
|
| 695 | + Args: |
| 696 | + column: Column to check; can be a string column name or a column expression |
| 697 | + value: Value to use in the condition as number, column name or sql expression |
| 698 | +
|
| 699 | + Returns: |
| 700 | + Column object indicating whether the number of coordinate pairs in the geometries of the input column is |
| 701 | + less than the provided value |
| 702 | +
|
| 703 | + Note: |
| 704 | + This function requires Databricks serverless compute or runtime 17.1 or above. |
| 705 | + """ |
| 706 | + return _compare_sql_function_result( |
| 707 | + column, |
| 708 | + value, |
| 709 | + spatial_function="st_npoints", |
| 710 | + spatial_quantity_label="number of coordinates", |
| 711 | + spatial_quantity_name="num_points", |
| 712 | + compare_op=py_operator.lt, |
| 713 | + compare_op_label="less than", |
| 714 | + compare_op_name="less_than", |
| 715 | + ) |
| 716 | + |
| 717 | + |
| 718 | +def _compare_sql_function_result( |
| 719 | + column: str | Column, |
| 720 | + value: int | float | str | Column, |
| 721 | + spatial_function: str, |
| 722 | + spatial_quantity_label: str, |
| 723 | + spatial_quantity_name: str, |
| 724 | + compare_op: Callable[[Column, Column], Column], |
| 725 | + compare_op_label: str, |
| 726 | + compare_op_name: str, |
| 727 | + srid: int | None = None, |
| 728 | + geodesic: bool = False, |
| 729 | +) -> Column: |
| 730 | + """ |
| 731 | + Compares the results from applying a spatial SQL function (e.g. `st_area`) on a geometry column against a limit |
| 732 | + using the specified comparison operator. |
| 733 | +
|
| 734 | + Args: |
| 735 | + column: Column to check; can be a string column name or a column expression |
| 736 | + value: Value to use in the condition as number, column name or sql expression |
| 737 | + spatial_function: Spatial SQL function as a string (e.g. `st_npoints`) |
| 738 | + spatial_quantity_label: Spatial quantity label (e.g. `number of coordinates` ) |
| 739 | + spatial_quantity_name: Spatial quantity identifier (e.g. `num_points`) |
| 740 | + compare_op: Comparison operator (e.g., `operator.gt`, `operator.lt`). |
| 741 | + compare_op_label: Human-readable label for the comparison (e.g., 'greater than'). |
| 742 | + compare_op_name: Name identifier for the comparison (e.g., 'greater_than'). |
| 743 | + srid: Optional integer SRID for computing measurements on the converted geometry or geography value (default `None`). |
| 744 | + geodesic: Whether to convert the input column to a geography type for computing geodesic distances. |
| 745 | +
|
| 746 | + Returns: |
| 747 | + Column object indicating whether the area the geometries in the input column is less than the provided limit |
| 748 | +
|
| 749 | + Note: |
| 750 | + This function requires Databricks serverless compute or runtime 17.1 or above. |
| 751 | + """ |
| 752 | + col_str_norm, col_expr_str, col_expr = _get_normalized_column_and_expr(column) |
| 753 | + value_expr = _get_limit_expr(value) |
| 754 | + # NOTE: This function is currently only available in Databricks runtime 17.1 or above or in |
| 755 | + # Databricks SQL, due to the use of the `try_to_geometry` and `st_area` functions. |
| 756 | + if geodesic: |
| 757 | + spatial_conversion_expr = f"try_to_geography({col_str_norm})" |
| 758 | + spatial_data_type = "geography" |
| 759 | + elif srid: |
| 760 | + spatial_conversion_expr = f"st_transform(st_setsrid(try_to_geometry({col_str_norm}), {DEFAULT_SRID}), {srid})" |
| 761 | + spatial_data_type = "geometry" |
| 762 | + else: |
| 763 | + spatial_conversion_expr = f"try_to_geometry({col_str_norm})" |
| 764 | + spatial_data_type = "geometry" |
| 765 | + |
| 766 | + is_valid_cond = F.expr(f"{spatial_conversion_expr} IS NULL") |
| 767 | + is_valid_message = F.concat_ws( |
| 768 | + "", |
| 769 | + F.lit("value `"), |
| 770 | + col_expr.cast("string"), |
| 771 | + F.lit(f"` in column `{col_expr_str}` is not a valid {spatial_data_type}"), |
| 772 | + ) |
| 773 | + compare_cond = compare_op(F.expr(f"{spatial_function}({spatial_conversion_expr})"), value_expr) |
| 774 | + compare_message = F.concat_ws( |
| 775 | + "", |
| 776 | + F.lit("value `"), |
| 777 | + col_expr.cast("string"), |
| 778 | + F.lit(f"` in column `{col_expr_str}` has {spatial_quantity_label} {compare_op_label} value: "), |
| 779 | + value_expr.cast("string"), |
| 780 | + ) |
| 781 | + condition = F.when(col_expr.isNull(), F.lit(None)).otherwise(is_valid_cond | compare_cond) |
| 782 | + |
| 783 | + return make_condition( |
| 784 | + condition, |
| 785 | + F.when(is_valid_cond, is_valid_message).otherwise(compare_message), |
| 786 | + f"{col_str_norm}_{spatial_quantity_name}_{compare_op_name}_limit", |
| 787 | + ) |
0 commit comments