Skip to content

Commit 57c1354

Browse files
Merge pull request #43 from MichaelisTrofficus/new-traversing-structure
New traversing structure
2 parents 58e2c7a + 2a05cf4 commit 57c1354

File tree

15 files changed

+602
-275
lines changed

15 files changed

+602
-275
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[flake8]
22
select = B,B9,C,D,DAR,E,F,N,RST,S,W
3-
ignore = E203,E501,RST201,RST203,RST301,W503,D100,D103,D102,D101,D107,D105,D205,D212,D415,D104,N803,S605
3+
ignore = E203,E501,RST201,RST203,RST301,W503,D100,D103,D102,D101,D107,D105,D205,D212,D415,D104,N803,S605,N802,C901,B904
44
max-line-length = 120
55
max-complexity = 10
66
docstring-convention = google

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,3 +328,4 @@ pyrightconfig.json
328328
chatgpt_experiment.py
329329
chatgpt_battery.py
330330
chatgpt_battery_classes.py
331+
*.diff

noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def coverage(session: Session) -> None:
159159
"coverage",
160160
*args,
161161
"--omit=tests/*,src/gpt4docstrings/__main__.py,src/gpt4docstrings/cli.py",
162-
"--fail-under=40",
162+
"--fail-under=30",
163163
)
164164

165165

poetry.lock

Lines changed: 72 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "gpt4docstrings"
3-
version = "0.1.2"
3+
version = "0.1.3"
44
description = "gpt4docstrings"
55
authors = ["Miguel Otero Pedrido <miguel.otero.pedrido.1993@gmail.com>"]
66
license = "MIT"
@@ -25,6 +25,11 @@ pytest-mock = "^3.11.1"
2525
tabulate = "^0.9.0"
2626
colorama = "^0.4.6"
2727
langchain = "^0.0.311"
28+
toml = "^0.10.2"
29+
asyncio = "^3.4.3"
30+
aiofiles = "^23.2.1"
31+
reorder-python-imports = "^3.12.0"
32+
astor = "^0.8.1"
2833

2934
[tool.poetry.dev-dependencies]
3035
Pygments = ">=2.10.0"

src/gpt4docstrings/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import click
44

55
import gpt4docstrings
6+
from gpt4docstrings.config import GPT4DocstringsConfig
67

78

89
@click.option(
@@ -73,5 +74,6 @@ def main(paths, **kwargs):
7374
docstring_style=kwargs["style"],
7475
api_key=kwargs["api_key"],
7576
verbose=kwargs["verbose"],
77+
config=GPT4DocstringsConfig(),
7678
)
7779
docstrings_generator.generate_docstrings()

src/gpt4docstrings/config.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2020 Lynn Root
2+
"""Configuration-related helpers."""
3+
# Adapted from Black https://github.com/psf/black/blob/master/black.py.
4+
# This code is adapted from Lynn Root's interrogate library: https://github.com/econchick/interrogate
5+
import pathlib
6+
7+
import attr
8+
9+
10+
# TODO: idea: break out InterrogateConfig into two classes: one for
11+
# running the tool, one for reporting the results
12+
@attr.s
13+
class GPT4DocstringsConfig:
14+
"""
15+
Configuration related to interrogating a given codebase.
16+
17+
Args:
18+
update_file (bool): If `True`, the documented file will be updated in place. This may be dangerous
19+
in same cases, so it's better to use patches to confirm the changes to be applied.
20+
ignore_private (bool): Ignore private classes, methods, and
21+
functions starting with two underscores.
22+
ignore_semiprivate (bool): Ignore semiprivate classes, methods,
23+
and functions starting with a single underscore.
24+
ignore_init_method (bool): Ignore ``__init__`` method of
25+
classes.
26+
ignore_nested_functions (bool): Ignore nested functions and
27+
methods.
28+
"""
29+
30+
update_file = attr.ib(default=False)
31+
ignore_private = attr.ib(default=False)
32+
ignore_semiprivate = attr.ib(default=False)
33+
ignore_init_method = attr.ib(default=False)
34+
ignore_nested_classes = attr.ib(default=False)
35+
ignore_nested_functions = attr.ib(default=False)
36+
ignore_property_setters = attr.ib(default=False)
37+
ignore_property_decorators = attr.ib(default=False)
38+
39+
40+
def find_project_root(srcs):
41+
"""
42+
Return a directory containing .git, .hg, or pyproject.toml.
43+
That directory can be one of the directories passed in `srcs` or their
44+
common parent.
45+
If no directory in the tree contains a marker that would specify it's the
46+
project root, the root of the file system is returned.k
47+
"""
48+
if not srcs:
49+
return pathlib.Path("/").resolve()
50+
51+
common_base = min(pathlib.Path(src).resolve() for src in srcs)
52+
if common_base.is_dir():
53+
# Append a fake file so `parents` below returns `common_base_dir`, too.
54+
common_base /= "fake-file"
55+
56+
for directory in common_base.parents:
57+
if (directory / ".git").exists():
58+
return directory
59+
60+
if (directory / ".hg").is_dir():
61+
return directory
62+
63+
if (directory / "pyproject.toml").is_file():
64+
return directory
65+
66+
return directory
Lines changed: 31 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1+
import ast
12
import os
23
import textwrap
34

45
import openai
56
from langchain.chat_models import ChatOpenAI
67
from langchain.prompts import PromptTemplate
7-
from redbaron import RedBaron
88

