Skip to content

Commit 62671ec

Browse files
MarkDaoustmarkmcd
andauthored
Check format with black. (#31)
* Add a GH action to run the tests. * Add a action to check format with black. * format with black --------- Co-authored-by: Mark McDonald <[email protected]>
1 parent 196601a commit 62671ec

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+5647
-5765
lines changed

.github/workflows/test_pr.yaml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,17 @@ jobs:
5151
python --version
5252
pip install -q -e .[dev]
5353
python -m unittest discover --pattern '*test*.py'
54-
54+
format:
55+
name: Check format with black
56+
runs-on: ubuntu-latest
57+
steps:
58+
- uses: actions/checkout@v3
59+
- uses: actions/setup-python@v4
60+
with:
61+
python-version: '3.11'
62+
- name: Check format
63+
run: |
64+
python --version
65+
pip install -q -e .
66+
pip install -q black
67+
black . --check

docs/build_docs.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,11 @@
120120
)
121121

122122
_CODE_URL_PREFIX = flags.DEFINE_string(
123-
"code_url_prefix", "https://github.com/google/generative-ai-python/blob/master/google/generativeai",
124-
"where to find the project code")
123+
"code_url_prefix",
124+
"https://github.com/google/generative-ai-python/blob/master/google/generativeai",
125+
"where to find the project code",
126+
)
127+
125128

126129
class MyFilter:
127130
def __init__(self, base_dirs):
@@ -134,9 +137,7 @@ def drop_staticmethods(self, parent, children):
134137
yield name, value
135138

136139
def __call__(self, path, parent, children):
137-
if (
138-
any("generativelanguage" in part for part in path)
139-
or "generativeai" in path):
140+
if any("generativelanguage" in part for part in path) or "generativeai" in path:
140141
children = self.filter_base_dirs(path, parent, children)
141142
children = public_api.explicit_package_contents_filter(
142143
path, parent, children
@@ -158,7 +159,9 @@ def make_default_filters(self):
158159
public_api.add_proto_fields,
159160
public_api.filter_builtin_modules,
160161
public_api.filter_private_symbols,
161-
MyFilter(self._base_dir), # Replaces: public_api.FilterBaseDirs(self._base_dir),
162+
MyFilter(
163+
self._base_dir
164+
), # Replaces: public_api.FilterBaseDirs(self._base_dir),
162165
public_api.FilterPrivateMap(self._private_map),
163166
public_api.filter_doc_controls_skip,
164167
public_api.ignore_typing,
@@ -186,8 +189,10 @@ def gen_api_docs():
186189
pathlib.Path(google.generativeai.__file__).parent,
187190
pathlib.Path(google.ai.generativelanguage.__file__).parent.parent,
188191
),
189-
code_url_prefix=(_CODE_URL_PREFIX.value,
190-
'https://github.com/googleapis/google-cloud-python/tree/main/packages/google-ai-generativelanguage/google/ai'),
192+
code_url_prefix=(
193+
_CODE_URL_PREFIX.value,
194+
"https://github.com/googleapis/google-cloud-python/tree/main/packages/google-ai-generativelanguage/google/ai",
195+
),
191196
search_hints=_SEARCH_HINTS.value,
192197
site_path=_SITE_PATH.value,
193198
callbacks=[],
@@ -216,12 +221,12 @@ def gen_api_docs():
216221
redirects_path.write_text(yaml.dump(redirects))
217222

218223
# clear `oneof` junk from proto pages
219-
for fpath in out_path.rglob('*.md'):
224+
for fpath in out_path.rglob("*.md"):
220225
old_content = fpath.read_text()
221226
new_content = old_content
222-
new_content = re.sub(r'\.\. _oneof:.*?\n', '', new_content)
223-
new_content = re.sub(r'`oneof`_.*?\n', '', new_content)
224-
new_content = re.sub(r'\.\. code-block:: python.*?\n', '', new_content)
227+
new_content = re.sub(r"\.\. _oneof:.*?\n", "", new_content)
228+
new_content = re.sub(r"`oneof`_.*?\n", "", new_content)
229+
new_content = re.sub(r"\.\. code-block:: python.*?\n", "", new_content)
225230
if new_content != old_content:
226231
fpath.write_text(new_content)
227232

google/generativeai/notebook/__init__.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616

1717

