66from copy import copy
77from enum import Enum
88from inspect import getdoc , isclass
9- from typing import TYPE_CHECKING , Any , Callable , List , Optional , Union , get_args , get_origin
9+ from typing import TYPE_CHECKING , Any , Callable , List , Optional , Union , get_args , get_origin , get_type_hints
1010
1111from docstring_parser import parse
1212from pydantic import BaseModel , create_model
@@ -53,35 +53,38 @@ class PydanticDataType(Enum):
5353
5454
5555def map_pydantic_type_to_gbnf (pydantic_type : type [Any ]) -> str :
56- if isclass (pydantic_type ) and issubclass (pydantic_type , str ):
56+ origin_type = get_origin (pydantic_type )
57+ origin_type = pydantic_type if origin_type is None else origin_type
58+
59+ if isclass (origin_type ) and issubclass (origin_type , str ):
5760 return PydanticDataType .STRING .value
58- elif isclass (pydantic_type ) and issubclass (pydantic_type , bool ):
61+ elif isclass (origin_type ) and issubclass (origin_type , bool ):
5962 return PydanticDataType .BOOLEAN .value
60- elif isclass (pydantic_type ) and issubclass (pydantic_type , int ):
63+ elif isclass (origin_type ) and issubclass (origin_type , int ):
6164 return PydanticDataType .INTEGER .value
62- elif isclass (pydantic_type ) and issubclass (pydantic_type , float ):
65+ elif isclass (origin_type ) and issubclass (origin_type , float ):
6366 return PydanticDataType .FLOAT .value
64- elif isclass (pydantic_type ) and issubclass (pydantic_type , Enum ):
67+ elif isclass (origin_type ) and issubclass (origin_type , Enum ):
6568 return PydanticDataType .ENUM .value
6669
67- elif isclass (pydantic_type ) and issubclass (pydantic_type , BaseModel ):
68- return format_model_and_field_name (pydantic_type .__name__ )
69- elif get_origin ( pydantic_type ) is list :
70+ elif isclass (origin_type ) and issubclass (origin_type , BaseModel ):
71+ return format_model_and_field_name (origin_type .__name__ )
72+ elif origin_type is list :
7073 element_type = get_args (pydantic_type )[0 ]
7174 return f"{ map_pydantic_type_to_gbnf (element_type )} -list"
72- elif get_origin ( pydantic_type ) is set :
75+ elif origin_type is set :
7376 element_type = get_args (pydantic_type )[0 ]
7477 return f"{ map_pydantic_type_to_gbnf (element_type )} -set"
75- elif get_origin ( pydantic_type ) is Union :
78+ elif origin_type is Union :
7679 union_types = get_args (pydantic_type )
7780 union_rules = [map_pydantic_type_to_gbnf (ut ) for ut in union_types ]
7881 return f"union-{ '-or-' .join (union_rules )} "
79- elif get_origin ( pydantic_type ) is Optional :
82+ elif origin_type is Optional :
8083 element_type = get_args (pydantic_type )[0 ]
8184 return f"optional-{ map_pydantic_type_to_gbnf (element_type )} "
82- elif isclass (pydantic_type ):
83- return f"{ PydanticDataType .CUSTOM_CLASS .value } -{ format_model_and_field_name (pydantic_type .__name__ )} "
84- elif get_origin ( pydantic_type ) is dict :
85+ elif isclass (origin_type ):
86+ return f"{ PydanticDataType .CUSTOM_CLASS .value } -{ format_model_and_field_name (origin_type .__name__ )} "
87+ elif origin_type is dict :
8588 key_type , value_type = get_args (pydantic_type )
8689 return f"custom-dict-key-type-{ format_model_and_field_name (map_pydantic_type_to_gbnf (key_type ))} -value-type-{ format_model_and_field_name (map_pydantic_type_to_gbnf (value_type ))} "
8790 else :
@@ -118,7 +121,7 @@ def get_members_structure(cls, rule_name):
118121 # Modify this comprehension
119122 members = [
120123 f' "\\ "{ name } \\ "" ":" { map_pydantic_type_to_gbnf (param_type )} '
121- for name , param_type in cls . __annotations__ .items ()
124+ for name , param_type in get_type_hints ( cls ) .items ()
122125 if name != "self"
123126 ]
124127
@@ -297,17 +300,20 @@ def generate_gbnf_rule_for_type(
297300 field_name = format_model_and_field_name (field_name )
298301 gbnf_type = map_pydantic_type_to_gbnf (field_type )
299302
300- if isclass (field_type ) and issubclass (field_type , BaseModel ):
303+ origin_type = get_origin (field_type )
304+ origin_type = field_type if origin_type is None else origin_type
305+
306+ if isclass (origin_type ) and issubclass (origin_type , BaseModel ):
301307 nested_model_name = format_model_and_field_name (field_type .__name__ )
302308 nested_model_rules , _ = generate_gbnf_grammar (field_type , processed_models , created_rules )
303309 rules .extend (nested_model_rules )
304310 gbnf_type , rules = nested_model_name , rules
305- elif isclass (field_type ) and issubclass (field_type , Enum ):
311+ elif isclass (origin_type ) and issubclass (origin_type , Enum ):
306312 enum_values = [f'"\\ "{ e .value } \\ ""' for e in field_type ] # Adding escaped quotes
307313 enum_rule = f"{ model_name } -{ field_name } ::= { ' | ' .join (enum_values )} "
308314 rules .append (enum_rule )
309315 gbnf_type , rules = model_name + "-" + field_name , rules
310- elif get_origin ( field_type ) == list : # Array
316+ elif origin_type is list : # Array
311317 element_type = get_args (field_type )[0 ]
312318 element_rule_name , additional_rules = generate_gbnf_rule_for_type (
313319 model_name , f"{ field_name } -element" , element_type , is_optional , processed_models , created_rules
@@ -317,7 +323,7 @@ def generate_gbnf_rule_for_type(
317323 rules .append (array_rule )
318324 gbnf_type , rules = model_name + "-" + field_name , rules
319325
320- elif get_origin ( field_type ) == set or field_type == set : # Array
326+ elif origin_type is set : # Array
321327 element_type = get_args (field_type )[0 ]
322328 element_rule_name , additional_rules = generate_gbnf_rule_for_type (
323329 model_name , f"{ field_name } -element" , element_type , is_optional , processed_models , created_rules
@@ -371,7 +377,7 @@ def generate_gbnf_rule_for_type(
371377 gbnf_type = f"{ model_name } -{ field_name } -optional"
372378 else :
373379 gbnf_type = f"{ model_name } -{ field_name } -union"
374- elif isclass (field_type ) and issubclass (field_type , str ):
380+ elif isclass (origin_type ) and issubclass (origin_type , str ):
375381 if field_info and hasattr (field_info , "json_schema_extra" ) and field_info .json_schema_extra is not None :
376382 triple_quoted_string = field_info .json_schema_extra .get ("triple_quoted_string" , False )
377383 markdown_string = field_info .json_schema_extra .get ("markdown_code_block" , False )
@@ -387,8 +393,8 @@ def generate_gbnf_rule_for_type(
387393 gbnf_type = PydanticDataType .STRING .value
388394
389395 elif (
390- isclass (field_type )
391- and issubclass (field_type , float )
396+ isclass (origin_type )
397+ and issubclass (origin_type , float )
392398 and field_info
393399 and hasattr (field_info , "json_schema_extra" )
394400 and field_info .json_schema_extra is not None
@@ -413,8 +419,8 @@ def generate_gbnf_rule_for_type(
413419 )
414420
415421 elif (
416- isclass (field_type )
417- and issubclass (field_type , int )
422+ isclass (origin_type )
423+ and issubclass (origin_type , int )
418424 and field_info
419425 and hasattr (field_info , "json_schema_extra" )
420426 and field_info .json_schema_extra is not None
@@ -462,15 +468,15 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
462468 if not issubclass (model , BaseModel ):
463469 # For non-Pydantic classes, generate model_fields from __annotations__ or __init__
464470 if hasattr (model , "__annotations__" ) and model .__annotations__ :
465- model_fields = {name : (typ , ...) for name , typ in model . __annotations__ . items ()} # pyright: ignore[reportGeneralTypeIssues]
471+ model_fields = {name : (typ , ...) for name , typ in get_type_hints ( model ). items ()}
466472 else :
467473 init_signature = inspect .signature (model .__init__ )
468474 parameters = init_signature .parameters
469475 model_fields = {name : (param .annotation , param .default ) for name , param in parameters .items () if
470476 name != "self" }
471477 else :
472478 # For Pydantic models, use model_fields and check for ellipsis (required fields)
473- model_fields = model . __annotations__
479+ model_fields = get_type_hints ( model )
474480
475481 model_rule_parts = []
476482 nested_rules = []
@@ -706,7 +712,7 @@ def generate_markdown_documentation(
706712 else :
707713 documentation += f" Fields:\n " # noqa: F541
708714 if isclass (model ) and issubclass (model , BaseModel ):
709- for name , field_type in model . __annotations__ .items ():
715+ for name , field_type in get_type_hints ( model ) .items ():
710716 # if name == "markdown_code_block":
711717 # continue
712718 if get_origin (field_type ) == list :
@@ -754,14 +760,17 @@ def generate_field_markdown(
754760 field_info = model .model_fields .get (field_name )
755761 field_description = field_info .description if field_info and field_info .description else ""
756762
757- if get_origin (field_type ) == list :
763+ origin_type = get_origin (field_type )
764+ origin_type = field_type if origin_type is None else origin_type
765+
766+ if origin_type == list :
758767 element_type = get_args (field_type )[0 ]
759768 field_text = f"{ indent } { field_name } ({ format_model_and_field_name (field_type .__name__ )} of { format_model_and_field_name (element_type .__name__ )} )"
760769 if field_description != "" :
761770 field_text += ":\n "
762771 else :
763772 field_text += "\n "
764- elif get_origin ( field_type ) == Union :
773+ elif origin_type == Union :
765774 element_types = get_args (field_type )
766775 types = []
767776 for element_type in element_types :
@@ -792,9 +801,9 @@ def generate_field_markdown(
792801 example_text = f"'{ field_example } '" if isinstance (field_example , str ) else field_example
793802 field_text += f"{ indent } Example: { example_text } \n "
794803
795- if isclass (field_type ) and issubclass (field_type , BaseModel ):
804+ if isclass (origin_type ) and issubclass (origin_type , BaseModel ):
796805 field_text += f"{ indent } Details:\n "
797- for name , type_ in field_type . __annotations__ .items ():
806+ for name , type_ in get_type_hints ( field_type ) .items ():
798807 field_text += generate_field_markdown (name , type_ , field_type , depth + 2 )
799808
800809 return field_text
@@ -855,7 +864,7 @@ def generate_text_documentation(
855864
856865 if isclass (model ) and issubclass (model , BaseModel ):
857866 documentation_fields = ""
858- for name , field_type in model . __annotations__ .items ():
867+ for name , field_type in get_type_hints ( model ) .items ():
859868 # if name == "markdown_code_block":
860869 # continue
861870 if get_origin (field_type ) == list :
@@ -948,7 +957,7 @@ def generate_field_text(
948957
949958 if isclass (field_type ) and issubclass (field_type , BaseModel ):
950959 field_text += f"{ indent } Details:\n "
951- for name , type_ in field_type . __annotations__ .items ():
960+ for name , type_ in get_type_hints ( field_type ) .items ():
952961 field_text += generate_field_text (name , type_ , field_type , depth + 2 )
953962
954963 return field_text
0 commit comments