Skip to content

Commit 41a33c4

Browse files
committed
refactor: simplify collections normalization
1 parent 924f234 commit 41a33c4

File tree

2 files changed

+21
-36
lines changed

2 files changed

+21
-36
lines changed

graphistry/PlotterBase.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1870,13 +1870,11 @@ def collections(
18701870
encode_collections,
18711871
normalize_collections,
18721872
normalize_collections_url_params,
1873-
normalize_validation_params,
18741873
)
18751874

1876-
validate_mode, warn = normalize_validation_params(validate, warn)
18771875
settings: Dict[str, Any] = {}
18781876
if collections is not None:
1879-
normalized = normalize_collections(collections, validate=validate_mode, warn=warn)
1877+
normalized = normalize_collections(collections, validate=validate, warn=warn)
18801878
settings['collections'] = encode_collections(normalized)
18811879
extras: Dict[str, Any] = {}
18821880
if show_collections is not None:
@@ -1886,7 +1884,7 @@ def collections(
18861884
if collections_global_edge_color is not None:
18871885
extras['collectionsGlobalEdgeColor'] = collections_global_edge_color
18881886
if extras:
1889-
extras = normalize_collections_url_params(extras, validate=validate_mode, warn=warn)
1887+
extras = normalize_collections_url_params(extras, validate=validate, warn=warn)
18901888
settings.update(extras)
18911889

18921890
if len(settings.keys()) > 0:

graphistry/validate/validate_collections.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def normalize_validation_params(
3232
return validate_mode, warn
3333

3434

35-
def encode_collections(collections: List[Dict[str, Any]], encode: bool = True) -> str:
35+
def encode_collections(collections: List[Dict[str, Any]]) -> str:
3636
json_str = json.dumps(collections, separators=(',', ':'), ensure_ascii=True)
37-
return quote(json_str, safe='') if encode else json_str
37+
return quote(json_str, safe='')
3838

3939

4040
def _issue(
@@ -86,33 +86,13 @@ def _parse_collections_input(
8686
return []
8787

8888

89-
def _coerce_str_field(
89+
def _normalize_str_field(
9090
entry: Dict[str, Any],
9191
key: str,
9292
validate_mode: ValidationMode,
9393
warn: bool,
94-
entry_index: int
95-
) -> None:
96-
if key not in entry or entry[key] is None:
97-
return
98-
if isinstance(entry[key], str):
99-
return
100-
_issue(
101-
f'Collection field "{key}" should be a string',
102-
{'index': entry_index, 'value': entry[key], 'type': type(entry[key]).__name__},
103-
validate_mode,
104-
warn
105-
)
106-
if validate_mode == 'autofix':
107-
entry[key] = str(entry[key])
108-
109-
110-
def _normalize_color_field(
111-
entry: Dict[str, Any],
112-
key: str,
113-
validate_mode: ValidationMode,
114-
warn: bool,
115-
entry_index: int
94+
entry_index: int,
95+
autofix_drop: bool
11696
) -> None:
11797
if key not in entry or entry[key] is None:
11898
return
@@ -125,7 +105,10 @@ def _normalize_color_field(
125105
warn
126106
)
127107
if validate_mode == 'autofix':
128-
entry.pop(key, None)
108+
if autofix_drop:
109+
entry.pop(key, None)
110+
else:
111+
entry[key] = str(entry[key])
129112

130113

131114
def _normalize_sets_list(
@@ -356,11 +339,10 @@ def normalize_collections(
356339
continue
357340
return []
358341

359-
for field in ('id', 'name', 'description', 'node_color', 'edge_color'):
360-
if field in ('node_color', 'edge_color'):
361-
_normalize_color_field(normalized_entry, field, validate_mode, warn, idx)
362-
else:
363-
_coerce_str_field(normalized_entry, field, validate_mode, warn, idx)
342+
for field in ('id', 'name', 'description'):
343+
_normalize_str_field(normalized_entry, field, validate_mode, warn, idx, autofix_drop=False)
344+
for field in ('node_color', 'edge_color'):
345+
_normalize_str_field(normalized_entry, field, validate_mode, warn, idx, autofix_drop=True)
364346

365347
expr = normalized_entry.get('expr')
366348
if collection_type == 'intersection':
@@ -372,6 +354,11 @@ def normalize_collections(
372354
continue
373355
return []
374356
normalized_entry['expr'] = normalized_expr
357+
normalized_entry = {
358+
key: normalized_entry[key]
359+
for key in _ALLOWED_COLLECTION_FIELDS_ORDER
360+
if key in normalized_entry
361+
}
375362
normalized.append(normalized_entry)
376363

377364
return normalized
@@ -388,7 +375,7 @@ def normalize_collections_url_params(
388375
if 'collections' in updated:
389376
normalized = normalize_collections(updated['collections'], validate_mode, warn)
390377
if len(normalized) > 0:
391-
updated['collections'] = encode_collections(normalized, encode=True)
378+
updated['collections'] = encode_collections(normalized)
392379
else:
393380
if validate_mode in ('strict', 'strict-fast'):
394381
return updated

0 commit comments

Comments
 (0)