@@ -71,7 +71,7 @@ def __init__(
71
71
model_name : str = "gemini-pro" ,
72
72
safety_settings : safety_types .SafetySettingOptions | None = None ,
73
73
generation_config : generation_types .GenerationConfigType | None = None ,
74
- tools : content_types .ToolsType = None ,
74
+ tools : content_types .FunctionLibraryType | None = None ,
75
75
):
76
76
if "/" not in model_name :
77
77
model_name = "models/" + model_name
@@ -80,7 +80,7 @@ def __init__(
80
80
safety_settings , harm_category_set = "new"
81
81
)
82
82
self ._generation_config = generation_types .to_generation_config_dict (generation_config )
83
- self ._tools = content_types .to_tools (tools )
83
+ self ._tools = content_types .to_function_library (tools )
84
84
85
85
self ._client = None
86
86
self ._async_client = None
@@ -94,8 +94,9 @@ def __str__(self):
94
94
f"""\
95
95
genai.GenerativeModel(
96
96
model_name='{ self .model_name } ',
97
- generation_config={ self ._generation_config } .
98
- safety_settings={ self ._safety_settings }
97
+ generation_config={ self ._generation_config } ,
98
+ safety_settings={ self ._safety_settings } ,
99
+ tools={ self ._tools } ,
99
100
)"""
100
101
)
101
102
@@ -107,12 +108,16 @@ def _prepare_request(
107
108
contents : content_types .ContentsType ,
108
109
generation_config : generation_types .GenerationConfigType | None = None ,
109
110
safety_settings : safety_types .SafetySettingOptions | None = None ,
110
- ** kwargs ,
111
+ tools : content_types . FunctionLibraryType | None ,
111
112
) -> glm .GenerateContentRequest :
112
113
"""Creates a `glm.GenerateContentRequest` from raw inputs."""
113
114
if not contents :
114
115
raise TypeError ("contents must not be empty" )
115
116
117
+ tools_lib = self ._get_tools_lib (tools )
118
+ if tools_lib is not None :
119
+ tools_lib = tools_lib .to_proto ()
120
+
116
121
contents = content_types .to_contents (contents )
117
122
118
123
generation_config = generation_types .to_generation_config_dict (generation_config )
@@ -129,19 +134,26 @@ def _prepare_request(
129
134
contents = contents ,
130
135
generation_config = merged_gc ,
131
136
safety_settings = merged_ss ,
132
- tools = self ._tools ,
133
- ** kwargs ,
137
+ tools = tools_lib ,
134
138
)
135
139
140
+ def _get_tools_lib (
141
+ self , tools : content_types .FunctionLibraryType
142
+ ) -> content_types .FunctionLibrary | None :
143
+ if tools is None :
144
+ return self ._tools
145
+ else :
146
+ return content_types .to_function_library (tools )
147
+
136
148
def generate_content (
137
149
self ,
138
150
contents : content_types .ContentsType ,
139
151
* ,
140
152
generation_config : generation_types .GenerationConfigType | None = None ,
141
153
safety_settings : safety_types .SafetySettingOptions | None = None ,
142
154
stream : bool = False ,
155
+ tools : content_types .FunctionLibraryType | None = None ,
143
156
request_options : dict [str , Any ] | None = None ,
144
- ** kwargs ,
145
157
) -> generation_types .GenerateContentResponse :
146
158
"""A multipurpose function to generate responses from the model.
147
159
@@ -201,7 +213,7 @@ def generate_content(
201
213
contents = contents ,
202
214
generation_config = generation_config ,
203
215
safety_settings = safety_settings ,
204
- ** kwargs ,
216
+ tools = tools ,
205
217
)
206
218
if self ._client is None :
207
219
self ._client = client .get_default_generative_client ()
@@ -230,15 +242,15 @@ async def generate_content_async(
230
242
generation_config : generation_types .GenerationConfigType | None = None ,
231
243
safety_settings : safety_types .SafetySettingOptions | None = None ,
232
244
stream : bool = False ,
245
+ tools : content_types .FunctionLibraryType | None = None ,
233
246
request_options : dict [str , Any ] | None = None ,
234
- ** kwargs ,
235
247
) -> generation_types .AsyncGenerateContentResponse :
236
248
"""The async version of `GenerativeModel.generate_content`."""
237
249
request = self ._prepare_request (
238
250
contents = contents ,
239
251
generation_config = generation_config ,
240
252
safety_settings = safety_settings ,
241
- ** kwargs ,
253
+ tools = tools ,
242
254
)
243
255
if self ._async_client is None :
244
256
self ._async_client = client .get_default_generative_async_client ()
@@ -299,6 +311,7 @@ def start_chat(
299
311
self ,
300
312
* ,
301
313
history : Iterable [content_types .StrictContentType ] | None = None ,
314
+ enable_automatic_function_calling : bool = False ,
302
315
) -> ChatSession :
303
316
"""Returns a `genai.ChatSession` attached to this model.
304
317
@@ -314,6 +327,7 @@ def start_chat(
314
327
return ChatSession (
315
328
model = self ,
316
329
history = history ,
330
+ enable_automatic_function_calling = enable_automatic_function_calling ,
317
331
)
318
332
319
333
@@ -341,11 +355,13 @@ def __init__(
341
355
self ,
342
356
model : GenerativeModel ,
343
357
history : Iterable [content_types .StrictContentType ] | None = None ,
358
+ enable_automatic_function_calling : bool = False ,
344
359
):
345
360
self .model : GenerativeModel = model
346
361
self ._history : list [glm .Content ] = content_types .to_contents (history )
347
362
self ._last_sent : glm .Content | None = None
348
363
self ._last_received : generation_types .BaseGenerateContentResponse | None = None
364
+ self .enable_automatic_function_calling = enable_automatic_function_calling
349
365
350
366
def send_message (
351
367
self ,
@@ -354,7 +370,7 @@ def send_message(
354
370
generation_config : generation_types .GenerationConfigType = None ,
355
371
safety_settings : safety_types .SafetySettingOptions = None ,
356
372
stream : bool = False ,
357
- ** kwargs ,
373
+ tools : content_types . FunctionLibraryType | None = None ,
358
374
) -> generation_types .GenerateContentResponse :
359
375
"""Sends the conversation history with the added message and returns the model's response.
360
376
@@ -387,23 +403,52 @@ def send_message(
387
403
safety_settings: Overrides for the model's safety settings.
388
404
stream: If True, yield response chunks as they are generated.
389
405
"""
406
+ if self .enable_automatic_function_calling and stream :
407
+ raise NotImplementedError (
408
+ "The `google.generativeai` SDK does not yet support `stream=True` with "
409
+ "`enable_automatic_function_calling=True`"
410
+ )
411
+
412
+ tools_lib = self .model ._get_tools_lib (tools )
413
+
390
414
content = content_types .to_content (content )
415
+
391
416
if not content .role :
392
417
content .role = self ._USER_ROLE
418
+
393
419
history = self .history [:]
394
420
history .append (content )
395
421
396
422
generation_config = generation_types .to_generation_config_dict (generation_config )
397
423
if generation_config .get ("candidate_count" , 1 ) > 1 :
398
424
raise ValueError ("Can't chat with `candidate_count > 1`" )
425
+
399
426
response = self .model .generate_content (
400
427
contents = history ,
401
428
generation_config = generation_config ,
402
429
safety_settings = safety_settings ,
403
430
stream = stream ,
404
- ** kwargs ,
431
+ tools = tools_lib ,
405
432
)
406
433
434
+ self ._check_response (response = response , stream = stream )
435
+
436
+ if self .enable_automatic_function_calling and tools_lib is not None :
437
+ self .history , content , response = self ._handle_afc (
438
+ response = response ,
439
+ history = history ,
440
+ generation_config = generation_config ,
441
+ safety_settings = safety_settings ,
442
+ stream = stream ,
443
+ tools_lib = tools_lib ,
444
+ )
445
+
446
+ self ._last_sent = content
447
+ self ._last_received = response
448
+
449
+ return response
450
+
451
+ def _check_response (self , * , response , stream ):
407
452
if response .prompt_feedback .block_reason :
408
453
raise generation_types .BlockedPromptException (response .prompt_feedback )
409
454
@@ -415,10 +460,49 @@ def send_message(
415
460
):
416
461
raise generation_types .StopCandidateException (response .candidates [0 ])
417
462
418
- self ._last_sent = content
419
- self ._last_received = response
463
+ def _get_function_calls (self , response ) -> list [glm .FunctionCall ]:
464
+ candidates = response .candidates
465
+ if len (candidates ) != 1 :
466
+ raise ValueError (
467
+ f"Automatic function calling only works with 1 candidate, got: { len (candidates )} "
468
+ )
469
+ parts = candidates [0 ].content .parts
470
+ function_calls = [part .function_call for part in parts if part and "function_call" in part ]
471
+ return function_calls
472
+
473
+ def _handle_afc (
474
+ self , * , response , history , generation_config , safety_settings , stream , tools_lib
475
+ ) -> tuple [list [glm .Content ], glm .Content , generation_types .BaseGenerateContentResponse ]:
476
+
477
+ while function_calls := self ._get_function_calls (response ):
478
+ if not all (callable (tools_lib [fc ]) for fc in function_calls ):
479
+ break
480
+ history .append (response .candidates [0 ].content )
481
+
482
+ function_response_parts : list [glm .Part ] = []
483
+ for fc in function_calls :
484
+ fr = tools_lib (fc )
485
+ assert fr is not None , (
486
+ "This should never happen, it should only return None if the declaration"
487
+ "is not callable, and that's guarded against above."
488
+ )
489
+ function_response_parts .append (fr )
420
490
421
- return response
491
+ send = glm .Content (role = self ._USER_ROLE , parts = function_response_parts )
492
+ history .append (send )
493
+
494
+ response = self .model .generate_content (
495
+ contents = history ,
496
+ generation_config = generation_config ,
497
+ safety_settings = safety_settings ,
498
+ stream = stream ,
499
+ tools = tools_lib ,
500
+ )
501
+
502
+ self ._check_response (response = response , stream = stream )
503
+
504
+ * history , content = history
505
+ return history , content , response
422
506
423
507
async def send_message_async (
424
508
self ,
@@ -427,42 +511,88 @@ async def send_message_async(
427
511
generation_config : generation_types .GenerationConfigType = None ,
428
512
safety_settings : safety_types .SafetySettingOptions = None ,
429
513
stream : bool = False ,
430
- ** kwargs ,
514
+ tools : content_types . FunctionLibraryType | None = None ,
431
515
) -> generation_types .AsyncGenerateContentResponse :
432
516
"""The async version of `ChatSession.send_message`."""
517
+ if self .enable_automatic_function_calling and stream :
518
+ raise NotImplementedError (
519
+ "The `google.generativeai` SDK does not yet support `stream=True` with "
520
+ "`enable_automatic_function_calling=True`"
521
+ )
522
+
523
+ tools_lib = self .model ._get_tools_lib (tools )
524
+
433
525
content = content_types .to_content (content )
526
+
434
527
if not content .role :
435
528
content .role = self ._USER_ROLE
529
+
436
530
history = self .history [:]
437
531
history .append (content )
438
532
439
533
generation_config = generation_types .to_generation_config_dict (generation_config )
440
534
if generation_config .get ("candidate_count" , 1 ) > 1 :
441
535
raise ValueError ("Can't chat with `candidate_count > 1`" )
442
- response = await self .model .generate_content_async (
536
+
537
+ response = await self .model .generate_content (
443
538
contents = history ,
444
539
generation_config = generation_config ,
445
540
safety_settings = safety_settings ,
446
541
stream = stream ,
447
- ** kwargs ,
542
+ tools = tools_lib ,
448
543
)
449
544
450
- if response .prompt_feedback .block_reason :
451
- raise generation_types .BlockedPromptException (response .prompt_feedback )
545
+ self ._check_response (response = response , stream = stream )
452
546
453
- if not stream :
454
- if response .candidates [0 ].finish_reason not in (
455
- glm .Candidate .FinishReason .FINISH_REASON_UNSPECIFIED ,
456
- glm .Candidate .FinishReason .STOP ,
457
- glm .Candidate .FinishReason .MAX_TOKENS ,
458
- ):
459
- raise generation_types .StopCandidateException (response .candidates [0 ])
547
+ if self .enable_automatic_function_calling and tools_lib is not None :
548
+ self .history , content , response = await self ._handle_afc_async (
549
+ response = response ,
550
+ history = history ,
551
+ generation_config = generation_config ,
552
+ safety_settings = safety_settings ,
553
+ stream = stream ,
554
+ tools_lib = tools_lib ,
555
+ )
460
556
461
557
self ._last_sent = content
462
558
self ._last_received = response
463
559
464
560
return response
465
561
562
+ async def _handle_afc_async (
563
+ self , * , response , history , generation_config , safety_settings , stream , tools_lib
564
+ ) -> tuple [list [glm .Content ], glm .Content , generation_types .BaseGenerateContentResponse ]:
565
+
566
+ while function_calls := self ._get_function_calls (response ):
567
+ if not all (callable (tools_lib [fc ]) for fc in function_calls ):
568
+ break
569
+ history .append (response .candidates [0 ].content )
570
+
571
+ function_response_parts : list [glm .Part ] = []
572
+ for fc in function_calls :
573
+ fr = tools_lib (fc )
574
+ assert fr is not None , (
575
+ "This should never happen, it should only return None if the declaration"
576
+ "is not callable, and that's guarded against above."
577
+ )
578
+ function_response_parts .append (fr )
579
+
580
+ send = glm .Content (role = self ._USER_ROLE , parts = function_response_parts )
581
+ history .append (send )
582
+
583
+ response = await self .model .generate_content_async (
584
+ contents = history ,
585
+ generation_config = generation_config ,
586
+ safety_settings = safety_settings ,
587
+ stream = stream ,
588
+ tools = tools_lib ,
589
+ )
590
+
591
+ self ._check_response (response = response , stream = stream )
592
+
593
+ * history , content = history
594
+ return history , content , response
595
+
466
596
def __copy__ (self ):
467
597
return ChatSession (
468
598
model = self .model ,
0 commit comments