Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,21 @@ import org.domaframework.doma.intellij.extension.psi.isDataType
import org.domaframework.doma.intellij.extension.psi.isDomain
import org.domaframework.doma.intellij.extension.psi.isEntity
import org.domaframework.doma.intellij.formatter.block.SqlBlock
import org.domaframework.doma.intellij.formatter.block.comma.SqlCommaBlock
import org.domaframework.doma.intellij.formatter.block.group.keyword.create.SqlCreateViewGroupBlock
import org.domaframework.doma.intellij.formatter.block.group.keyword.with.SqlWithQuerySubGroupBlock
import org.domaframework.doma.intellij.formatter.block.group.subgroup.SqlSubGroupBlock
import kotlin.reflect.KClass

object TypeUtil {
private val TOP_LEVEL_EXPECTED_TYPES =
listOf(
SqlSubGroupBlock::class,
SqlCommaBlock::class,
SqlWithQuerySubGroupBlock::class,
SqlCreateViewGroupBlock::class,
)

/**
* Unwraps the type parameter from Optional if present, otherwise returns the original type.
*/
Expand Down Expand Up @@ -118,6 +130,8 @@ object TypeUtil {
return PsiTypeChecker.isBaseClassType(type) || DomaClassName.isOptionalWrapperType(type.canonicalText)
}

fun isTopLevelExpectedType(childBlock: SqlBlock?): Boolean = isExpectedClassType(TOP_LEVEL_EXPECTED_TYPES, childBlock)

/**
* Determines whether the specified class instance matches.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ open class SqlBlock(
protected fun isConditionLoopDirectiveRegisteredBeforeParent(): Boolean {
val firstPrevBlock = (prevBlocks.lastOrNull() as? SqlElConditionLoopCommentBlock)
parentBlock?.let { parent ->
return firstPrevBlock != null &&

return (childBlocks.firstOrNull() as? SqlElConditionLoopCommentBlock)?.isBeforeParentBlock() == true ||
firstPrevBlock != null &&
firstPrevBlock.conditionEnd != null &&
firstPrevBlock.node.startOffset > parent.node.startOffset
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ import com.intellij.formatting.Wrap
import com.intellij.lang.ASTNode
import com.intellij.psi.PsiWhiteSpace
import com.intellij.psi.formatter.common.AbstractBlock
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.psi.util.elementType
import com.intellij.psi.util.nextLeaf
import com.intellij.psi.util.nextLeafs
import org.domaframework.doma.intellij.common.util.TypeUtil
import org.domaframework.doma.intellij.formatter.block.comma.SqlCommaBlock
import org.domaframework.doma.intellij.formatter.block.comment.SqlCommentBlock
Expand Down Expand Up @@ -62,7 +58,6 @@ import org.domaframework.doma.intellij.formatter.block.other.SqlEscapeBlock
import org.domaframework.doma.intellij.formatter.block.other.SqlOtherBlock
import org.domaframework.doma.intellij.formatter.block.word.SqlAliasBlock
import org.domaframework.doma.intellij.formatter.block.word.SqlArrayWordBlock
import org.domaframework.doma.intellij.formatter.block.word.SqlFunctionGroupBlock
import org.domaframework.doma.intellij.formatter.block.word.SqlTableBlock
import org.domaframework.doma.intellij.formatter.block.word.SqlWordBlock
import org.domaframework.doma.intellij.formatter.builder.SqlBlockBuilder
Expand Down Expand Up @@ -152,152 +147,130 @@ open class SqlFileBlock(
child: ASTNode,
prevBlock: SqlBlock?,
): SqlBlock {
val defaultFormatCtx =
SqlBlockFormattingContext(
wrap,
alignment,
spacingBuilder,
isEnableFormat(),
formatMode,
)
val defaultFormatCtx = createDefaultFormattingContext()
val lastGroup = blockBuilder.getLastGroupTopNodeIndexHistory()
val lastGroupFilteredDirective = blockBuilder.getLastGroupFilterDirective()
return when (child.elementType) {
SqlTypes.KEYWORD -> {
return blockUtil.getKeywordBlock(
child,
blockBuilder.getLastGroupTopNodeIndexHistory(),
)
}

SqlTypes.DATATYPE -> {
SqlDataTypeBlock(
child,
defaultFormatCtx,
)
}

SqlTypes.LEFT_PAREN -> {
return blockUtil.getSubGroupBlock(lastGroup, child, blockBuilder.getGroupTopNodeIndexHistory())
}

SqlTypes.OTHER -> {
return if (lastGroup is SqlUpdateSetGroupBlock &&
lastGroup.columnDefinitionGroupBlock != null
) {
SqlUpdateColumnAssignmentSymbolBlock(
child,
defaultFormatCtx,
)
} else {
val escapeStrings = listOf("\"", "`", "[", "]")
if (escapeStrings.contains(child.text)) {
if (child.text == "[" && prevBlock is SqlArrayWordBlock) {
SqlArrayListGroupBlock(
child,
defaultFormatCtx,
)
} else {
SqlEscapeBlock(
child,
defaultFormatCtx,
)
}
} else {
SqlOtherBlock(
child,
defaultFormatCtx,
)
}
}
}

SqlTypes.RIGHT_PAREN -> return SqlRightPatternBlock(
child,
defaultFormatCtx,
)
return when (child.elementType) {
SqlTypes.KEYWORD -> createKeywordBlock(child, lastGroup)
SqlTypes.DATATYPE -> SqlDataTypeBlock(child, defaultFormatCtx)
SqlTypes.LEFT_PAREN -> blockUtil.getSubGroupBlock(lastGroup, child, blockBuilder.getGroupTopNodeIndexHistory())
SqlTypes.OTHER -> createOtherBlock(child, prevBlock, lastGroup, defaultFormatCtx)
SqlTypes.RIGHT_PAREN -> SqlRightPatternBlock(child, defaultFormatCtx)
SqlTypes.COMMA -> createCommaBlock(child, lastGroup, defaultFormatCtx)
SqlTypes.FUNCTION_NAME -> createFunctionNameBlock(child, lastGroup, defaultFormatCtx)
SqlTypes.WORD -> createWordBlock(child, lastGroup, defaultFormatCtx)
SqlTypes.BLOCK_COMMENT -> createBlockCommentBlock(child, lastGroup, lastGroupFilteredDirective, defaultFormatCtx)
SqlTypes.LINE_COMMENT -> SqlLineCommentBlock(child, defaultFormatCtx)
SqlTypes.PLUS, SqlTypes.MINUS, SqlTypes.ASTERISK, SqlTypes.SLASH -> SqlElSymbolBlock(child, defaultFormatCtx)
SqlTypes.LE, SqlTypes.LT, SqlTypes.EL_EQ, SqlTypes.EL_NE, SqlTypes.GE, SqlTypes.GT -> SqlElSymbolBlock(child, defaultFormatCtx)
SqlTypes.STRING, SqlTypes.NUMBER, SqlTypes.BOOLEAN -> SqlLiteralBlock(child, defaultFormatCtx)
else -> SqlUnknownBlock(child, defaultFormatCtx)
}
}

SqlTypes.COMMA -> {
return if (lastGroup is SqlWithQueryGroupBlock) {
SqlWithCommonTableGroupBlock(child, defaultFormatCtx)
} else {
blockUtil.getCommaGroupBlock(lastGroup, child)
}
}
private fun createDefaultFormattingContext(): SqlBlockFormattingContext =
SqlBlockFormattingContext(
wrap,
alignment,
spacingBuilder,
isEnableFormat(),
formatMode,
)

SqlTypes.FUNCTION_NAME -> {
val notWhiteSpaceElement =
child.psi.nextLeafs
.takeWhile { it is PsiWhiteSpace }
.lastOrNull()
?.nextLeaf(true)
if (notWhiteSpaceElement?.elementType == SqlTypes.LEFT_PAREN ||
PsiTreeUtil.nextLeaf(child.psi)?.elementType == SqlTypes.LEFT_PAREN
) {
return SqlFunctionGroupBlock(child, defaultFormatCtx)
}
return SqlKeywordBlock(
child,
IndentType.ATTACHED,
defaultFormatCtx,
)
}
private fun createKeywordBlock(
child: ASTNode,
lastGroup: SqlBlock?,
): SqlBlock {
if (blockUtil.hasEscapeBeforeWhiteSpace(blocks.lastOrNull() as? SqlBlock?, child)) {
return blockUtil.getWordBlock(lastGroup, child)
}
return blockUtil.getKeywordBlock(
child,
blockBuilder.getLastGroupTopNodeIndexHistory(),
)
}

SqlTypes.WORD -> {
return if (lastGroup is SqlWithQueryGroupBlock) {
SqlWithCommonTableGroupBlock(child, defaultFormatCtx)
} else {
blockUtil.getWordBlock(lastGroup, child)
}
}
private fun createOtherBlock(
child: ASTNode,
prevBlock: SqlBlock?,
lastGroup: SqlBlock?,
defaultFormatCtx: SqlBlockFormattingContext,
): SqlBlock {
if (lastGroup is SqlUpdateSetGroupBlock && lastGroup.columnDefinitionGroupBlock != null) {
return SqlUpdateColumnAssignmentSymbolBlock(child, defaultFormatCtx)
}

SqlTypes.BLOCK_COMMENT -> {
val tempBlock =
blockUtil.getBlockCommentBlock(
child,
createBlockDirectiveCommentSpacingBuilder(),
)
if (tempBlock !is SqlElConditionLoopCommentBlock) {
if (lastGroup is SqlWithQueryGroupBlock || lastGroupFilteredDirective is SqlWithQueryGroupBlock) {
return SqlWithCommonTableGroupBlock(child, defaultFormatCtx)
}
}
return if (lastGroup is SqlWithCommonTableGroupBlock) {
SqlWithCommonTableGroupBlock(child, defaultFormatCtx)
} else {
tempBlock
}
val escapeStrings = listOf("\"", "`", "[", "]")
if (escapeStrings.contains(child.text)) {
return if (child.text == "[" && prevBlock is SqlArrayWordBlock) {
SqlArrayListGroupBlock(child, defaultFormatCtx)
} else {
SqlEscapeBlock(child, defaultFormatCtx)
}
}
return SqlOtherBlock(child, defaultFormatCtx)
}

SqlTypes.LINE_COMMENT ->
return SqlLineCommentBlock(
child,
defaultFormatCtx,
)

SqlTypes.PLUS, SqlTypes.MINUS, SqlTypes.ASTERISK, SqlTypes.SLASH ->
return SqlElSymbolBlock(
child,
defaultFormatCtx,
)
private fun createCommaBlock(
child: ASTNode,
lastGroup: SqlBlock?,
defaultFormatCtx: SqlBlockFormattingContext,
): SqlBlock =
if (lastGroup is SqlWithQueryGroupBlock) {
SqlWithCommonTableGroupBlock(child, defaultFormatCtx)
} else {
blockUtil.getCommaGroupBlock(lastGroup, child)
}

SqlTypes.LE, SqlTypes.LT, SqlTypes.EL_EQ, SqlTypes.EL_NE, SqlTypes.GE, SqlTypes.GT ->
return SqlElSymbolBlock(
child,
defaultFormatCtx,
)
private fun createFunctionNameBlock(
child: ASTNode,
lastGroup: SqlBlock?,
defaultFormatCtx: SqlBlockFormattingContext,
): SqlBlock {
val block = blockUtil.getFunctionName(child, defaultFormatCtx)
if (block != null) {
return block
}
// If it is not followed by a left parenthesis, treat it as a word block
return if (lastGroup is SqlWithQueryGroupBlock) {
SqlWithCommonTableGroupBlock(child, defaultFormatCtx)
} else {
blockUtil.getWordBlock(lastGroup, child)
}
}

SqlTypes.STRING, SqlTypes.NUMBER, SqlTypes.BOOLEAN ->
return SqlLiteralBlock(
child,
defaultFormatCtx,
)
private fun createWordBlock(
child: ASTNode,
lastGroup: SqlBlock?,
defaultFormatCtx: SqlBlockFormattingContext,
): SqlBlock =
if (lastGroup is SqlWithQueryGroupBlock) {
SqlWithCommonTableGroupBlock(child, defaultFormatCtx)
} else {
blockUtil.getWordBlock(lastGroup, child)
}

else ->
SqlUnknownBlock(
child,
defaultFormatCtx,
)
private fun createBlockCommentBlock(
child: ASTNode,
lastGroup: SqlBlock?,
lastGroupFilteredDirective: SqlBlock?,
defaultFormatCtx: SqlBlockFormattingContext,
): SqlBlock {
val tempBlock =
blockUtil.getBlockCommentBlock(
child,
createBlockDirectiveCommentSpacingBuilder(),
)
if (tempBlock !is SqlElConditionLoopCommentBlock) {
if (lastGroup is SqlWithQueryGroupBlock || lastGroupFilteredDirective is SqlWithQueryGroupBlock) {
return SqlWithCommonTableGroupBlock(child, defaultFormatCtx)
}
}
return if (lastGroup is SqlWithCommonTableGroupBlock) {
SqlWithCommonTableGroupBlock(child, defaultFormatCtx)
} else {
tempBlock
}
}

Expand Down Expand Up @@ -554,7 +527,19 @@ open class SqlFileBlock(
SqlCustomSpacingBuilder.nonSpacing
}

else -> SqlCustomSpacingBuilder.normalSpacing
is SqlSubGroupBlock -> {
val includeSpaceRight = childBlock1.endPatternBlock?.isPreSpaceRight()

if (includeSpaceRight == false) {
SqlCustomSpacingBuilder.nonSpacing
} else {
SqlCustomSpacingBuilder.normalSpacing
}
}

else -> {
SqlCustomSpacingBuilder.normalSpacing
}
}
}

Expand All @@ -576,6 +561,13 @@ open class SqlFileBlock(
if (childBlock2.isEndEscape) {
return SqlCustomSpacingBuilder.nonSpacing
}

// When a column definition is enclosed in escape characters,
// calculate the indentation to match the formatting rules of a CREATE query.
CreateClauseHandler
.getColumnDefinitionRawGroupSpacing(childBlock1, childBlock2)
?.let { return it }

return SqlCustomSpacingBuilder().getSpacing(childBlock2)
}

Expand Down
Loading
Loading