9-
from gpt4docstrings import utils
10-
from gpt4docstrings.docstrings_generators.utils.decorators import retry
9+
from gpt4docstrings.docstrings_generators.docstring import Docstring
1110
from gpt4docstrings.docstrings_generators.utils.parsers import DocstringParser
1211
from gpt4docstrings.docstrings_generators.utils.prompts import CLASS_PROMPTS
1312
from gpt4docstrings.docstrings_generators.utils.prompts import FUNCTION_PROMPTS
14-
from gpt4docstrings.exceptions import ASTError
13+
from gpt4docstrings.visit import GPT4DocstringsNode
1514

1615

1716
class ChatGPTDocstringGenerator:
@@ -32,12 +31,13 @@ def __init__(
3231
self.model_name = model_name
3332
self.docstring_style = docstring_style
3433

35-
self.model = ChatOpenAI(model_name=model_name, temperature=1.0)
34+
self.model = ChatOpenAI(
35+
model_name=model_name, temperature=1.0, openai_api_key=self.api_key
36+
)
3637
self.function_prompt_template = FUNCTION_PROMPTS.get(docstring_style)
3738
self.class_prompt_template = CLASS_PROMPTS.get(docstring_style)
3839

39-
@retry(max_retries=5, delay=5)
40-
def _get_completion(self, prompt: str) -> str:
40+
async def _get_completion(self, prompt: str) -> str:
4141
"""
4242
Generates a completion using the ChatGPT model.
4343
@@ -47,86 +47,42 @@ def _get_completion(self, prompt: str) -> str:
4747
Returns:
4848
str: The generated completion.
4949
"""
50-
return self.model.predict(prompt).strip()
50+
return await self.model.apredict(prompt)
51+
52+
def _get_template(self, node: GPT4DocstringsNode):
53+
"""Returns a function template or a class template depending on the node type"""
54+
if node.node_type in ["FunctionDef", "AsyncFunctionDef"]:
55+
return self.function_prompt_template
56+
else:
57+
return self.class_prompt_template
5158

52-
def generate_function_docstring(self, source: str) -> dict:
59+
async def generate_docstring(self, node: GPT4DocstringsNode) -> Docstring:
5360
"""
5461
Generates a docstring for a function.
5562
5663
Args:
57-
source (str): The source code of the function.
64+
node (GPT4DocstringsNode): A GPT4DocstringsNode node
5865
5966
Returns:
60-
dict: A dictionary containing the generated docstring.
61-
62-
Raises:
63-
ASTError: Raises an ASTError when there are errors interacting with an AST node
67+
Docstring: A Docstring object
6468
"""
65-
source = source.strip()
69+
source = node.source.strip()
6670
stripped_source = textwrap.dedent(source)
67-
prompt = PromptTemplate(
68-
template=self.function_prompt_template,
69-
input_variables=["code"],
70-
)
71-
_input = prompt.format_prompt(code=stripped_source)
72-
fn_src = DocstringParser().parse(self._get_completion(_input.to_string()))
73-
74-
try:
75-
fn_node = RedBaron(fn_src).find_all("def")[0]
76-
return {
77-
"docstring": utils.add_indentation_to_docstring(
78-
'"""' + textwrap.dedent(fn_node[0].to_python()) + '"""',
79-
fn_node[0].indentation,
80-
)
81-
}
82-
except ValueError as e:
83-
raise ASTError(
84-
f"Some error has occurred when trying to parse the current AST node: {e}"
85-
) from e
86-
87-
def generate_class_docstring(self, source: str) -> dict:
88-
"""
89-
Generates docstrings for a class.
71+
prompt_template = self._get_template(node)
72+
parent_offset = node.col_offset
9073

91-
Args:
92-
source (str): The source code of the class.
93-
94-
Returns:
95-
dict: A dictionary containing the generated docstrings.
96-
97-
Raises:
98-
ASTError: Raises an ASTError when there are errors interacting with an AST node
99-
"""
100-
source = source.strip()
101-
stripped_source = textwrap.dedent(source)
10274
prompt = PromptTemplate(
103-
template=self.class_prompt_template,
75+
template=prompt_template,
10476
input_variables=["code"],
10577
)
10678
_input = prompt.format_prompt(code=stripped_source)
107-
class_src = DocstringParser().parse(self._get_completion(_input.to_string()))
108-
109-
# TODO: Add here access to class node explicitly.
110-
try:
111-
class_node = RedBaron(class_src).find_all("class")[0]
112-
method_nodes = [f for f in class_node.find_all("def")]
113-
114-
docstrings = {}
115-
for method_node in method_nodes:
116-
docstrings[method_node.name] = utils.add_indentation_to_docstring(
117-
'"""' + textwrap.dedent(method_node[0].to_python()) + '"""',
118-
method_node[0].indentation,
79+
src = DocstringParser().parse(await self._get_completion(_input.to_string()))
80+
81+
tree = ast.parse(src)
82+
for n in ast.walk(tree):
83+
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
84+
return Docstring(
85+
text=ast.get_docstring(n),
86+
col_offset=n.body[-1].col_offset + parent_offset,
87+
lineno=node.lineno,
11988
)
120-
121-
docstrings["docstring"] = class_node.value[0]
122-
docstrings["docstring"] = utils.add_indentation_to_docstring(
123-
'"""' + textwrap.dedent(class_node[0].to_python()) + '"""',
124-
class_node[0].indentation,
125-
)
126-
127-
return docstrings
128-
129-
except ValueError as e:
130-
raise ASTError(
131-
f"Some error has occurred when trying to parse the current AST node: {e}"
132-
) from e

0 commit comments

Comments
 (0)