|
| 1 | +from collections.abc import Mapping |
| 2 | + |
1 | 3 | from modflow_devtools.dfn.schema.v2 import FieldV2 |
2 | 4 |
|
3 | 5 |
|
@@ -31,3 +33,86 @@ def record_child_type(field: FieldV2) -> str: |
31 | 33 | def keystring_children(field: FieldV2) -> dict: |
32 | 34 | """Get the children of a keystring field for union generation.""" |
33 | 35 | return {} if field.type != "union" else field.children |
| 36 | + |
| 37 | + |
| 38 | +def is_period_list_field(field: FieldV2) -> bool: |
| 39 | + """Check if a field is part of a period block list/recarray.""" |
| 40 | + if not field.shape or not field.block: |
| 41 | + return False |
| 42 | + return ( |
| 43 | + "period" in field.block |
| 44 | + and field.type in ["string", "integer", "double"] |
| 45 | + and field.shape is not None |
| 46 | + ) |
| 47 | + |
| 48 | + |
| 49 | +def group_period_fields(block_fields: Mapping[str, FieldV2]) -> dict[str, list[str]]: |
| 50 | + """ |
| 51 | + Group period block fields that should be combined into a single list. |
| 52 | +
|
| 53 | + Returns a dict mapping the first field name to a list of all field names |
| 54 | + in the group. Fields are grouped if they share similar shapes (same base |
| 55 | + dimensions like nper, nnodes). |
| 56 | + """ |
| 57 | + period_fields = { |
| 58 | + name: field for name, field in block_fields.items() if is_period_list_field(field) |
| 59 | + } |
| 60 | + |
| 61 | + if not period_fields: |
| 62 | + return {} |
| 63 | + |
| 64 | + # All period fields in the same block should be combined into one recarray |
| 65 | + # Return a single group with all field names |
| 66 | + field_names = list(period_fields.keys()) |
| 67 | + if field_names: |
| 68 | + return {field_names[0]: field_names} |
| 69 | + return {} |
| 70 | + |
| 71 | + |
| 72 | +def get_recarray_name(block_name: str) -> str: |
| 73 | + """Get the name for a recarray representing period data in a block.""" |
| 74 | + # Use similar naming to V1: stress_period_data, perioddata, etc. |
| 75 | + if block_name == "period": |
| 76 | + return "stress_period_data" |
| 77 | + return f"{block_name}data" |
| 78 | + |
| 79 | + |
| 80 | +def get_recarray_columns(field_names: list[str], block_fields: Mapping[str, FieldV2]) -> list[str]: |
| 81 | + """ |
| 82 | + Get column names for a recarray, similar to V1 format. |
| 83 | +
|
| 84 | + Returns column names in format like: ['cellid', 'q', 'aux', 'boundname'] |
| 85 | + """ |
| 86 | + columns = [] |
| 87 | + |
| 88 | + # Check if any field has spatial dimensions (indicates cellid is needed) |
| 89 | + has_spatial = False |
| 90 | + for name in field_names: |
| 91 | + field = block_fields[name] |
| 92 | + if field.shape and any( |
| 93 | + dim in field.shape for dim in ["nnodes", "ncells", "nlay", "nrow", "ncol"] |
| 94 | + ): |
| 95 | + has_spatial = True |
| 96 | + break |
| 97 | + |
| 98 | + if has_spatial: |
| 99 | + columns.append("cellid") |
| 100 | + |
| 101 | + # Add the field names as columns |
| 102 | + columns.extend(field_names) |
| 103 | + |
| 104 | + return columns |
| 105 | + |
| 106 | + |
| 107 | +def get_all_grouped_field_names(blocks: Mapping[str, Mapping[str, FieldV2]]) -> set[str]: |
| 108 | + """ |
| 109 | + Get all field names that are grouped into recarrays across all blocks. |
| 110 | +
|
| 111 | + Returns a set of field names that should not have individual rules generated. |
| 112 | + """ |
| 113 | + grouped = set() |
| 114 | + for block_fields in blocks.values(): |
| 115 | + period_groups = group_period_fields(block_fields) |
| 116 | + for field_list in period_groups.values(): |
| 117 | + grouped.update(field_list) |
| 118 | + return grouped |
0 commit comments