Skip to content

Commit 6e23afc

Browse files
feat(vertexai): Add repetition penalties to GenerationConfig (#17234)
* feat: Add presencePenalty and frequencyPenalty to GenerationConfig This commit adds two new instance variables, and , to the class. These variables control the likelihood of repeating words or phrases in the generated text. The implementation is based on the Swift SDK and includes corresponding documentation. Note: Unable to run tests due to unavailable test execution environment. * Use Dartdoc formatting * Add test for presencePenalty and frequencyPenalty * Move new fields into `BaseGenerationConfig` * Add `presencePenalty` and `frequencyPenalty` to `LiveGenerationConfig` * Reword repetition penalties docs * Link to Firebase docs instead of Cloud docs for repetition penalties Both `presencePenalty` and `frequencyPenalty` have now been added to: https://firebase.google.com/docs/vertex-ai/model-parameters?platform=flutter#configure-model-parameters-gemini --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
1 parent a7a842e commit 6e23afc

File tree

3 files changed

+62
-0
lines changed

3 files changed

+62
-0
lines changed

packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,8 @@ abstract class BaseGenerationConfig {
659659
this.temperature,
660660
this.topP,
661661
this.topK,
662+
this.presencePenalty,
663+
this.frequencyPenalty,
662664
});
663665

664666
/// Number of generated responses to return.
@@ -700,6 +702,41 @@ abstract class BaseGenerationConfig {
700702
/// Note: The default value varies by model.
701703
final int? topK;
702704

705+
/// The penalty for repeating the same words or phrases already generated in
706+
/// the text.
707+
///
708+
/// Controls the likelihood of repetition. Higher penalty values result in
709+
/// more diverse output.
710+
///
711+
/// **Note:** While both [presencePenalty] and [frequencyPenalty] discourage
712+
/// repetition, [presencePenalty] applies the same penalty regardless of how
713+
/// many times the word/phrase has already appeared, whereas
714+
/// [frequencyPenalty] increases the penalty for *each* repetition of a
715+
/// word/phrase.
716+
///
717+
/// **Important:** The range of supported [presencePenalty] values depends on
718+
/// the model; see the
719+
/// [documentation](https://firebase.google.com/docs/vertex-ai/model-parameters?platform=flutter#configure-model-parameters-gemini)
720+
/// for more details.
721+
final double? presencePenalty;
722+
723+
/// The penalty for repeating words or phrases, with the penalty increasing
724+
/// for each repetition.
725+
///
726+
/// Controls the likelihood of repetition. Higher values increase the penalty
727+
/// of repetition, resulting in more diverse output.
728+
///
729+
/// **Note:** While both [frequencyPenalty] and [presencePenalty] discourage
730+
/// repetition, [frequencyPenalty] increases the penalty for *each* repetition
731+
/// of a word/phrase, whereas [presencePenalty] applies the same penalty
732+
/// regardless of how many times the word/phrase has already appeared.
733+
///
734+
/// **Important:** The range of supported [frequencyPenalty] values depends on
735+
/// the model; see the
736+
/// [documentation](https://firebase.google.com/docs/vertex-ai/model-parameters?platform=flutter#configure-model-parameters-gemini)
737+
/// for more details.
738+
final double? frequencyPenalty;
739+
703740
// ignore: public_member_api_docs
704741
Map<String, Object?> toJson() => {
705742
if (candidateCount case final candidateCount?)
@@ -709,6 +746,10 @@ abstract class BaseGenerationConfig {
709746
if (temperature case final temperature?) 'temperature': temperature,
710747
if (topP case final topP?) 'topP': topP,
711748
if (topK case final topK?) 'topK': topK,
749+
if (presencePenalty case final presencePenalty?)
750+
'presencePenalty': presencePenalty,
751+
if (frequencyPenalty case final frequencyPenalty?)
752+
'frequencyPenalty': frequencyPenalty,
712753
};
713754
}
714755

@@ -722,6 +763,8 @@ final class GenerationConfig extends BaseGenerationConfig {
722763
super.temperature,
723764
super.topP,
724765
super.topK,
766+
super.presencePenalty,
767+
super.frequencyPenalty,
725768
this.responseMimeType,
726769
this.responseSchema,
727770
});

packages/firebase_vertexai/firebase_vertexai/lib/src/live_api.dart

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ final class LiveGenerationConfig extends BaseGenerationConfig {
8989
super.temperature,
9090
super.topP,
9191
super.topK,
92+
super.presencePenalty,
93+
super.frequencyPenalty,
9294
});
9395

9496
/// The speech configuration.

packages/firebase_vertexai/firebase_vertexai/test/model_test.dart

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,23 @@ void main() {
176176
);
177177
});
178178

179+
test('can override GenerationConfig repetition penalties', () async {
180+
final (client, model) = createModel();
181+
const prompt = 'Some prompt';
182+
await client.checkRequest(
183+
() => model.generateContent([Content.text(prompt)],
184+
generationConfig: GenerationConfig(
185+
presencePenalty: 0.5, frequencyPenalty: 0.2)),
186+
verifyRequest: (_, request) {
187+
expect(request['generationConfig'], {
188+
'presencePenalty': 0.5,
189+
'frequencyPenalty': 0.2,
190+
});
191+
},
192+
response: arbitraryGenerateContentResponse,
193+
);
194+
});
195+
179196
test('can pass system instructions', () async {
180197
const instructions = 'Do a good job';
181198
final (client, model) = createModel(

0 commit comments

Comments
 (0)