Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.firebase.ai.ksp

import com.google.devtools.ksp.isPublic
import com.google.devtools.ksp.processing.CodeGenerator
import com.google.devtools.ksp.processing.Dependencies
import com.google.devtools.ksp.processing.KSPLogger
Expand All @@ -25,6 +26,7 @@ import com.google.devtools.ksp.symbol.ClassKind
import com.google.devtools.ksp.symbol.KSAnnotated
import com.google.devtools.ksp.symbol.KSAnnotation
import com.google.devtools.ksp.symbol.KSClassDeclaration
import com.google.devtools.ksp.symbol.KSFunctionDeclaration
import com.google.devtools.ksp.symbol.KSType
import com.google.devtools.ksp.symbol.KSVisitorVoid
import com.google.devtools.ksp.symbol.Modifier
Expand All @@ -39,31 +41,196 @@ import com.squareup.kotlinpoet.TypeSpec
import com.squareup.kotlinpoet.ksp.toClassName
import com.squareup.kotlinpoet.ksp.toTypeName
import com.squareup.kotlinpoet.ksp.writeTo
import java.util.Locale
import javax.annotation.processing.Generated

public class SchemaSymbolProcessor(
public class FirebaseSymbolProcessor(
private val codeGenerator: CodeGenerator,
private val logger: KSPLogger,
) : SymbolProcessor {
// This regex extracts everything in the kdocs until it hits either the end of the kdocs, or
// the first @<tag> like @property or @see, extracting the main body text of the kdoc
private val baseKdocRegex = Regex("""^\s*(.*?)((@\w* .*)|\z)""", RegexOption.DOT_MATCHES_ALL)
// This regex extracts two capture groups from @property tags, the first is the name of the
// property, and the second is the documentation associated with that property
private val propertyKdocRegex =
Regex("""\s*@property (\w*) (.*?)(?=@\w*|\z)""", RegexOption.DOT_MATCHES_ALL)

override fun process(resolver: Resolver): List<KSAnnotated> {
resolver
.getSymbolsWithAnnotation("com.google.firebase.ai.annotations.Generable")
.filterIsInstance<KSClassDeclaration>()
.map { it to SchemaSymbolProcessorVisitor() }
.forEach { (klass, visitor) -> visitor.visitClassDeclaration(klass, Unit) }

resolver
.getSymbolsWithAnnotation("com.google.firebase.ai.annotations.Tool")
.filterIsInstance<KSFunctionDeclaration>()
.map { it to FunctionSymbolProcessorVisitor(it, resolver) }
.forEach { (klass, visitor) -> visitor.visitFunctionDeclaration(klass, Unit) }

return emptyList()
}

private inner class FunctionSymbolProcessorVisitor(
private val func: KSFunctionDeclaration,
private val resolver: Resolver,
) : KSVisitorVoid() {
override fun visitFunctionDeclaration(function: KSFunctionDeclaration, data: Unit) {
var shouldError = false
val fullFunctionName = function.qualifiedName?.asString()
if (fullFunctionName == null) {
logger.warn("Error extracting function name.")
shouldError = true
}
if (!function.isPublic()) {
logger.warn("$fullFunctionName must be public.")
shouldError = true
}
val containingClass = function.parentDeclaration as? KSClassDeclaration
val containingClassQualifiedName = containingClass?.qualifiedName?.asString()
if (containingClassQualifiedName == null) {
logger.warn("Could not extract name of containing class of $fullFunctionName.")
shouldError = true
}
if (containingClass == null || !containingClass.isCompanionObject) {
logger.warn(
"$fullFunctionName must be within a companion object $containingClassQualifiedName."
)
shouldError = true
}
if (function.parameters.size != 1) {
logger.warn("$fullFunctionName must have exactly one parameter")
shouldError = true
}
val parameter = function.parameters.firstOrNull()?.type?.resolve()?.declaration
if (parameter != null) {
if (parameter.annotations.find { it.shortName.getShortName() == "Generable" } == null) {
logger.warn("$fullFunctionName parameter must be annotated @Generable")
shouldError = true
}
if (parameter.annotations.find { it.shortName.getShortName() == "Serializable" } == null) {
logger.warn("$fullFunctionName parameter must be annotated @Serializable")
shouldError = true
}
}
val output = function.returnType?.resolve()
if (
output != null &&
output.toClassName().canonicalName != "kotlinx.serialization.json.JsonObject"
) {
if (
output.declaration.annotations.find { it.shortName.getShortName() != "Generable" } == null
) {
logger.warn("$fullFunctionName output must be annotated @Generable")
shouldError = true
}
if (
output.declaration.annotations.find { it.shortName.getShortName() != "Serializable" } ==
null
) {
logger.warn("$fullFunctionName output must be annotated @Serializable")
shouldError = true
}
}
if (shouldError) {
logger.error("$fullFunctionName has one or more errors, please resolve them.")
}
val generatedFunctionFile = generateFileSpec(function)
val containingFile = function.containingFile
if (containingFile == null) {
logger.error("$fullFunctionName must be in a file in the build.")
throw RuntimeException()
}
generatedFunctionFile.writeTo(
codeGenerator,
Dependencies(true, containingFile),
)
}

private fun generateFileSpec(functionDeclaration: KSFunctionDeclaration): FileSpec {
val generatedClassName =
functionDeclaration.simpleName.asString().replaceFirstChar {
if (it.isLowerCase()) it.titlecase(Locale.ROOT) else it.toString()
} + "GeneratedFunctionDeclaration"
return FileSpec.builder(functionDeclaration.packageName.asString(), generatedClassName)
.addImport("com.google.firebase.ai.type", "AutoFunctionDeclaration")
.addType(
TypeSpec.classBuilder(generatedClassName)
.addAnnotation(Generated::class)
.addType(
TypeSpec.companionObjectBuilder()
.addProperty(
PropertySpec.builder(
"FUNCTION_DECLARATION",
ClassName("com.google.firebase.ai.type", "AutoFunctionDeclaration")
.parameterizedBy(
functionDeclaration.parameters.first().type.resolve().toClassName(),
functionDeclaration.returnType?.resolve()?.toClassName()
?: ClassName("kotlinx.serialization.json", "JsonObject")
),
KModifier.PUBLIC,
)
.mutable(false)
.initializer(
CodeBlock.builder()
.add(generateCodeBlockForFunctionDeclaration(functionDeclaration))
.build()
)
.build()
)
.build()
)
.build()
)
.build()
}

fun generateCodeBlockForFunctionDeclaration(
functionDeclaration: KSFunctionDeclaration
): CodeBlock {
val builder = CodeBlock.builder()
val returnType = functionDeclaration.returnType
val hasTypedOutput =
!(returnType == null ||
returnType.resolve().toClassName().canonicalName ==
"kotlinx.serialization.json.JsonObject")
val kdocDescription = functionDeclaration.docString?.let { extractBaseKdoc(it) }
val annotationDescription =
getStringFromAnnotation(
functionDeclaration.annotations.find { it.shortName.getShortName() == "Tool" },
"description"
)
val description = annotationDescription ?: kdocDescription ?: ""
val inputSchemaName =
"${
functionDeclaration.parameters.first().type.resolve().toClassName().canonicalName
}GeneratedSchema.SCHEMA"
builder
.addStatement("AutoFunctionDeclaration.create(")
.indent()
.addStatement("functionName = %S,", functionDeclaration.simpleName.getShortName())
.addStatement("description = %S,", description)
.addStatement("inputSchema = $inputSchemaName,")
if (hasTypedOutput) {
val outputSchemaName =
"${
functionDeclaration.returnType!!.resolve().toClassName().canonicalName
}GeneratedSchema.SCHEMA"
builder.addStatement("outputSchema = $outputSchemaName,")
}
builder.addStatement(
"functionReference = " +
functionDeclaration.qualifiedName!!.getQualifier() +
"::${functionDeclaration.qualifiedName!!.getShortName()},"
)
builder.unindent().addStatement(")")
return builder.build()
}
}

