Skip to content

Commit 45d59ac

Browse files
committed
Support @datatype for DAO method parameter types
1 parent 0d8cdc7 commit 45d59ac

File tree

4 files changed

+72
-5
lines changed

4 files changed

+72
-5
lines changed

src/main/kotlin/org/domaframework/doma/intellij/extension/psi/PsiClassExtension.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ fun PsiClass.isEntity(): Boolean = this.getClassAnnotation(DomaClassName.ENTITY.
5050

5151
fun PsiClass.isDomain(): Boolean = this.getClassAnnotation(DomaClassName.DOMAIN.className) != null
5252

53-
fun PsiClass.isDataType(): Boolean = this.getClassAnnotation(DomaClassName.DATATYPE.className) != null
53+
fun PsiClass.isDataType(): Boolean = this.getClassAnnotation(DomaClassName.DATATYPE.className) != null && this.isRecord
5454

5555
fun PsiClassType.getSuperType(superClassName: String): PsiClassType? {
5656
var parent: PsiClassType? = this

src/main/kotlin/org/domaframework/doma/intellij/inspection/dao/processor/paramtype/BatchParamTypeCheckProcessor.kt

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,19 @@ class BatchParamTypeCheckProcessor(
8686

8787
val iterableClassType = param.type as? PsiClassType
8888
iterableClassType?.parameters?.firstOrNull()?.let { iterableParam ->
89-
if (!TypeUtil.isEntity(iterableParam, project)) {
90-
resultParamType.highlightElement(holder)
89+
// Check if @Sql annotation is present or sqlFile=true
90+
if (psiDaoMethod.useSqlAnnotation() || psiDaoMethod.sqlFileOption) {
91+
if (!TypeUtil.isEntity(iterableParam, project) &&
92+
!TypeUtil.isDomain(iterableParam, project) &&
93+
!TypeUtil.isDataType(iterableParam, project)
94+
) {
95+
resultParamType.highlightElement(holder)
96+
}
97+
} else {
98+
// When @Sql annotation is present or sqlFile=true, only Entity types are allowed
99+
if (!TypeUtil.isEntity(iterableParam, project)) {
100+
resultParamType.highlightElement(holder)
101+
}
91102
}
92103
return
93104
}

src/test/kotlin/org/domaframework/doma/intellij/inspection/dao/AnnotationParamTypeCheckInspectionTest.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class AnnotationParamTypeCheckInspectionTest : DomaSqlTest() {
3232
"ScriptParamTestDao",
3333
"SqlProcessorParamTestDao",
3434
"FactoryParamTestDao",
35+
"DataTypeParamTypeTestDao",
3536
)
3637
private val daoPackage = "inspection/paramtype"
3738

@@ -43,6 +44,7 @@ class AnnotationParamTypeCheckInspectionTest : DomaSqlTest() {
4344
addOtherJavaFile("collector", "HogeCollector.java")
4445
addOtherJavaFile("function", "HogeFunction.java")
4546
addOtherJavaFile("function", "HogeBiFunction.java")
47+
addOtherJavaFile("domain", "Salary.java")
4648
testDaoNames.forEach { daoName ->
4749
addDaoJavaFile("$daoPackage/$daoName.java")
4850
}
@@ -87,4 +89,9 @@ class AnnotationParamTypeCheckInspectionTest : DomaSqlTest() {
8789
val dao = findDaoClass("$daoPackage.FactoryParamTestDao")
8890
myFixture.testHighlighting(false, false, false, dao.containingFile.virtualFile)
8991
}
92+
93+
fun testDataTypeParam() {
94+
val dao = findDaoClass("$daoPackage.DataTypeParamTypeTestDao")
95+
myFixture.testHighlighting(false, false, false, dao.containingFile.virtualFile)
96+
}
9097
}

src/test/testData/src/main/java/doma/example/dao/inspection/returntype/DataTypeReturnTypeTestDao.java

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import org.seasar.doma.*;
44
import doma.example.domain.Salary;
55
import org.seasar.doma.jdbc.Reference;
6+
import org.seasar.doma.jdbc.Result;
67

78
import java.util.List;
89
import java.util.Optional;
@@ -18,9 +19,57 @@ public interface DataTypeReturnTypeTestDao {
1819
@Sql("select * from salary where id = /*salary.value*/0")
1920
Optional<Salary> selectOptSalary(Salary salary) ;
2021

22+
@Update
23+
@Sql("UPDATE salary SET val = /*salary.value*/0")
24+
int updateSalaryWithSql1(Salary salary);
25+
26+
@Update
27+
@Sql("UPDATE salary SET val = /*salary.value*/0")
28+
Result<Salary> <error descr="The return type must be \"int\"">updateSalaryWithSql2</error>(Salary salary);
29+
30+
@Insert
31+
@Sql("INSERT INTO salary (val) VALUES (/*salary.value*/0)")
32+
int insertSalaryWithSql(Salary salary);
33+
34+
@Insert
35+
@Sql("INSERT INTO salary (val) VALUES (/*salary.value*/0)")
36+
Result<Salary> <error descr="The return type must be \"int\"">insertSalaryWithSql2</error>(Salary salary);
37+
38+
@Delete
39+
@Sql("DELETE FROM salary WHERE val = /*salary.value*/0 ")
40+
int deleteSalaryWithSql(Salary salary);
41+
42+
@Delete
43+
@Sql("DELETE FROM salary WHERE val = /*salary.value*/0 ")
44+
Result<Salary> <error descr="The return type must be \"int\"">deleteSalaryWithSql2</error>(Salary salary);
45+
46+
@BatchUpdate
47+
@Sql("UPDATE salary SET val = /*salary.value*/0")
48+
int[] batchUpdateSalaryWithSql(List<Salary> salary);
49+
50+
@BatchUpdate
51+
@Sql("UPDATE salary SET val = /*salary.value*/0")
52+
<error descr="Cannot resolve symbol 'BatchResult'">BatchResult</error><Salary> <error descr="The return type must be \"int[]\"">batchUpdateSalaryWithSql2</error>(List<Salary> salary);
53+
54+
@BatchInsert
55+
@Sql("INSERT INTO salary (val) VALUES (/*salary.value*/0)")
56+
int[] batchInsertSalaryWithSql(List<Salary> salary);
57+
58+
@BatchInsert
59+
@Sql("INSERT INTO salary (val) VALUES (/*salary.value*/0)")
60+
<error descr="Cannot resolve symbol 'BatchResult'">BatchResult</error><Salary> <error descr="The return type must be \"int[]\"">batchInsertSalaryWithSql2</error>(List<Salary> salary);
61+
62+
@BatchDelete
63+
@Sql("DELETE FROM salary WHERE val = /*salary.value*/0 ")
64+
int[] batchDeleteSalaryWithSql(List<Salary> salary);
65+
66+
@BatchDelete
67+
@Sql("DELETE FROM salary WHERE val = /*salary.value*/0 ")
68+
<error descr="Cannot resolve symbol 'BatchResult'">BatchResult</error><Salary> <error descr="The return type must be \"int[]\"">batchDeleteSalaryWithSql2</error>(List<Salary> salary);
69+
2170
@Function
22-
Salary calculateAverageSalary();
23-
71+
List<Salary> getTopSalaries(@In Salary limit);
72+
2473
@Function
2574
Optional<Salary> getMaxSalary(@InOut Reference<Salary> percentage, @ResultSet List<Salary> resultSet);
2675

0 commit comments

Comments
 (0)