Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,46 @@ class PsiClassTypeUtil {
type = type.parameters.firstOrNull()
count++
}
return type as? PsiClassType
val convertOptional = type?.let { convertOptionalType(it, project) }
return convertOptional as? PsiClassType
}
return null
}

/**
* Check if daoParamType is an instance of PsiClassType representing Optional or its primitive variants
*/
fun convertOptionalType(
daoParamType: PsiType,
project: Project,
): PsiType {
if (daoParamType is PsiClassType) {
val resolved = daoParamType.resolve()
val optionalTypeMap =
mapOf(
"java.util.OptionalInt" to "java.lang.Integer",
"java.util.OptionalDouble" to "java.lang.Double",
"java.util.OptionalLong" to "java.lang.Long",
)
if (resolved != null) {
when (resolved.qualifiedName) {
"java.util.Optional" -> return daoParamType.parameters.firstOrNull()
?: daoParamType

else ->
optionalTypeMap[resolved.qualifiedName]?.let { optionalType ->
val newType =
PsiType.getTypeByName(
optionalType,
project,
GlobalSearchScope.allScope(project),
)
return newType
}
}
}
}
return daoParamType
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ class ForDirectiveUtil {

val matchParam = daoMethod.findParameter(cleanString(topElementText))
val daoParamType = matchParam?.type ?: return null
fieldAccessTopParentClass = PsiParentClass(daoParamType)
fieldAccessTopParentClass = PsiParentClass(PsiClassTypeUtil.convertOptionalType(daoParamType, project))
}
fieldAccessTopParentClass?.let {
getFieldAccessLastPropertyClassType(
Expand Down Expand Up @@ -329,14 +329,15 @@ class ForDirectiveUtil {
): ValidationResult? {
var parent =
if (isBatchAnnotation) {
val parentType = topParent.type
val parentType = PsiClassTypeUtil.convertOptionalType(topParent.type, project)
val nextClassType = parentType as? PsiClassType ?: return null
val nestType = nextClassType.parameters.firstOrNull() ?: return null
PsiParentClass(nestType)
PsiParentClass(PsiClassTypeUtil.convertOptionalType(nestType, project))
} else {
topParent
val convertOptional = PsiClassTypeUtil.convertOptionalType(topParent.type, project)
PsiParentClass(convertOptional)
}
val parentType = parent.type
val parentType = PsiClassTypeUtil.convertOptionalType(parent.type, project)
val classType = parentType as? PsiClassType ?: return null

var competeResult: ValidationCompleteResult? = null
Expand All @@ -353,8 +354,8 @@ class ForDirectiveUtil {
// When a List type element is used as the parent,
// the original declared type is retained and the referenced type is obtained by nesting.
var parentListBaseType: PsiType? =
if (PsiClassTypeUtil.Companion.isIterableType(classType, project)) {
parentType
if (PsiClassTypeUtil.isIterableType(classType, project)) {
PsiClassTypeUtil.convertOptionalType(parentType, project)
} else {
null
}
Expand All @@ -375,24 +376,25 @@ class ForDirectiveUtil {
parent
.findField(searchElm)
?.let { match ->
val convertOptional = PsiClassTypeUtil.convertOptionalType(match.type, project)
val type =
parentListBaseType?.let {
PsiClassTypeUtil.Companion.getParameterType(
PsiClassTypeUtil.getParameterType(
project,
match.type,
convertOptional,
it,
nestIndex,
)
}
?: match.type
?: convertOptional
val classType = type as? PsiClassType
if (classType != null &&
PsiClassTypeUtil.Companion.isIterableType(
PsiClassTypeUtil.isIterableType(
classType,
element.project,
)
) {
parentListBaseType = type
parentListBaseType = PsiClassTypeUtil.convertOptionalType(type, project)
nestIndex = 0
}
findFieldMethod?.invoke(type)
Expand All @@ -402,19 +404,20 @@ class ForDirectiveUtil {
.findMethod(searchElm)
?.let { match ->
val returnType = match.returnType ?: return null
val convertOptionalType = PsiClassTypeUtil.convertOptionalType(returnType, project)
val methodReturnType =
parentListBaseType?.let {
PsiClassTypeUtil.Companion.getParameterType(
PsiClassTypeUtil.getParameterType(
project,
returnType,
convertOptionalType,
it,
nestIndex,
)
}
?: returnType
?: convertOptionalType
val classType = methodReturnType as? PsiClassType
if (classType != null &&
PsiClassTypeUtil.Companion.isIterableType(
PsiClassTypeUtil.isIterableType(
classType,
element.project,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import com.intellij.util.ProcessingContext
import org.domaframework.doma.intellij.common.dao.findDaoMethod
import org.domaframework.doma.intellij.common.psi.PsiParentClass
import org.domaframework.doma.intellij.common.psi.PsiPatternUtil
import org.domaframework.doma.intellij.common.sql.PsiClassTypeUtil
import org.domaframework.doma.intellij.common.sql.cleanString
import org.domaframework.doma.intellij.common.sql.directive.DirectiveCompletion
import org.domaframework.doma.intellij.common.sql.validator.result.ValidationCompleteResult
Expand Down Expand Up @@ -385,7 +386,7 @@ class SqlParameterCompletionProvider : CompletionProvider<CompletionParameters>(
return null
}
val immediate = findParam.getIterableClazz(daoMethod.getDomaAnnotationType())
return immediate.type
return PsiClassTypeUtil.convertOptionalType(immediate.type, originalFile.project)
}

private fun getRefClazz(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ import org.domaframework.doma.intellij.psi.SqlElFieldAccessExpr

class DocumentDaoParameterGenerator(
val originalElement: PsiElement,
val project: Project,
override val project: Project,
val result: MutableList<String?>,
) : DocumentGenerator() {
) : DocumentGenerator(project) {
override fun generateDocument() {
var topParentType: PsiParentClass? = null
val selfSkip = isSelfSkip(originalElement)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
*/
package org.domaframework.doma.intellij.document.generator

import com.intellij.openapi.project.Project
import com.intellij.psi.PsiElement
import org.domaframework.doma.intellij.common.psi.PsiParentClass
import org.domaframework.doma.intellij.common.sql.PsiClassTypeUtil
import org.domaframework.doma.intellij.common.sql.foritem.ForItem
import org.domaframework.doma.intellij.extension.psi.getForItem

abstract class DocumentGenerator {
abstract class DocumentGenerator(
open val project: Project,
) {
abstract fun generateDocument()

protected fun isSelfSkip(targetElement: PsiElement): Boolean {
Expand All @@ -30,8 +34,10 @@ abstract class DocumentGenerator {
}

protected fun generateTypeLink(parentClass: PsiParentClass?): String {
if (parentClass?.type != null) {
return generateTypeLinkFromCanonicalText(parentClass.type.canonicalText)
val parentClassType = parentClass?.type
if (parentClassType != null) {
val convertOptionalType = PsiClassTypeUtil.convertOptionalType(parentClassType, project)
return generateTypeLinkFromCanonicalText(convertOptionalType.canonicalText)
}
return ""
}
Expand All @@ -40,22 +46,34 @@ abstract class DocumentGenerator {
val regex = Regex("([a-zA-Z0-9_]+\\.)*([a-zA-Z0-9_]+)")
val result = StringBuilder()
var lastIndex = 0
val optionalPackage = "java.util.Optional"
val optionalTypeMap =
listOf(
optionalPackage,
"${optionalPackage}Int",
"${optionalPackage}Double",
"${optionalPackage}Long",
)
var skipCount = 0

for (match in regex.findAll(canonicalText)) {
val fullMatch = match.value
val optionalSkip = optionalTypeMap.contains(fullMatch)
if (optionalSkip) skipCount++

val typeName = match.groups[2]?.value ?: fullMatch
val startIndex = match.range.first
val endIndex = match.range.last + 1

if (lastIndex < startIndex) {
if (lastIndex < startIndex && !optionalSkip) {
result.append(canonicalText.substring(lastIndex, startIndex))
}
result.append("<a href=\"psi_element://$fullMatch\">$typeName</a>")
if (!optionalSkip) result.append("<a href=\"psi_element://$fullMatch\">$typeName</a>")
lastIndex = endIndex
}

if (lastIndex < canonicalText.length) {
result.append(canonicalText.substring(lastIndex))
if (lastIndex + skipCount < canonicalText.length) {
result.append(canonicalText.substring(lastIndex + skipCount))
}

return result.toString()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ import org.domaframework.doma.intellij.psi.SqlElStaticFieldAccessExpr

class DocumentStaticFieldGenerator(
val originalElement: PsiElement,
val project: Project,
override val project: Project,
val result: MutableList<String?>,
val staticFieldAccessExpr: SqlElStaticFieldAccessExpr,
val file: PsiFile,
) : DocumentGenerator() {
) : DocumentGenerator(project) {
override fun generateDocument() {
val fieldAccessBlocks = staticFieldAccessExpr.accessElements
val staticElement = PsiStaticElement(fieldAccessBlocks, file)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ fun PsiParameter.getIterableClazz(useListParam: Boolean): PsiParentClass {
val immediate = this.type as? PsiClassType
val classType = immediate?.let { PsiClassTypeUtil.getPsiTypeByList(it, this.project, useListParam) }
if (classType != null) {
return PsiParentClass(classType)
val convertOptional = PsiClassTypeUtil.convertOptionalType(classType, this.project)
return PsiParentClass(convertOptional)
}
return PsiParentClass(this.type)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class SqlInspectionVisitor(
if (forItem != null) {
val result = ForDirectiveUtil.getForDirectiveItemClassType(project, forDirectiveBlocks)
if (result == null) {
// TODO Add an error message when the type of element used in the For directory is not a List type.
errorHighlight(topElement, daoMethod, holder)
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class SqlCompleteTest : DomaSqlTest() {
"$testDapName/completeCallStaticPropertyClass.sql",
"$testDapName/completeForItemHasNext.sql",
"$testDapName/completeForItemIndex.sql",
"$testDapName/completeOptionalDaoParam.sql",
"$testDapName/completeOptionalStaticProperty.sql",
"$testDapName/completeOptionalByForItem.sql",
)
myFixture.enableInspections(SqlBindVariableValidInspector())
}
Expand Down Expand Up @@ -329,6 +332,30 @@ class SqlCompleteTest : DomaSqlTest() {
)
}

fun testCompleteOptionalDaoParam() {
innerDirectiveCompleteTest(
"$testDapName/completeOptionalDaoParam.sql",
listOf("manager", "projectNumber", "getFirstEmployee()"),
listOf("get()", "orElseGet()", "isPresent()"),
)
}

fun testCompleteOptionalStaticProperty() {
innerDirectiveCompleteTest(
"$testDapName/completeOptionalStaticProperty.sql",
listOf("userId", "userName", "email", "getUserNameFormat()"),
listOf("get()", "orElseGet()", "isPresent()"),
)
}

fun testCompleteOptionalByForItem() {
innerDirectiveCompleteTest(
"$testDapName/completeOptionalByForItem.sql",
listOf("manager", "projectNumber", "getFirstEmployee()"),
listOf("get()", "orElseGet()", "isPresent()"),
)
}

private fun innerDirectiveCompleteTest(
sqlFileName: String,
expectedSuggestions: List<String>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
*/
package org.domaframework.doma.intellij.document

import com.intellij.openapi.vfs.VirtualFile
import com.intellij.psi.PsiElement
import org.domaframework.doma.intellij.DomaSqlTest
import org.domaframework.doma.intellij.psi.SqlBlockComment
import org.domaframework.doma.intellij.psi.SqlElFieldAccessExpr
import org.domaframework.doma.intellij.psi.SqlElForDirective
import org.domaframework.doma.intellij.psi.SqlElIdExpr

class SqlSymbolDocumentTestCase : DomaSqlTest() {
Expand All @@ -37,6 +41,8 @@ class SqlSymbolDocumentTestCase : DomaSqlTest() {
addSqlFile("$testPackage/$testDaoName/documentForItemStaticProperty.sql")
addSqlFile("$testPackage/$testDaoName/documentForItemHasNext.sql")
addSqlFile("$testPackage/$testDaoName/documentForItemIndex.sql")
addSqlFile("$testPackage/$testDaoName/documentForItemOptionalForItem.sql")
addSqlFile("$testPackage/$testDaoName/documentForItemOptionalProperty.sql")
}

fun testDocumentForItemDaoParam() {
Expand All @@ -55,6 +61,22 @@ class SqlSymbolDocumentTestCase : DomaSqlTest() {
documentationTest(sqlName, result)
}

fun testDocumentForItemOptionalForItem() {
val sqlName = "documentForItemOptionalForItem"
val result =
"<a href=\"psi_element://java.util.List\">List</a><<a href=\"psi_element://doma.example.entity.Project\">Project</a>> optionalProjects"

documentationTest(sqlName, result)
}

fun testDocumentForItemOptionalForItemProperty() {
val sqlName = "documentForItemOptionalProperty"
val result =
"<a href=\"psi_element://java.util.List\">List</a><<a href=\"psi_element://java.lang.Integer\">Integer</a>> optionalIds"

documentationFindTextTest(sqlName, "optionalIds", result)
}

fun testDocumentForItemElement() {
val sqlName = "documentForItemElement"
val result =
Expand Down Expand Up @@ -146,8 +168,34 @@ class SqlSymbolDocumentTestCase : DomaSqlTest() {
if (sqlFile == null) return

myFixture.configureFromExistingVirtualFile(sqlFile)
var originalElement: PsiElement = myFixture.findElementByText(originalElementName, SqlElIdExpr::class.java)
var originalElement: PsiElement? =
myFixture.findElementByText(originalElementName, SqlElIdExpr::class.java)
?: fundForDirectiveDeclarationElement(sqlFile, originalElementName)
assertNotNull("Not Found Element [$originalElementName]", originalElement)
if (originalElement == null) return

val resultDocument = myDocumentationProvider.generateDoc(originalElement, originalElement)
assertEquals("Documentation should contain expected text", result, resultDocument)
}

private fun fundForDirectiveDeclarationElement(
sqlFile: VirtualFile,
searchElementName: String,
): PsiElement? {
myFixture.configureFromExistingVirtualFile(sqlFile)
val topElement = myFixture.findElementByText(searchElementName, PsiElement::class.java)
val forDirectiveBlock =
topElement.children
.firstOrNull { it is SqlBlockComment && it.text.contains(searchElementName) }

val forDirective =
forDirectiveBlock?.children?.find { it is SqlElForDirective } as? SqlElForDirective
?: return null
val fieldAccessExpr = forDirective.elExprList[1] as? SqlElFieldAccessExpr
if (fieldAccessExpr == null) {
return forDirective.elExprList.firstOrNull { it.text == searchElementName }
}

return fieldAccessExpr.elExprList.firstOrNull { it.text == searchElementName }
}
}
Loading
Loading