@@ -49,10 +49,10 @@ def digest(self) -> Optional[str]:
4949 return _compute_digest (self .binary_path ).hex ()
5050
5151 def execute_msgs (self ) -> dict :
52- return _extract_msgs (self .execute_schema )
52+ return _extract_msgs (self .execute_schema , self . execute_schema )
5353
5454 def query_msgs (self ) -> dict :
55- return _extract_msgs (self .query_schema )
55+ return _extract_msgs (self .query_schema , self . query_schema )
5656
5757 def init_args (self ) -> dict :
5858 return _get_msg_args (self .instantiate_schema )
@@ -65,7 +65,7 @@ def to_variable_name(cls, contract_name: str) -> str:
6565 return contract_name .replace ("-" , "_" )
6666
6767
68- def _extract_msgs (schema : dict ) -> dict :
68+ def _extract_msgs (schema : dict , root_schema : dict ) -> dict :
6969 msgs = {}
7070 if 'oneOf' in schema :
7171 schemas = schema ['oneOf' ]
@@ -74,9 +74,26 @@ def _extract_msgs(schema: dict) -> dict:
7474 else :
7575 return msgs
7676 for msg_schema in schemas :
77- msg = msg_schema ['required' ][0 ]
78- args = _get_msg_args (msg_schema ['properties' ][msg ])
79- msgs [msg ] = args
77+ if "$ref" in msg_schema :
78+ # Resolve the reference
79+ ref_path = msg_schema ["$ref" ].split ("/" )
80+ assert ref_path [0 ] == "#"
81+ ref_schema = root_schema
82+ for key in ref_path [1 :]:
83+ ref_schema = ref_schema [key ]
84+
85+ # Recursively extract messages
86+ nested_msgs = _extract_msgs (ref_schema , root_schema )
87+
88+ # Ensure no overlapping keys
89+ assert len (msgs .items () & nested_msgs .items ()) == 0 , "Nested messages overlap"
90+
91+ msgs = msgs | nested_msgs
92+ else :
93+ # Direct schema definition
94+ msg = msg_schema ['required' ][0 ]
95+ args = _get_msg_args (msg_schema ['properties' ][msg ])
96+ msgs [msg ] = args
8097 return msgs
8198
8299
0 commit comments