1818
def load_ipython_extension(ipython):
19-
"""Register the Colab Magic extension to support %load_ext."""
20-
# pylint: disable-next=g-import-not-at-top
21-
from google.generativeai.notebook import magics
19+
"""Register the Colab Magic extension to support %load_ext."""
20+
# pylint: disable-next=g-import-not-at-top
21+
from google.generativeai.notebook import magics
2222

23-
ipython.register_magics(magics.Magics)
23+
ipython.register_magics(magics.Magics)
2424

25-
# Since we're in an interactive environment, make the tables prettier.
26-
try:
27-
# pylint: disable-next=g-import-not-at-top
28-
from google import colab
25+
# Since we're in an interactive environment, make the tables prettier.
26+
try:
27+
# pylint: disable-next=g-import-not-at-top
28+
from google import colab
2929

30-
colab.data_table.enable_dataframe_formatter()
31-
except ImportError:
32-
pass
30+
colab.data_table.enable_dataframe_formatter()
31+
except ImportError:
32+
pass

google/generativeai/notebook/argument_parser.py

Lines changed: 55 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -34,79 +34,79 @@
3434

3535
# pylint: disable-next=g-bad-exception-name
3636
class _ParserBaseException(RuntimeError, metaclass=abc.ABCMeta):
37-
"""Base class for parser exceptions including normal exit."""
37+
"""Base class for parser exceptions including normal exit."""
3838

39-
def __init__(self, msgs: Sequence[str], *args, **kwargs):
40-
super().__init__("".join(msgs), *args, **kwargs)
41-
self._msgs = msgs
42-
self._ipython_env: ipython_env.IPythonEnv | None = None
39+
def __init__(self, msgs: Sequence[str], *args, **kwargs):
40+
super().__init__("".join(msgs), *args, **kwargs)
41+
self._msgs = msgs
42+
self._ipython_env: ipython_env.IPythonEnv | None = None
4343

44-
def set_ipython_env(self, env: ipython_env.IPythonEnv) -> None:
45-
self._ipython_env = env
44+
def set_ipython_env(self, env: ipython_env.IPythonEnv) -> None:
45+
self._ipython_env = env
4646

47-
def _ipython_display_(self):
48-
self.display(self._ipython_env)
47+
def _ipython_display_(self):
48+
self.display(self._ipython_env)
4949

50-
def msgs(self) -> Sequence[str]:
51-
return self._msgs
50+
def msgs(self) -> Sequence[str]:
51+
return self._msgs
5252

53-
@abc.abstractmethod
54-
def display(self, env: ipython_env.IPythonEnv | None) -> None:
55-
"""Display this exception on an IPython console."""
53+
@abc.abstractmethod
54+
def display(self, env: ipython_env.IPythonEnv | None) -> None:
55+
"""Display this exception on an IPython console."""
5656

5757

5858
# ParserNormalExit is not an error: it's a way for ArgumentParser to indicate
5959
# that the user has entered a special request (e.g. "--help") instead of a
6060
# runnable command.
6161
# pylint: disable-next=g-bad-exception-name
6262
class ParserNormalExit(_ParserBaseException):
63-
"""Exception thrown when the parser exits normally.
63+
"""Exception thrown when the parser exits normally.
6464
65-
This is usually thrown when the user requests the help message.
66-
"""
65+
This is usually thrown when the user requests the help message.
66+
"""
6767

68-
def display(self, env: ipython_env.IPythonEnv | None) -> None:
69-
for msg in self._msgs:
70-
print(msg)
68+
def display(self, env: ipython_env.IPythonEnv | None) -> None:
69+
for msg in self._msgs:
70+
print(msg)
7171

7272

7373
class ParserError(_ParserBaseException):
74-
"""Exception thrown when there is an error."""
74+
"""Exception thrown when there is an error."""
7575

76-
def display(self, env: ipython_env.IPythonEnv | None) -> None:
77-
for msg in self._msgs:
78-
print(msg)
79-
if env is not None:
80-
# Highlight to the user that an error has occurred.
81-
env.display_html("<b style='font-family:courier new'>ERROR</b>")
76+
def display(self, env: ipython_env.IPythonEnv | None) -> None:
77+
for msg in self._msgs:
78+
print(msg)
79+
if env is not None:
80+
# Highlight to the user that an error has occurred.
81+
env.display_html("<b style='font-family:courier new'>ERROR</b>")
8282

8383