private inner class SchemaSymbolProcessorVisitor() : KSVisitorVoid() {
private val numberTypes = setOf("kotlin.Int", "kotlin.Long", "kotlin.Double", "kotlin.Float")
// This regex extracts everything in the kdocs until it hits either the end of the kdocs, or
// the first @<tag> like @property or @see, extracting the main body text of the kdoc
private val baseKdocRegex = Regex("""^\s*(.*?)((@\w* .*)|\z)""", RegexOption.DOT_MATCHES_ALL)
// This regex extracts two capture groups from @property tags, the first is the name of the
// property, and the second is the documentation associated with that property
private val propertyKdocRegex =
Regex("""\s*@property (\w*) (.*?)(?=@\w*|\z)""", RegexOption.DOT_MATCHES_ALL)

override fun visitClassDeclaration(classDeclaration: KSClassDeclaration, data: Unit) {
val isDataClass = classDeclaration.modifiers.contains(Modifier.DATA)
Expand All @@ -76,13 +243,10 @@ public class SchemaSymbolProcessor(
throw RuntimeException()
}
val generatedSchemaFile = generateFileSpec(classDeclaration)

classDeclaration.containingFile?.let {
generatedSchemaFile.writeTo(
codeGenerator,
Dependencies(true, containingFile),
)
}
generatedSchemaFile.writeTo(
codeGenerator,
Dependencies(true, containingFile),
)
}

fun generateFileSpec(classDeclaration: KSClassDeclaration): FileSpec {
Expand Down Expand Up @@ -293,73 +457,72 @@ public class SchemaSymbolProcessor(
builder.addStatement("nullable = %L)", className.isNullable).unindent()
return builder.build()
}
}

private fun getDescriptionFromAnnotations(
guideAnnotation: KSAnnotation?,
generableClassAnnotation: KSAnnotation?,
description: String?,
baseKdoc: String?,
): String? {
val guidePropertyDescription = getStringFromAnnotation(guideAnnotation, "description")

val guideClassDescription = getStringFromAnnotation(generableClassAnnotation, "description")
private fun getDescriptionFromAnnotations(
guideAnnotation: KSAnnotation?,
generableClassAnnotation: KSAnnotation?,
description: String?,
baseKdoc: String?,
): String? {
val guidePropertyDescription = getStringFromAnnotation(guideAnnotation, "description")

return guidePropertyDescription ?: guideClassDescription ?: description ?: baseKdoc
}
val guideClassDescription = getStringFromAnnotation(generableClassAnnotation, "description")

private fun getDoubleFromAnnotation(
guideAnnotation: KSAnnotation?,
doubleName: String,
): Double? {
val guidePropertyDoubleValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(doubleName) == true }
?.value as? Double
if (guidePropertyDoubleValue == null || guidePropertyDoubleValue == -1.0) {
return null
}
return guidePropertyDoubleValue
return guidePropertyDescription ?: guideClassDescription ?: description ?: baseKdoc
}
private fun getDoubleFromAnnotation(
guideAnnotation: KSAnnotation?,
doubleName: String,
): Double? {
val guidePropertyDoubleValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(doubleName) == true }
?.value as? Double
if (guidePropertyDoubleValue == null || guidePropertyDoubleValue == -1.0) {
return null
}
return guidePropertyDoubleValue
}

