Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,50 @@ class PsiClassTypeUtil {
type = type.parameters.firstOrNull()
count++
}
return type as? PsiClassType
val convertOptional = type?.let { convertOptionalType(it, project) }
return convertOptional as? PsiClassType
}
return null
}

/**
* Check if daoParamType is an instance of PsiClassType representing Optional or its primitive variants
*/
fun convertOptionalType(
daoParamType: PsiType,
project: Project,
): PsiType {
if (daoParamType is PsiClassType) {
val resolved = daoParamType.resolve()
val optionalTypeMap =
mapOf(
"java.util.OptionalInt" to "java.lang.Integer",
"java.util.OptionalDouble" to "java.lang.Double",
"java.util.OptionalLong" to "java.lang.Long",
)
if (resolved != null) {
when (resolved.qualifiedName) {
// If the type is java.util.Optional, return its parameter type if available;
// otherwise, return the original daoParamType.
"java.util.Optional" -> return daoParamType.parameters.firstOrNull()
?: daoParamType

// For primitive Optional types (e.g., OptionalInt, OptionalDouble),
// map them to their corresponding wrapper types (e.g., Integer, Double).
else ->
optionalTypeMap[resolved.qualifiedName]?.let { optionalType ->
val newType =
PsiType.getTypeByName(
optionalType,
project,
GlobalSearchScope.allScope(project),
)
return newType
}
}
}
}
return daoParamType
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ class ForDirectiveUtil {

val matchParam = daoMethod.findParameter(cleanString(topElementText))
val daoParamType = matchParam?.type ?: return null
fieldAccessTopParentClass = PsiParentClass(daoParamType)
fieldAccessTopParentClass = PsiParentClass(PsiClassTypeUtil.convertOptionalType(daoParamType, project))
}
fieldAccessTopParentClass?.let {
getFieldAccessLastPropertyClassType(
Expand Down Expand Up @@ -329,14 +329,15 @@ class ForDirectiveUtil {
): ValidationResult? {
var parent =
if (isBatchAnnotation) {
val parentType = topParent.type
val parentType = PsiClassTypeUtil.convertOptionalType(topParent.type, project)
val nextClassType = parentType as? PsiClassType ?: return null
val nestType = nextClassType.parameters.firstOrNull() ?: return null
PsiParentClass(nestType)
PsiParentClass(PsiClassTypeUtil.convertOptionalType(nestType, project))
} else {
topParent
val convertOptional = PsiClassTypeUtil.convertOptionalType(topParent.type, project)
PsiParentClass(convertOptional)
}
val parentType = parent.type
val parentType = PsiClassTypeUtil.convertOptionalType(parent.type, project)
val classType = parentType as? PsiClassType ?: return null

var competeResult: ValidationCompleteResult? = null
Expand All @@ -353,8 +354,8 @@ class ForDirectiveUtil {
// When a List type element is used as the parent,
// the original declared type is retained and the referenced type is obtained by nesting.
var parentListBaseType: PsiType? =
if (PsiClassTypeUtil.Companion.isIterableType(classType, project)) {
parentType
if (PsiClassTypeUtil.isIterableType(classType, project)) {
PsiClassTypeUtil.convertOptionalType(parentType, project)
} else {
null
}
Expand All @@ -375,24 +376,25 @@ class ForDirectiveUtil {
parent
.findField(searchElm)
?.let { match ->
val convertOptional = PsiClassTypeUtil.convertOptionalType(match.type, project)
val type =
parentListBaseType?.let {
PsiClassTypeUtil.Companion.getParameterType(
PsiClassTypeUtil.getParameterType(
project,
match.type,
convertOptional,
it,
nestIndex,
)
}
?: match.type
?: convertOptional
val classType = type as? PsiClassType
if (classType != null &&
PsiClassTypeUtil.Companion.isIterableType(
PsiClassTypeUtil.isIterableType(
classType,
element.project,
)
) {
parentListBaseType = type
parentListBaseType = PsiClassTypeUtil.convertOptionalType(type, project)
nestIndex = 0
}
findFieldMethod?.invoke(type)
Expand All @@ -402,19 +404,20 @@ class ForDirectiveUtil {
.findMethod(searchElm)
?.let { match ->
val returnType = match.returnType ?: return null
val convertOptionalType = PsiClassTypeUtil.convertOptionalType(returnType, project)
val methodReturnType =
parentListBaseType?.let {
PsiClassTypeUtil.Companion.getParameterType(
PsiClassTypeUtil.getParameterType(
project,
returnType,
convertOptionalType,
it,
nestIndex,
)
}
?: returnType
?: convertOptionalType
val classType = methodReturnType as? PsiClassType
if (classType != null &&
PsiClassTypeUtil.Companion.isIterableType(
PsiClassTypeUtil.isIterableType(
classType,
element.project,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import com.intellij.codeInsight.completion.CompletionProvider
import com.intellij.codeInsight.completion.CompletionResultSet
import com.intellij.codeInsight.lookup.LookupElementBuilder
import com.intellij.codeInsight.lookup.VariableLookupItem
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiClass
import com.intellij.psi.PsiDirectory
import com.intellij.psi.PsiElement
Expand All @@ -33,8 +34,10 @@ import com.intellij.psi.util.elementType
import com.intellij.psi.util.prevLeafs
import com.intellij.util.ProcessingContext
import org.domaframework.doma.intellij.common.dao.findDaoMethod
import org.domaframework.doma.intellij.common.psi.PsiDaoMethod
import org.domaframework.doma.intellij.common.psi.PsiParentClass
import org.domaframework.doma.intellij.common.psi.PsiPatternUtil
import org.domaframework.doma.intellij.common.sql.PsiClassTypeUtil
import org.domaframework.doma.intellij.common.sql.cleanString
import org.domaframework.doma.intellij.common.sql.directive.DirectiveCompletion
import org.domaframework.doma.intellij.common.sql.validator.result.ValidationCompleteResult
Expand Down Expand Up @@ -270,66 +273,49 @@ class SqlParameterCompletionProvider : CompletionProvider<CompletionParameters>(
originalFile: PsiFile,
result: CompletionResultSet,
) {
val daoMethod = findDaoMethod(originalFile)
val searchText = cleanString(getSearchElementText(position))
var topElementType: PsiType? = null
if (elements.isEmpty()) {
getElementTypeByFieldAccess(originalFile, elements, result)
if (elements.isEmpty() && daoMethod != null) {
getElementTypeByFieldAccess(originalFile, elements, daoMethod, result)
return
}
val top = elements.first()

val topText = cleanString(getSearchElementText(top))
val prevWord = PsiPatternUtil.getBindSearchWord(originalFile, elements.last(), " ")
if (prevWord.startsWith("@") && prevWord.endsWith("@")) {
setStaticFieldAccess(top, prevWord, topText, result)
setCompletionStaticFieldAccess(top, prevWord, topText, result)
return
}

var isBatchAnnotation = false
if (top.parent !is PsiFile && top.parent?.parent !is PsiDirectory) {
val staticDirective = top.findNodeParent(SqlTypes.EL_STATIC_FIELD_ACCESS_EXPR)
staticDirective?.let {
topElementType = getElementTypeByStaticFieldAccess(top, it, topText) ?: return
}
}

if (daoMethod == null) return
val project = originalFile.project
val psiDaoMethod = PsiDaoMethod(project, daoMethod)
if (topElementType == null) {
if (isFieldAccessByForItem(top, elements, searchText, result)) return
isBatchAnnotation = psiDaoMethod.daoType.isBatchAnnotation()
if (isFieldAccessByForItem(top, elements, searchText, isBatchAnnotation, result)) return
topElementType =
getElementTypeByFieldAccess(originalFile, elements, result) ?: return
getElementTypeByFieldAccess(originalFile, elements, daoMethod, result) ?: return
}

var psiParentClass = PsiParentClass(topElementType)
// FieldAccess Completion
ForDirectiveUtil.getFieldAccessLastPropertyClassType(
setCompletionFieldAccess(
topElementType,
originalFile.project,
isBatchAnnotation,
elements,
top.project,
psiParentClass,
shortName = "",
dropLastIndex = 1,
complete = { lastType ->
val searchWord = cleanString(getSearchElementText(position))
setFieldsAndMethodsCompletionResultSet(
lastType.searchField(searchWord)?.toTypedArray() ?: emptyArray(),
lastType.searchMethod(searchWord)?.toTypedArray() ?: emptyArray(),
result,
)
},
searchText,
result,
)
}

private fun setStaticFieldAccess(
top: PsiElement,
prevWord: String,
topText: String,
result: CompletionResultSet,
) {
val clazz = getRefClazz(top) { prevWord.replace("@", "") } ?: return
val matchFields = clazz.searchStaticField(topText)
val matchMethod = clazz.searchStaticMethod(topText)

// When you enter here, it is the top element, so return static fields and methods.
setFieldsAndMethodsCompletionResultSet(matchFields, matchMethod, result)
}

private fun getSearchElementText(elm: PsiElement?): String =
if (elm is SqlElIdExpr || elm.elementType == SqlTypes.EL_IDENTIFIER) {
elm?.text ?: ""
Expand Down Expand Up @@ -368,9 +354,9 @@ class SqlParameterCompletionProvider : CompletionProvider<CompletionParameters>(
private fun getElementTypeByFieldAccess(
originalFile: PsiFile,
elements: List<PsiElement>,
daoMethod: PsiMethod,
result: CompletionResultSet,
): PsiType? {
val daoMethod = findDaoMethod(originalFile) ?: return null
val topText = cleanString(getSearchElementText(elements.firstOrNull()))
val matchParams = daoMethod.searchParameter(topText)
val findParam = matchParams.find { it.name == topText }
Expand All @@ -385,7 +371,7 @@ class SqlParameterCompletionProvider : CompletionProvider<CompletionParameters>(
return null
}
val immediate = findParam.getIterableClazz(daoMethod.getDomaAnnotationType())
return immediate.type
return PsiClassTypeUtil.convertOptionalType(immediate.type, originalFile.project)
}

private fun getRefClazz(
Expand Down Expand Up @@ -416,10 +402,10 @@ class SqlParameterCompletionProvider : CompletionProvider<CompletionParameters>(
private fun isFieldAccessByForItem(
top: PsiElement,
elements: List<PsiElement>,
positionText: String,
searchWord: String,
isBatchAnnotation: Boolean = false,
result: CompletionResultSet,
): Boolean {
val searchWord = cleanString(positionText)
val project = top.project
val forDirectiveBlocks = ForDirectiveUtil.getForDirectiveBlocks(top)
ForDirectiveUtil.findForItem(top, forDirectives = forDirectiveBlocks) ?: return false
Expand All @@ -438,6 +424,7 @@ class SqlParameterCompletionProvider : CompletionProvider<CompletionParameters>(
elements,
project,
topClassType,
isBatchAnnotation = isBatchAnnotation,
shortName = "",
dropLastIndex = 1,
complete = { lastType ->
Expand All @@ -450,4 +437,46 @@ class SqlParameterCompletionProvider : CompletionProvider<CompletionParameters>(
)
return result is ValidationCompleteResult
}

private fun setCompletionFieldAccess(
topElementType: PsiType,
project: Project,
isBatchAnnotation: Boolean,
elements: List<PsiElement>,
searchWord: String,
result: CompletionResultSet,
) {
var psiParentClass = PsiParentClass(topElementType)

// FieldAccess Completion
ForDirectiveUtil.getFieldAccessLastPropertyClassType(
elements,
project,
psiParentClass,
isBatchAnnotation = isBatchAnnotation,
shortName = "",
dropLastIndex = 1,
complete = { lastType ->
setFieldsAndMethodsCompletionResultSet(
lastType.searchField(searchWord)?.toTypedArray() ?: emptyArray(),
lastType.searchMethod(searchWord)?.toTypedArray() ?: emptyArray(),
result,
)
},
)
}

private fun setCompletionStaticFieldAccess(
top: PsiElement,
prevWord: String,
topText: String,
result: CompletionResultSet,
) {
val clazz = getRefClazz(top) { prevWord.replace("@", "") } ?: return
val matchFields = clazz.searchStaticField(topText)
val matchMethod = clazz.searchStaticMethod(topText)

// When you enter here, it is the top element, so return static fields and methods.
setFieldsAndMethodsCompletionResultSet(matchFields, matchMethod, result)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ import org.domaframework.doma.intellij.psi.SqlElFieldAccessExpr

class DocumentDaoParameterGenerator(
val originalElement: PsiElement,
val project: Project,
override val project: Project,
val result: MutableList<String?>,
) : DocumentGenerator() {
) : DocumentGenerator(project) {
override fun generateDocument() {
var topParentType: PsiParentClass? = null
val selfSkip = isSelfSkip(originalElement)
Expand Down
Loading