Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions generativeai/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ dependencies {
testImplementation("io.kotest:kotest-assertions-core:5.5.5")
testImplementation("io.kotest:kotest-assertions-core-jvm:5.5.5")
testImplementation("io.mockk:mockk:1.12.8")
testImplementation("org.json:json:20240303")
androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai

import com.google.ai.client.generativeai.type.FunctionCallPart
import com.google.ai.client.generativeai.type.Schema
import com.google.ai.client.generativeai.type.Tool
import com.google.ai.client.generativeai.type.defineFunction
import io.kotest.assertions.throwables.shouldThrowWithMessage
import io.kotest.matchers.shouldBe
import org.json.JSONObject
import org.junit.Test

internal class FunctionCallingTests {

@Test
fun `function calls with valid args should succeed`() = doBlocking {
val functionDeclaration =
defineFunction("name", "description", Schema.str("param1", "description")) { param1 ->
JSONObject(mapOf("result" to "success"))
}
val model = GenerativeModel("model", "key", tools = listOf(Tool(listOf(functionDeclaration))))

val functionCall = FunctionCallPart("name", mapOf("param1" to "valid parameter"))

val result = model.executeFunction(functionCall)

result["result"] shouldBe "success"
}

@Test
fun `function calls with invalid args should fail`() = doBlocking {
val functionDeclaration =
defineFunction("name", "description", Schema.str("param1", "description")) { param1 ->
JSONObject(mapOf("result" to "success"))
}
val model = GenerativeModel("model", "key", tools = listOf(Tool(listOf(functionDeclaration))))

val functionCall = FunctionCallPart("name", mapOf("param1" to null))

shouldThrowWithMessage<RuntimeException>(
"Missing argument for parameter \"param1\" for function \"name\""
) {
model.executeFunction(functionCall)
}
}
}