Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions guardrails/actions/filter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any


class Filter:
Expand All @@ -9,30 +9,19 @@ def apply_filters(value: Any) -> Any:
"""Recursively filter out any values that are instances of Filter."""
if isinstance(value, Filter):
pass
elif isinstance(value, List):
# Cleaner syntax but requires two iterations
# filtered_list = list(filter(None, map(apply_filters, value)))
elif isinstance(value, list):
filtered_list = []
for item in value:
filtered_item = apply_filters(item)
if filtered_item is not None:
filtered_list.append(filtered_item)

return filtered_list
elif isinstance(value, Dict):
# Cleaner syntax but requires two iterations
# filtered_dict = {
# k: apply_filters(v)
# for k, v in value.items()
# if apply_filters(v)
# }
elif isinstance(value, dict):
filtered_dict = {}
for k, v in value.items():
# Should we omit the key or just the value?
filtered_value = apply_filters(v)
if filtered_value is not None:
filtered_dict[k] = filtered_value

return filtered_dict
else:
return value
14 changes: 7 additions & 7 deletions guardrails/actions/refrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ def apply_refrain(value: Any, output_type: OutputTypes) -> Any:

If found, return an empty value of the appropriate type.
"""
refrain_value = {}
if output_type == OutputTypes.STRING:
refrain_value = ""
elif output_type == OutputTypes.LIST:
refrain_value = []

if check_for_refrain(value):
# If the data contains a `Refain` value, we return an empty
# value.
if output_type == OutputTypes.STRING:
refrain_value = ""
elif output_type == OutputTypes.LIST:
refrain_value = []
else:
refrain_value = {}
logger.debug("Refrain detected.")
value = refrain_value
return refrain_value

return value