Skip to content

Commit d5326cb

Browse files
authored
reformat with new black config (in pyproject.toml) (#78)
1 parent 2fcd649 commit d5326cb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+301
-571
lines changed

docs/build_docs.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,7 @@
115115
"search_hints", True, "Include metadata search hints in the generated files"
116116
)
117117

118-
_SITE_PATH = flags.DEFINE_string(
119-
"site_path", "/api/python", "Path prefix in the _toc.yaml"
120-
)
118+
_SITE_PATH = flags.DEFINE_string("site_path", "/api/python", "Path prefix in the _toc.yaml")
121119

122120
_CODE_URL_PREFIX = flags.DEFINE_string(
123121
"code_url_prefix",
@@ -139,9 +137,7 @@ def drop_staticmethods(self, parent, children):
139137
def __call__(self, path, parent, children):
140138
if any("generativelanguage" in part for part in path) or "generativeai" in path:
141139
children = self.filter_base_dirs(path, parent, children)
142-
children = public_api.explicit_package_contents_filter(
143-
path, parent, children
144-
)
140+
children = public_api.explicit_package_contents_filter(path, parent, children)
145141

146142
if any("generativelanguage" in part for part in path):
147143
if "ServiceClient" in path[-1] or "ServiceAsyncClient" in path[-1]:
@@ -159,9 +155,7 @@ def make_default_filters(self):
159155
public_api.add_proto_fields,
160156
public_api.filter_builtin_modules,
161157
public_api.filter_private_symbols,
162-
MyFilter(
163-
self._base_dir
164-
), # Replaces: public_api.FilterBaseDirs(self._base_dir),
158+
MyFilter(self._base_dir), # Replaces: public_api.FilterBaseDirs(self._base_dir),
165159
public_api.FilterPrivateMap(self._private_map),
166160
public_api.filter_doc_controls_skip,
167161
public_api.ignore_typing,
@@ -229,9 +223,7 @@ def gen_api_docs():
229223
new_content = re.sub(r".*?`oneof`_ ``_.*?\n", "", new_content, re.MULTILINE)
230224
new_content = re.sub(r"\.\. code-block:: python.*?\n", "", new_content)
231225

232-
new_content = re.sub(
233-
r"generativelanguage_\w+.types", "generativelanguage", new_content
234-
)
226+
new_content = re.sub(r"generativelanguage_\w+.types", "generativelanguage", new_content)
235227

236228
if new_content != old_content:
237229
fpath.write_text(new_content)

google/generativeai/client.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,7 @@ def configure(
7878

7979
if had_api_key_value:
8080
if api_key is not None:
81-
raise ValueError(
82-
"You can't set both `api_key` and `client_options['api_key']`."
83-
)
81+
raise ValueError("You can't set both `api_key` and `client_options['api_key']`.")
8482
else:
8583
if api_key is None:
8684
# If no key is provided explicitly, attempt to load one from the
@@ -107,9 +105,7 @@ def configure(
107105
}
108106

109107
new_default_client_config = {
110-
key: value
111-
for key, value in new_default_client_config.items()
112-
if value is not None
108+
key: value for key, value in new_default_client_config.items() if value is not None
113109
}
114110

115111
default_client_config = new_default_client_config
@@ -147,9 +143,7 @@ def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient:
147143
# Attempt to configure using defaults.
148144
if not default_client_config:
149145
configure()
150-
default_discuss_async_client = glm.DiscussServiceAsyncClient(
151-
**default_client_config
152-
)
146+
default_discuss_async_client = glm.DiscussServiceAsyncClient(**default_client_config)
153147

154148
return default_discuss_async_client
155149

google/generativeai/discuss.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def _make_message(content: discuss_types.MessageOptions) -> glm.Message:
3939
return glm.Message(content)
4040

4141

42-
def _make_messages(messages: discuss_types.MessagesOptions) -> List[glm.Message]:
42+
def _make_messages(
43+
messages: discuss_types.MessagesOptions,
44+
) -> List[glm.Message]:
4345
"""
4446
Creates a list of `glm.Message` objects from the provided messages.
4547
@@ -146,7 +148,9 @@ def _make_examples_from_flat(
146148
return result
147149

148150

149-
def _make_examples(examples: discuss_types.ExamplesOptions) -> List[glm.Example]:
151+
def _make_examples(
152+
examples: discuss_types.ExamplesOptions,
153+
) -> List[glm.Example]:
150154
"""
151155
Creates a list of `glm.Example` objects from the provided examples.
152156
@@ -223,9 +227,7 @@ def _make_message_prompt_dict(
223227
messages=messages,
224228
)
225229
else:
226-
flat_prompt = (
227-
(context is not None) or (examples is not None) or (messages is not None)
228-
)
230+
flat_prompt = (context is not None) or (examples is not None) or (messages is not None)
229231
if flat_prompt:
230232
raise ValueError(
231233
"You can't set `prompt`, and its fields `(context, examples, messages)`"
@@ -446,9 +448,7 @@ async def chat_async(
446448
@set_doc(discuss_types.ChatResponse.__doc__)
447449
@dataclasses.dataclass(**DATACLASS_KWARGS, init=False)
448450
class ChatResponse(discuss_types.ChatResponse):
449-
_client: glm.DiscussServiceClient | None = dataclasses.field(
450-
default=lambda: None, repr=False
451-
)
451+
_client: glm.DiscussServiceClient | None = dataclasses.field(default=lambda: None, repr=False)
452452

453453
def __init__(self, **kwargs):
454454
for key, value in kwargs.items():
@@ -469,13 +469,9 @@ def last(self, message: discuss_types.MessageOptions):
469469
self.messages[-1] = message
470470

471471
@set_doc(discuss_types.ChatResponse.reply.__doc__)
472-
def reply(
473-
self, message: discuss_types.MessageOptions
474-
) -> discuss_types.ChatResponse:
472+
def reply(self, message: discuss_types.MessageOptions) -> discuss_types.ChatResponse:
475473
if isinstance(self._client, glm.DiscussServiceAsyncClient):
476-
raise TypeError(
477-
f"reply can't be called on an async client, use reply_async instead."
478-
)
474+
raise TypeError(f"reply can't be called on an async client, use reply_async instead.")
479475
if self.last is None:
480476
raise ValueError(
481477
"The last response from the model did not return any candidates.\n"
@@ -532,9 +528,7 @@ def _build_chat_response(
532528
request.setdefault("temperature", None)
533529
request.setdefault("candidate_count", None)
534530

535-
return ChatResponse(
536-
_client=client, **response, **request
537-
) # pytype: disable=missing-parameter
531+
return ChatResponse(_client=client, **response, **request) # pytype: disable=missing-parameter
538532

539533

540534
def _generate_response(
@@ -571,9 +565,7 @@ def count_message_tokens(
571565
client: glm.DiscussServiceAsyncClient | None = None,
572566
):
573567
model = model_types.make_model_name(model)
574-
prompt = _make_message_prompt(
575-
prompt, context=context, examples=examples, messages=messages
576-
)
568+
prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages)
577569

578570
if client is None:
579571
client = get_default_discuss_client()

google/generativeai/models.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ def get_model(
5353
raise ValueError("Model names must start with `models/` or `tunedModels/`")
5454

5555

56-
def get_base_model(
57-
name: model_types.BaseModelNameOptions, *, client=None
58-
) -> model_types.Model:
56+
def get_base_model(name: model_types.BaseModelNameOptions, *, client=None) -> model_types.Model:
5957
"""Get the `types.Model` for the given base model name.
6058
6159
```
@@ -133,9 +131,7 @@ def _list_tuned_models_next_page(page_size, page_token, client):
133131
)
134132
result = result._response
135133
result = type(result).to_dict(result)
136-
result["models"] = [
137-
model_types.decode_tuned_model(mod) for mod in result.pop("tuned_models")
138-
]
134+
result["models"] = [model_types.decode_tuned_model(mod) for mod in result.pop("tuned_models")]
139135
result["page_size"] = page_size
140136
result["page_token"] = result.pop("next_page_token")
141137
result["client"] = client
@@ -154,21 +150,19 @@ def _list_models_iter_pages(
154150
page_token = None
155151
while True:
156152
if select == "base":
157-
result = _list_base_models_next_page(
158-
page_size, page_token=page_token, client=client
159-
)
153+
result = _list_base_models_next_page(page_size, page_token=page_token, client=client)
160154
elif select == "tuned":
161-
result = _list_tuned_models_next_page(
162-
page_size, page_token=page_token, client=client
163-
)
155+
result = _list_tuned_models_next_page(page_size, page_token=page_token, client=client)
164156
yield from result["models"]
165157
page_token = result["page_token"]
166158
if page_token == "":
167159
break
168160

169161

170162
def list_models(
171-
*, page_size: int | None = None, client: glm.ModelServiceClient | None = None
163+
*,
164+
page_size: int | None = None,
165+
client: glm.ModelServiceClient | None = None,
172166
) -> model_types.ModelsIterable:
173167
"""Lists available models.
174168
@@ -190,7 +184,9 @@ def list_models(
190184

191185

192186
def list_tuned_models(
193-
*, page_size: int | None = None, client: glm.ModelServiceClient | None = None
187+
*,
188+
page_size: int | None = None,
189+
client: glm.ModelServiceClient | None = None,
194190
) -> model_types.TunedModelsIterable:
195191
"""Lists available models.
196192
@@ -294,7 +290,9 @@ def create_tuned_model(
294290
training_data = model_types.encode_tuning_data(training_data)
295291

296292
hyperparameters = glm.Hyperparameters(
297-
epoch_count=epoch_count, batch_size=batch_size, learning_rate=learning_rate
293+
epoch_count=epoch_count,
294+
batch_size=batch_size,
295+
learning_rate=learning_rate,
298296
)
299297
tuning_task = glm.TuningTask(
300298
training_data=training_data,
@@ -310,9 +308,7 @@ def create_tuned_model(
310308
top_k=top_k,
311309
tuning_task=tuning_task,
312310
)
313-
operation = client.create_tuned_model(
314-
dict(tuned_model_id=id, tuned_model=tuned_model)
315-
)
311+
operation = client.create_tuned_model(dict(tuned_model_id=id, tuned_model=tuned_model))
316312

317313
return operations.CreateTunedModelOperation.from_core_operation(operation)
318314

google/generativeai/notebook/argument_parser_test.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@ class ArgumentParserTest(absltest.TestCase):
2424
def test_help(self):
2525
"""Verify that help messages raise ParserNormalExit."""
2626
parser = parser_lib.ArgumentParser()
27-
with self.assertRaisesRegex(
28-
parser_lib.ParserNormalExit, "show this help message and exit"
29-
):
27+
with self.assertRaisesRegex(parser_lib.ParserNormalExit, "show this help message and exit"):
3028
parser.parse_args(["-h"])
3129

3230
def test_parse_arg_errors(self):
@@ -42,9 +40,7 @@ def new_parser() -> argparse.ArgumentParser:
4240
with self.assertRaisesRegex(parser_lib.ParserError, "invalid int value"):
4341
new_parser().parse_args(["--value", "forty-two"])
4442

45-
with self.assertRaisesRegex(
46-
parser_lib.ParserError, "the following arguments are required"
47-
):
43+
with self.assertRaisesRegex(parser_lib.ParserError, "the following arguments are required"):
4844
new_parser().parse_args([])
4945

5046
with self.assertRaisesRegex(parser_lib.ParserError, "expected one argument"):

google/generativeai/notebook/cmd_line_parser.py

Lines changed: 16 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,7 @@ def _resolve_compare_fn_var(
6767
"""Resolves a value passed into --compare_fn."""
6868
fn = py_utils.get_py_var(name)
6969
if not isinstance(fn, Callable):
70-
raise ValueError(
71-
'Variable "{}" does not contain a Callable object'.format(name)
72-
)
70+
raise ValueError('Variable "{}" does not contain a Callable object'.format(name))
7371

7472
return name, fn
7573

@@ -80,19 +78,11 @@ def _resolve_ground_truth_var(name: str) -> Sequence[str]:
8078

8179
# "str" and "bytes" are also Sequences but we want an actual Sequence of
8280
# strings, like a list.
83-
if (
84-
not isinstance(value, Sequence)
85-
or isinstance(value, str)
86-
or isinstance(value, bytes)
87-
):
88-
raise ValueError(
89-
'Variable "{}" does not contain a Sequence of strings'.format(name)
90-
)
81+
if not isinstance(value, Sequence) or isinstance(value, str) or isinstance(value, bytes):
82+
raise ValueError('Variable "{}" does not contain a Sequence of strings'.format(name))
9183
for x in value:
9284
if not isinstance(x, str):
93-
raise ValueError(
94-
'Variable "{}" does not contain a Sequence of strings'.format(name)
95-
)
85+
raise ValueError('Variable "{}" does not contain a Sequence of strings'.format(name))
9686
return value
9787

9888

@@ -128,9 +118,7 @@ def _add_model_flags(
128118

129119
def _check_is_greater_than_or_equal_to_zero(x: float) -> float:
130120
if x < 0:
131-
raise ValueError(
132-
"Value should be greater than or equal to zero, got {}".format(x)
133-
)
121+
raise ValueError("Value should be greater than or equal to zero, got {}".format(x))
134122
return x
135123

136124
flag_def.SingleValueFlagDef(
@@ -154,8 +142,7 @@ def _check_is_greater_than_or_equal_to_zero(x: float) -> float:
154142
short_name="m",
155143
default_value=None,
156144
help_msg=(
157-
"The name of the model to use. If not provided, a default model will"
158-
" be used."
145+
"The name of the model to use. If not provided, a default model will" " be used."
159146
),
160147
).add_argument_to_parser(parser)
161148

@@ -315,9 +302,7 @@ def _compile_save_name_fn(var_name: str) -> str:
315302
return var_name
316303

317304
save_name_help = "The name of a Python variable to save the compiled function to."
318-
parser.add_argument(
319-
"compile_save_name", help=save_name_help, type=_compile_save_name_fn
320-
)
305+
parser.add_argument("compile_save_name", help=save_name_help, type=_compile_save_name_fn)
321306
_add_model_flags(parser)
322307

323308

@@ -346,22 +331,16 @@ def _resolve_llm_function_fn(
346331
if not isinstance(fn, llm_function.LLMFunction):
347332
raise argparse.ArgumentError(
348333
None,
349-
'{} is not a function created with the "compile" command'.format(
350-
var_name
351-
),
334+
'{} is not a function created with the "compile" command'.format(var_name),
352335
)
353336
return var_name, fn
354337

355338
name_help = (
356339
"The name of a Python variable containing a function previously created"
357340
' with the "compile" command.'
358341
)
359-
parser.add_argument(
360-
"lhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn
361-
)
362-
parser.add_argument(
363-
"rhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn
364-
)
342+
parser.add_argument("lhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn)
343+
parser.add_argument("rhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn)
365344

366345
_add_input_flags(parser, placeholders)
367346
_add_output_flags(parser)
@@ -409,9 +388,7 @@ def _create_parser(
409388
subparsers.add_parser(parsed_args_lib.CommandName.RUN_CMD.value),
410389
placeholders,
411390
)
412-
_create_compile_parser(
413-
subparsers.add_parser(parsed_args_lib.CommandName.COMPILE_CMD.value)
414-
)
391+
_create_compile_parser(subparsers.add_parser(parsed_args_lib.CommandName.COMPILE_CMD.value))
415392
_create_compare_parser(
416393
subparsers.add_parser(parsed_args_lib.CommandName.COMPARE_CMD.value),
417394
placeholders,
@@ -471,9 +448,7 @@ def _split_post_processing_tokens(
471448
if start_idx is None:
472449
start_idx = token_num
473450
if token == CmdLineParser.PIPE_OP:
474-
split_tokens.append(
475-
tokens[start_idx:token_num] if start_idx is not None else []
476-
)
451+
split_tokens.append(tokens[start_idx:token_num] if start_idx is not None else [])
477452
start_idx = None
478453

479454
# Add the remaining tokens after the last PIPE_OP.
@@ -518,7 +493,9 @@ def _get_model_args(
518493
candidate_count = parsed_results.pop("candidate_count", None)
519494

520495
model_args = model_lib.ModelArguments(
521-
model=model, temperature=temperature, candidate_count=candidate_count
496+
model=model,
497+
temperature=temperature,
498+
candidate_count=candidate_count,
522499
)
523500
return parsed_results, model_args
524501

@@ -556,9 +533,7 @@ def parse_line(
556533
_, rhs_fn = parsed_args.rhs_name_and_fn
557534
parsed_args = self._get_parsed_args_from_cmd_line_tokens(
558535
tokens=tokens,
559-
placeholders=frozenset(lhs_fn.get_placeholders()).union(
560-
rhs_fn.get_placeholders()
561-
),
536+
placeholders=frozenset(lhs_fn.get_placeholders()).union(rhs_fn.get_placeholders()),
562537
)
563538

564539
_validate_parsed_args(parsed_args)

0 commit comments

Comments
 (0)