8484
class ArgumentParser(argparse.ArgumentParser):
85-
"""Customized ArgumentParser for LLM Magics.
86-
87-
This class overrides the parent argparse.ArgumentParser's error-handling
88-
methods to avoid side-effects like printing to stderr. The messages are
89-
accumulated and passed into the raised exceptions for the caller to
90-
handle them.
91-
"""
92-
93-
def __init__(self, *args, **kwargs):
94-
super().__init__(*args, **kwargs)
95-
self._messages: list[str] = []
96-
97-
def _print_message(self, message, file=None):
98-
"""Override ArgumentParser's _print_message() method."""
99-
del file
100-
self._messages.append(message)
101-
102-
def exit(self, status=0, message=None):
103-
"""Override ArgumentParser's exit() method."""
104-
if message:
105-
self._print_message(message)
106-
107-
msgs = self._messages
108-
self._messages = []
109-
if status == 0:
110-
raise ParserNormalExit(msgs=msgs)
111-
else:
112-
raise ParserError(msgs=msgs)
85+
"""Customized ArgumentParser for LLM Magics.
86+
87+
This class overrides the parent argparse.ArgumentParser's error-handling
88+
methods to avoid side-effects like printing to stderr. The messages are
89+
accumulated and passed into the raised exceptions for the caller to
90+
handle them.
91+
"""
92+
93+
def __init__(self, *args, **kwargs):
94+
super().__init__(*args, **kwargs)
95+
self._messages: list[str] = []
96+
97+
def _print_message(self, message, file=None):
98+
"""Override ArgumentParser's _print_message() method."""
99+
del file
100+
self._messages.append(message)
101+
102+
def exit(self, status=0, message=None):
103+
"""Override ArgumentParser's exit() method."""
104+
if message:
105+
self._print_message(message)
106+
107+
msgs = self._messages
108+
self._messages = []
109+
if status == 0:
110+
raise ParserNormalExit(msgs=msgs)
111+
else:
112+
raise ParserError(msgs=msgs)

google/generativeai/notebook/argument_parser_test.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,38 +21,35 @@
2121

2222

2323
class ArgumentParserTest(absltest.TestCase):
24+
def test_help(self):
25+
"""Verify that help messages raise ParserNormalExit."""
26+
parser = parser_lib.ArgumentParser()
27+
with self.assertRaisesRegex(
28+
parser_lib.ParserNormalExit, "show this help message and exit"
29+
):
30+
parser.parse_args(["-h"])
2431

25-
def test_help(self):
26-
"""Verify that help messages raise ParserNormalExit."""
27-
parser = parser_lib.ArgumentParser()
28-
with self.assertRaisesRegex(
29-
parser_lib.ParserNormalExit, "show this help message and exit"
30-
):
31-
parser.parse_args(["-h"])
32+
def test_parse_arg_errors(self):
33+
def new_parser() -> argparse.ArgumentParser:
34+
parser = parser_lib.ArgumentParser()
35+
parser.add_argument("--value", type=int, required=True)
36+
return parser
3237

33-
def test_parse_arg_errors(self):
34-
def new_parser() -> argparse.ArgumentParser:
35-
parser = parser_lib.ArgumentParser()
36-
parser.add_argument("--value", type=int, required=True)
37-
return parser
38+
# Normal case: no error.
39+
results = new_parser().parse_args(["--value", "42"])
40+
self.assertEqual(42, results.value)
3841

39-
# Normal case: no error.
40-
results = new_parser().parse_args(["--value", "42"])
41-
self.assertEqual(42, results.value)
42+
with self.assertRaisesRegex(parser_lib.ParserError, "invalid int value"):
43+
new_parser().parse_args(["--value", "forty-two"])
4244

43-
with self.assertRaisesRegex(parser_lib.ParserError, "invalid int value"):
44-
new_parser().parse_args(["--value", "forty-two"])
45+
with self.assertRaisesRegex(
46+
parser_lib.ParserError, "the following arguments are required"
47+
):
48+
new_parser().parse_args([])
4549

46-
with self.assertRaisesRegex(
47-
parser_lib.ParserError, "the following arguments are required"
48-
):
49-
new_parser().parse_args([])
50-
51-
with self.assertRaisesRegex(
52-
parser_lib.ParserError, "expected one argument"
53-
):
54-
new_parser().parse_args(["--value"])
50+
with self.assertRaisesRegex(parser_lib.ParserError, "expected one argument"):
51+
new_parser().parse_args(["--value"])
5552

5653

5754
if __name__ == "__main__":
58-
absltest.main()
55+
absltest.main()

0 commit comments

Comments
 (0)