1414
1515logger = logging .getLogger (__name__ )
1616
17+ # These are the classes that are not language specific, but have language specific subclasses with different names
18+ SPECIAL_BASE_CLASSES = {"SourceFile" : "File" }
19+
1720
1821def sanitize_docstring_for_markdown (docstring : str | None ) -> str :
1922 """Sanitize the docstring for MDX"""
@@ -82,6 +85,9 @@ def is_language_base_class(cls_obj: Class):
8285 Returns:
8386 bool: if `cls_obj` is a language base class
8487 """
88+ if cls_obj .name in SPECIAL_BASE_CLASSES :
89+ return True
90+
8591 sub_classes = cls_obj .subclasses (max_depth = 1 )
8692 base_name = cls_obj .name .lower ()
8793 return any (sub_class .name .lower () in [f"py{ base_name } " , f"ts{ base_name } " ] for sub_class in sub_classes )
@@ -184,24 +190,53 @@ def has_documentation(c: Class):
184190 return any ([dec .name == "ts_apidoc" or dec .name == "py_apidoc" or dec .name == "apidoc" for dec in c .decorators ])
185191
186192
187- def safe_get_class (codebase : Codebase , class_name : str ) -> Class | None :
188- symbols = codebase .get_symbols (class_name )
189- if not symbols :
190- return None
191-
192- if len (symbols ) == 1 and isinstance (symbols [0 ], Class ):
193- return symbols [0 ]
194-
195- possible_classes = [s for s in symbols if isinstance (s , Class ) and has_documentation (s )]
196- if not possible_classes :
197- return None
198- if len (possible_classes ) == 1 :
199- return possible_classes [0 ]
200- msg = f"Found { len (possible_classes )} classes with name { class_name } "
201- raise ValueError (msg )
193+ def safe_get_class (codebase : Codebase , class_name : str , language : str | None = None ) -> Class | None :
194+ """Find the class in the codebase.
202195
203-
204- def find_symbol (codebase : Codebase , symbol_name : str , resolved_types : list [Type ], parent_class : Class , parent_symbol : Symbol , types_cache : dict ):
196+ Args:
197+ codebase (Codebase): the codebase to search in
198+ class_name (str): the name of the class to resolve
199+ language (str | None): the language of the class to resolve
200+ Returns:
201+ Class | None: the class if found, None otherwise
202+ """
203+ if '"' in class_name :
204+ class_name = class_name .strip ('"' )
205+ if "'" in class_name :
206+ class_name = class_name .strip ("'" )
207+
208+ symbols = []
209+ try :
210+ class_obj = codebase .get_class (class_name , optional = True )
211+ if not class_obj :
212+ return None
213+
214+ except Exception :
215+ symbols = codebase .get_symbols (class_name )
216+ possible_classes = [s for s in symbols if isinstance (s , Class ) and has_documentation (s )]
217+ if not possible_classes :
218+ return None
219+ if len (possible_classes ) > 1 :
220+ msg = f"Found { len (possible_classes )} classes with name { class_name } "
221+ raise ValueError (msg )
222+ class_obj = possible_classes [0 ]
223+
224+ if language and is_language_base_class (class_obj ):
225+ sub_classes = class_obj .subclasses (max_depth = 1 )
226+
227+ if class_name in SPECIAL_BASE_CLASSES :
228+ class_name = SPECIAL_BASE_CLASSES [class_name ]
229+
230+ if language == ProgrammingLanguage .PYTHON .value :
231+ sub_classes = [s for s in sub_classes if s .name == f"Py{ class_name } " ]
232+ elif language == ProgrammingLanguage .TYPESCRIPT .value :
233+ sub_classes = [s for s in sub_classes if s .name == f"TS{ class_name } " ]
234+ if len (sub_classes ) == 1 :
235+ class_obj = sub_classes [0 ]
236+ return class_obj
237+
238+
239+ def resolve_type_symbol (codebase : Codebase , symbol_name : str , resolved_types : list [Type ], parent_class : Class , parent_symbol : Symbol , types_cache : dict ):
205240 """Find the symbol in the codebase.
206241
207242 Args:
@@ -217,11 +252,13 @@ def find_symbol(codebase: Codebase, symbol_name: str, resolved_types: list[Type]
217252 return symbol_name
218253 if symbol_name .lower () == "self" :
219254 return f"<{ create_path (parent_class )} >"
220- if symbol_name in types_cache :
221- return types_cache [symbol_name ]
255+
256+ language = get_langauge (parent_class )
257+ if (symbol_name , language ) in types_cache :
258+ return types_cache [(symbol_name , language )]
222259
223260 trgt_symbol = None
224- cls_obj = safe_get_class (codebase , symbol_name )
261+ cls_obj = safe_get_class (codebase = codebase , class_name = symbol_name , language = language )
225262 if cls_obj :
226263 trgt_symbol = cls_obj
227264
@@ -230,8 +267,8 @@ def find_symbol(codebase: Codebase, symbol_name: str, resolved_types: list[Type]
230267 for resolved_type in symbol .resolved_types :
231268 if isinstance (resolved_type , FunctionCall ) and len (resolved_type .args ) >= 2 :
232269 bound_arg = resolved_type .args [1 ]
233- bound_name = bound_arg .value
234- if cls_obj := safe_get_class (codebase , bound_name ):
270+ bound_name = bound_arg .value . source
271+ if cls_obj := safe_get_class (codebase , bound_name , language = get_langauge ( parent_class ) ):
235272 trgt_symbol = cls_obj
236273 break
237274
@@ -241,7 +278,7 @@ def find_symbol(codebase: Codebase, symbol_name: str, resolved_types: list[Type]
241278
242279 if trgt_symbol and isinstance (trgt_symbol , Callable ) and has_documentation (trgt_symbol ):
243280 trgt_path = f"<{ create_path (trgt_symbol )} >"
244- types_cache [symbol_name ] = trgt_path
281+ types_cache [( symbol_name , language ) ] = trgt_path
245282 return trgt_path
246283
247284 return symbol_name
@@ -318,10 +355,12 @@ def process_parts(content):
318355 base_type = part [: part .index ("[" )]
319356 bracket_content = part [part .index ("[" ) :].strip ("[]" )
320357 processed_bracket = process_parts (bracket_content )
321- replacement = find_symbol (codebase = codebase , symbol_name = base_type , resolved_types = resolved_types , parent_class = parent_class , parent_symbol = parent_symbol , types_cache = types_cache )
358+ replacement = resolve_type_symbol (
359+ codebase = codebase , symbol_name = base_type , resolved_types = resolved_types , parent_class = parent_class , parent_symbol = parent_symbol , types_cache = types_cache
360+ )
322361 processed_part = replacement + "[" + processed_bracket + "]"
323362 else :
324- replacement = find_symbol (codebase = codebase , symbol_name = part , resolved_types = resolved_types , parent_class = parent_class , parent_symbol = parent_symbol , types_cache = types_cache )
363+ replacement = resolve_type_symbol (codebase = codebase , symbol_name = part , resolved_types = resolved_types , parent_class = parent_class , parent_symbol = parent_symbol , types_cache = types_cache )
325364 processed_part = replacement
326365 processed_parts .append (processed_part )
327366
@@ -340,9 +379,30 @@ def process_parts(content):
340379 base_type = input_str [: input_str .index ("[" )]
341380 bracket_content = input_str [input_str .index ("[" ) :].strip ("[]" )
342381 processed_content = process_parts (bracket_content )
343- replacement = find_symbol (codebase = codebase , symbol_name = base_type , resolved_types = resolved_types , parent_class = parent_class , parent_symbol = parent_symbol , types_cache = types_cache )
382+ replacement = resolve_type_symbol (codebase = codebase , symbol_name = base_type , resolved_types = resolved_types , parent_class = parent_class , parent_symbol = parent_symbol , types_cache = types_cache )
344383 return replacement + "[" + processed_content + "]"
345384 # Handle simple input
346385 else :
347- replacement = find_symbol (codebase = codebase , symbol_name = input_str , resolved_types = resolved_types , parent_class = parent_class , parent_symbol = parent_symbol , types_cache = types_cache )
386+ replacement = resolve_type_symbol (codebase = codebase , symbol_name = input_str , resolved_types = resolved_types , parent_class = parent_class , parent_symbol = parent_symbol , types_cache = types_cache )
348387 return replacement
388+
389+
390+ def extract_class_description (docstring ):
391+ """Extract the class description from a docstring, excluding the attributes section.
392+
393+ Args:
394+ docstring (str): The class docstring to parse
395+
396+ Returns:
397+ str: The class description with whitespace normalized
398+ """
399+ if not docstring :
400+ return ""
401+
402+ # Split by "Attributes:" and take only the first part
403+ parts = docstring .split ("Attributes:" )
404+ description = parts [0 ]
405+
406+ # Normalize whitespace
407+ lines = [line .strip () for line in description .strip ().splitlines ()]
408+ return " " .join (filter (None , lines ))
0 commit comments