private fun getIntFromAnnotation(guideAnnotation: KSAnnotation?, intName: String): Int? {
val guidePropertyIntValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(intName) == true }
?.value as? Int
if (guidePropertyIntValue == null || guidePropertyIntValue == -1) {
return null
}
return guidePropertyIntValue
private fun getIntFromAnnotation(guideAnnotation: KSAnnotation?, intName: String): Int? {
val guidePropertyIntValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(intName) == true }
?.value as? Int
if (guidePropertyIntValue == null || guidePropertyIntValue == -1) {
return null
}
return guidePropertyIntValue
}

private fun getStringFromAnnotation(
guideAnnotation: KSAnnotation?,
stringName: String,
): String? {
val guidePropertyStringValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(stringName) == true }
?.value as? String
if (guidePropertyStringValue.isNullOrEmpty()) {
return null
}
return guidePropertyStringValue
private fun getStringFromAnnotation(
guideAnnotation: KSAnnotation?,
stringName: String,
): String? {
val guidePropertyStringValue =
guideAnnotation
?.arguments
?.firstOrNull { it.name?.getShortName()?.equals(stringName) == true }
?.value as? String
if (guidePropertyStringValue.isNullOrEmpty()) {
return null
}
return guidePropertyStringValue
}

private fun extractBaseKdoc(kdoc: String): String? {
return baseKdocRegex.matchEntire(kdoc)?.groups?.get(1)?.value?.trim().let {
if (it.isNullOrEmpty()) null else it
}
private fun extractBaseKdoc(kdoc: String): String? {
return baseKdocRegex.matchEntire(kdoc)?.groups?.get(1)?.value?.trim().let {
if (it.isNullOrEmpty()) null else it
}
}

private fun extractPropertyKdocs(kdoc: String): Map<String, String> {
return propertyKdocRegex
.findAll(kdoc)
.map { it.groups[1]!!.value to it.groups[2]!!.value.replace("\n", "").trim() }
.toMap()
}
private fun extractPropertyKdocs(kdoc: String): Map<String, String> {
return propertyKdocRegex
.findAll(kdoc)
.map { it.groups[1]!!.value to it.groups[2]!!.value.replace("\n", "").trim() }
.toMap()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import com.google.devtools.ksp.processing.SymbolProcessor
import com.google.devtools.ksp.processing.SymbolProcessorEnvironment
import com.google.devtools.ksp.processing.SymbolProcessorProvider

public class SchemaSymbolProcessorProvider : SymbolProcessorProvider {
public class FirebaseSymbolProcessorProvider : SymbolProcessorProvider {
override fun create(environment: SymbolProcessorEnvironment): SymbolProcessor {
return SchemaSymbolProcessor(environment.codeGenerator, environment.logger)
return FirebaseSymbolProcessor(environment.codeGenerator, environment.logger)
}
}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
com.google.firebase.ai.ksp.SchemaSymbolProcessorProvider
com.google.firebase.ai.ksp.FirebaseSymbolProcessorProvider
Loading