1+ import ast
12import os
23import textwrap
34
45import openai
56from langchain .chat_models import ChatOpenAI
67from langchain .prompts import PromptTemplate
7- from redbaron import RedBaron
88
9- from gpt4docstrings import utils
10- from gpt4docstrings .docstrings_generators .utils .decorators import retry
9+ from gpt4docstrings .docstrings_generators .docstring import Docstring
1110from gpt4docstrings .docstrings_generators .utils .parsers import DocstringParser
1211from gpt4docstrings .docstrings_generators .utils .prompts import CLASS_PROMPTS
1312from gpt4docstrings .docstrings_generators .utils .prompts import FUNCTION_PROMPTS
14- from gpt4docstrings .exceptions import ASTError
13+ from gpt4docstrings .visit import GPT4DocstringsNode
1514
1615
1716class ChatGPTDocstringGenerator :
@@ -32,12 +31,13 @@ def __init__(
3231 self .model_name = model_name
3332 self .docstring_style = docstring_style
3433
35- self .model = ChatOpenAI (model_name = model_name , temperature = 1.0 )
34+ self .model = ChatOpenAI (
35+ model_name = model_name , temperature = 1.0 , openai_api_key = self .api_key
36+ )
3637 self .function_prompt_template = FUNCTION_PROMPTS .get (docstring_style )
3738 self .class_prompt_template = CLASS_PROMPTS .get (docstring_style )
3839
39- @retry (max_retries = 5 , delay = 5 )
40- def _get_completion (self , prompt : str ) -> str :
40+ async def _get_completion (self , prompt : str ) -> str :
4141 """
4242 Generates a completion using the ChatGPT model.
4343
@@ -47,86 +47,42 @@ def _get_completion(self, prompt: str) -> str:
4747 Returns:
4848 str: The generated completion.
4949 """
50- return self .model .predict (prompt ).strip ()
50+ return await self .model .apredict (prompt )
51+
52+ def _get_template (self , node : GPT4DocstringsNode ):
53+ """Returns a function template or a class template depending on the node type"""
54+ if node .node_type in ["FunctionDef" , "AsyncFunctionDef" ]:
55+ return self .function_prompt_template
56+ else :
57+ return self .class_prompt_template
5158
52- def generate_function_docstring (self , source : str ) -> dict :
59+ async def generate_docstring (self , node : GPT4DocstringsNode ) -> Docstring :
5360 """
5461 Generates a docstring for a function.
5562
5663 Args:
57- source (str ): The source code of the function.
64+ node (GPT4DocstringsNode ): A GPT4DocstringsNode node
5865
5966 Returns:
60- dict: A dictionary containing the generated docstring.
61-
62- Raises:
63- ASTError: Raises an ASTError when there are errors interacting with an AST node
67+ Docstring: A Docstring object
6468 """
65- source = source .strip ()
69+ source = node . source .strip ()
6670 stripped_source = textwrap .dedent (source )
67- prompt = PromptTemplate (
68- template = self .function_prompt_template ,
69- input_variables = ["code" ],
70- )
71- _input = prompt .format_prompt (code = stripped_source )
72- fn_src = DocstringParser ().parse (self ._get_completion (_input .to_string ()))
73-
74- try :
75- fn_node = RedBaron (fn_src ).find_all ("def" )[0 ]
76- return {
77- "docstring" : utils .add_indentation_to_docstring (
78- '"""' + textwrap .dedent (fn_node [0 ].to_python ()) + '"""' ,
79- fn_node [0 ].indentation ,
80- )
81- }
82- except ValueError as e :
83- raise ASTError (
84- f"Some error has occurred when trying to parse the current AST node: { e } "
85- ) from e
86-
87- def generate_class_docstring (self , source : str ) -> dict :
88- """
89- Generates docstrings for a class.
71+ prompt_template = self ._get_template (node )
72+ parent_offset = node .col_offset
9073
91- Args:
92- source (str): The source code of the class.
93-
94- Returns:
95- dict: A dictionary containing the generated docstrings.
96-
97- Raises:
98- ASTError: Raises an ASTError when there are errors interacting with an AST node
99- """
100- source = source .strip ()
101- stripped_source = textwrap .dedent (source )
10274 prompt = PromptTemplate (
103- template = self . class_prompt_template ,
75+ template = prompt_template ,
10476 input_variables = ["code" ],
10577 )
10678 _input = prompt .format_prompt (code = stripped_source )
107- class_src = DocstringParser ().parse (self ._get_completion (_input .to_string ()))
108-
109- # TODO: Add here access to class node explicitly.
110- try :
111- class_node = RedBaron (class_src ).find_all ("class" )[0 ]
112- method_nodes = [f for f in class_node .find_all ("def" )]
113-
114- docstrings = {}
115- for method_node in method_nodes :
116- docstrings [method_node .name ] = utils .add_indentation_to_docstring (
117- '"""' + textwrap .dedent (method_node [0 ].to_python ()) + '"""' ,
118- method_node [0 ].indentation ,
79+ src = DocstringParser ().parse (await self ._get_completion (_input .to_string ()))
80+
81+ tree = ast .parse (src )
82+ for n in ast .walk (tree ):
83+ if isinstance (n , (ast .FunctionDef , ast .AsyncFunctionDef , ast .ClassDef )):
84+ return Docstring (
85+ text = ast .get_docstring (n ),
86+ col_offset = n .body [- 1 ].col_offset + parent_offset ,
87+ lineno = node .lineno ,
11988 )
120-
121- docstrings ["docstring" ] = class_node .value [0 ]
122- docstrings ["docstring" ] = utils .add_indentation_to_docstring (
123- '"""' + textwrap .dedent (class_node [0 ].to_python ()) + '"""' ,
124- class_node [0 ].indentation ,
125- )
126-
127- return docstrings
128-
129- except ValueError as e :
130- raise ASTError (
131- f"Some error has occurred when trying to parse the current AST node: { e } "
132- ) from e
0 commit comments