diff --git a/guardrails/schema/rail_schema.py b/guardrails/schema/rail_schema.py index 922a993d8..da0958ed6 100644 --- a/guardrails/schema/rail_schema.py +++ b/guardrails/schema/rail_schema.py @@ -19,6 +19,20 @@ from guardrails.utils.xml_utils import xml_to_string from guardrails.validator_base import OnFailAction, Validator +_VALIDATION_TYPE_MAPPING = { + RailTypes.STRING: ValidationType(SimpleTypes.STRING), + RailTypes.INTEGER: ValidationType(SimpleTypes.INTEGER), + RailTypes.FLOAT: ValidationType(SimpleTypes.NUMBER), + RailTypes.BOOL: ValidationType(SimpleTypes.BOOLEAN), + RailTypes.DATE: ValidationType(SimpleTypes.STRING), + RailTypes.TIME: ValidationType(SimpleTypes.STRING), + RailTypes.DATETIME: ValidationType(SimpleTypes.STRING), + RailTypes.PERCENTAGE: ValidationType(SimpleTypes.STRING), + RailTypes.ENUM: ValidationType(SimpleTypes.STRING), + RailTypes.LIST: ValidationType(SimpleTypes.ARRAY), + RailTypes.OBJECT: ValidationType(SimpleTypes.OBJECT), +} + ### RAIL to JSON Schema ### STRING_TAGS = [ @@ -107,43 +121,65 @@ def parse_element( ) -> ModelSchema: """Takes an XML element Extracts validators to add to the 'validators' list and validator_map Returns a ModelSchema.""" + schema_type = element.tag if element.tag in STRING_TAGS: schema_type = RailTypes.STRING elif element.tag == "output": schema_type: str = element.attrib.get("type", RailTypes.OBJECT) # type: ignore - description = xml_to_string(element.attrib.get("description")) + # Fast path: avoid xml_to_string if possible, and use .get directly + description_raw = element.attrib.get("description") + description = ( + description_raw + if (description_raw is None or isinstance(description_raw, str)) + else xml_to_string(description_raw) + ) # Extract validators from RAIL and assign into ProcessedSchema extract_validators(element, processed_schema, json_path) json_path = json_path.replace(".*", "") + # Consolidate ModelSchema construction using mapping where possible if schema_type == RailTypes.STRING: - format = xml_to_string(element.attrib.get("format")) + format_raw = element.attrib.get("format") + format = ( + format_raw + if (format_raw is None or isinstance(format_raw, str)) + else xml_to_string(format_raw) + ) return ModelSchema( - type=ValidationType(SimpleTypes.STRING), + type=_VALIDATION_TYPE_MAPPING[RailTypes.STRING], description=description, format=format, ) elif schema_type == RailTypes.INTEGER: - format = xml_to_string(element.attrib.get("format")) + format_raw = element.attrib.get("format") + format = ( + format_raw + if (format_raw is None or isinstance(format_raw, str)) + else xml_to_string(format_raw) + ) return ModelSchema( - type=ValidationType(SimpleTypes.INTEGER), + type=_VALIDATION_TYPE_MAPPING[RailTypes.INTEGER], description=description, format=format, ) elif schema_type == RailTypes.FLOAT: - format = xml_to_string(element.attrib.get("format", RailTypes.FLOAT)) + format_raw = element.attrib.get("format", RailTypes.FLOAT) + format = ( + format_raw if (isinstance(format_raw, str)) else xml_to_string(format_raw) + ) return ModelSchema( - type=ValidationType(SimpleTypes.NUMBER), + type=_VALIDATION_TYPE_MAPPING[RailTypes.FLOAT], description=description, format=format, ) elif schema_type == RailTypes.BOOL: return ModelSchema( - type=ValidationType(SimpleTypes.BOOLEAN), description=description + type=_VALIDATION_TYPE_MAPPING[RailTypes.BOOL], + description=description, ) elif schema_type == RailTypes.DATE: format = extract_format( @@ -152,7 +188,7 @@ def parse_element( internal_format_attr="date-format", ) return ModelSchema( - type=ValidationType(SimpleTypes.STRING), + type=_VALIDATION_TYPE_MAPPING[RailTypes.DATE], description=description, format=format, ) @@ -163,7 +199,7 @@ def parse_element( internal_format_attr="time-format", ) return ModelSchema( - type=ValidationType(SimpleTypes.STRING), + type=_VALIDATION_TYPE_MAPPING[RailTypes.TIME], description=description, format=format, ) @@ -174,7 +210,7 @@ def parse_element( internal_format_attr="datetime-format", ) return ModelSchema( - type=ValidationType(SimpleTypes.STRING), + type=_VALIDATION_TYPE_MAPPING[RailTypes.DATETIME], description=description, format=format, ) @@ -185,22 +221,27 @@ def parse_element( internal_format_attr="", ) return ModelSchema( - type=ValidationType(SimpleTypes.STRING), + type=_VALIDATION_TYPE_MAPPING[RailTypes.PERCENTAGE], description=description, format=format, ) elif schema_type == RailTypes.ENUM: - format = xml_to_string(element.attrib.get("format")) - csv = xml_to_string(element.attrib.get("values", "")) or "" + format_raw = element.attrib.get("format") + format = ( + format_raw + if (format_raw is None or isinstance(format_raw, str)) + else xml_to_string(format_raw) + ) + csv_raw = element.attrib.get("values", "") + csv = csv_raw if (isinstance(csv_raw, str)) else (xml_to_string(csv_raw) or "") values = [v.strip() for v in csv.split(",")] if csv else None return ModelSchema( - type=ValidationType(SimpleTypes.STRING), + type=_VALIDATION_TYPE_MAPPING[RailTypes.ENUM], description=description, format=format, enum=values, ) elif schema_type == RailTypes.LIST: - items = None children = list(element) num_of_children = len(children) if num_of_children > 1: @@ -216,9 +257,12 @@ def parse_element( ) items = child_schema.to_dict() return ModelSchema( - type=ValidationType(SimpleTypes.ARRAY), items=items, description=description + type=_VALIDATION_TYPE_MAPPING[RailTypes.LIST], + items=items, + description=description, ) elif schema_type == RailTypes.OBJECT: + # Use list comprehensions and avoid extra lookups properties = {} required: List[str] = [] for child in element: @@ -226,6 +270,7 @@ def parse_element( child_required = child.get("required", "true") == "true" if not name: output_path = json_path.replace("$.", "output.") + # Avoid calling .replace if not needed logger.warning( f"{output_path} has a nameless child which is not allowed!" ) @@ -236,7 +281,7 @@ def parse_element( properties[name] = child_schema.to_dict() object_schema = ModelSchema( - type=ValidationType(SimpleTypes.OBJECT), + type=_VALIDATION_TYPE_MAPPING[RailTypes.OBJECT], properties=properties, description=description, required=required, @@ -264,7 +309,7 @@ def parse_element( if not discriminator: raise ValueError(" elements must specify a discriminator!") discriminator_model = ModelSchema( - type=ValidationType(SimpleTypes.STRING), enum=[] + type=_VALIDATION_TYPE_MAPPING[RailTypes.STRING], enum=[] ) for choice_case in element: case_name = choice_case.get("name") @@ -309,7 +354,7 @@ def parse_element( properties = {} properties[discriminator] = discriminator_model.to_dict() return ModelSchema( - type=ValidationType(SimpleTypes.OBJECT), + type=_VALIDATION_TYPE_MAPPING[RailTypes.OBJECT], properties=properties, required=[discriminator], allOf=allOf, @@ -317,9 +362,12 @@ def parse_element( ) else: # TODO: What if the user specifies a custom tag _and_ a format? - format = xml_to_string(element.attrib.get("format", schema_type)) + format_raw = element.attrib.get("format", schema_type) + format = ( + format_raw if (isinstance(format_raw, str)) else xml_to_string(format_raw) + ) return ModelSchema( - type=ValidationType(SimpleTypes.STRING), + type=_VALIDATION_TYPE_MAPPING[RailTypes.STRING], description=description, format=format, )