Skip to content

Commit 84d1f0b

Browse files
committed
HHH-16018 SemanticException when trying to return an entity (or anything not of a basic type) from CASE.
Also include test with a supported way to choose between several entities.
1 parent 469ffda commit 84d1f0b

File tree

4 files changed

+158
-39
lines changed

4 files changed

+158
-39
lines changed

hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5248,19 +5248,12 @@ private Expression createCaseExpression(SqmPath<?> lhs, EntityDomainType<?> trea
52485248
if ( treatTypeRestriction == null ) {
52495249
return expression;
52505250
}
5251-
final BasicValuedMapping mappingModelExpressible = (BasicValuedMapping) expression.getExpressionType();
5252-
final List<CaseSearchedExpression.WhenFragment> whenFragments = new ArrayList<>( 1 );
5253-
whenFragments.add(
5254-
new CaseSearchedExpression.WhenFragment(
5255-
treatTypeRestriction,
5256-
expression
5257-
)
5258-
);
5259-
return new CaseSearchedExpression(
5260-
mappingModelExpressible,
5261-
whenFragments,
5262-
null
5251+
var caseSearchedExpression = CaseSearchedExpression.ofType(
5252+
expression.getExpressionType(),
5253+
lhs::toHqlString
52635254
);
5255+
caseSearchedExpression.when( treatTypeRestriction, expression );
5256+
return caseSearchedExpression;
52645257
}
52655258

52665259
private Predicate consumeConjunctTreatTypeRestrictions() {
@@ -7168,34 +7161,37 @@ public CaseSearchedExpression visitSearchedCaseExpression(SqmCaseSearched<?> exp
71687161
final boolean oldInNestedContext = inNestedContext;
71697162

71707163
inNestedContext = true;
7171-
MappingModelExpressible<?> resolved = determineCurrentExpressible( expression );
7164+
JdbcMappingContainer resolved = determineCurrentExpressible( expression );
71727165

71737166
Expression otherwise = null;
71747167
for ( SqmCaseSearched.WhenFragment<?> whenFragment : expression.getWhenFragments() ) {
71757168
final Predicate whenPredicate = visitNestedTopLevelPredicate( whenFragment.getPredicate() );
7176-
final MappingModelExpressible<?> alreadyKnown = resolved;
7169+
final JdbcMappingContainer alreadyKnown = resolved;
71777170
inferrableTypeAccessStack.push(
71787171
() -> alreadyKnown == null && inferenceSupplier != null ? inferenceSupplier.get() : alreadyKnown
71797172
);
71807173
final Expression resultExpression = (Expression) whenFragment.getResult().accept( this );
71817174
inferrableTypeAccessStack.pop();
7182-
resolved = (MappingModelExpressible<?>) highestPrecedence( resolved, resultExpression.getExpressionType() );
7175+
resolved = highestPrecedence( resolved, resultExpression.getExpressionType() );
71837176

71847177
whenFragments.add( new CaseSearchedExpression.WhenFragment( whenPredicate, resultExpression ) );
71857178
}
71867179

71877180
if ( expression.getOtherwise() != null ) {
7188-
final MappingModelExpressible<?> alreadyKnown = resolved;
7181+
final JdbcMappingContainer alreadyKnown = resolved;
71897182
inferrableTypeAccessStack.push(
71907183
() -> alreadyKnown == null && inferenceSupplier != null ? inferenceSupplier.get() : alreadyKnown
71917184
);
71927185
otherwise = (Expression) expression.getOtherwise().accept( this );
71937186
inferrableTypeAccessStack.pop();
7194-
resolved = (MappingModelExpressible<?>) highestPrecedence( resolved, otherwise.getExpressionType() );
7187+
resolved = highestPrecedence( resolved, otherwise.getExpressionType() );
71957188
}
71967189

7190+
var caseSearchedExpression = CaseSearchedExpression.ofType( resolved, expression::toHqlString );
7191+
caseSearchedExpression.getWhenFragments().addAll( whenFragments );
7192+
caseSearchedExpression.otherwise( otherwise );
71977193
inNestedContext = oldInNestedContext;
7198-
return new CaseSearchedExpression( resolved, whenFragments, otherwise );
7194+
return caseSearchedExpression;
71997195
}
72007196

72017197
private MappingModelExpressible<?> determineCurrentExpressible(SqmTypedNode<?> expression) {

hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1963,28 +1963,27 @@ private void renderPredicatedSetAssignments(List<Assignment> assignments, Predic
19631963
final Expression assignedValue = assignment.getAssignedValue();
19641964
final Expression expression;
19651965
if ( assignable.getColumnReferences().size() == 1 ) {
1966-
expression = new CaseSearchedExpression(
1967-
(MappingModelExpressible) assignedValue.getExpressionType(),
1968-
List.of( new CaseSearchedExpression.WhenFragment( predicate, assignedValue ) ),
1969-
assignable.getColumnReferences().get( 0 )
1966+
CaseSearchedExpression caseSearchedExpression = CaseSearchedExpression.ofType(
1967+
assignedValue.getExpressionType(),
1968+
this::getSql
19701969
);
1970+
caseSearchedExpression.when( predicate, assignedValue );
1971+
caseSearchedExpression.otherwise( assignable.getColumnReferences().get( 0 ) );
1972+
expression = caseSearchedExpression;
19711973
}
19721974
else {
19731975
assert assignedValue instanceof SqlTupleContainer;
19741976
final List<? extends Expression> expressions =
19751977
( (SqlTupleContainer) assignedValue ).getSqlTuple().getExpressions();
19761978
final List<Expression> tupleExpressions = new ArrayList<>( expressions.size() );
19771979
for ( int i = 0; i < expressions.size(); i++ ) {
1978-
tupleExpressions.add(
1979-
new CaseSearchedExpression(
1980-
(MappingModelExpressible<?>) expressions.get( i ).getExpressionType(),
1981-
List.of( new CaseSearchedExpression.WhenFragment(
1982-
predicate,
1983-
expressions.get( i )
1984-
) ),
1985-
assignable.getColumnReferences().get( i )
1986-
)
1980+
CaseSearchedExpression caseSearchedExpression = CaseSearchedExpression.ofType(
1981+
expressions.get( i ).getExpressionType(),
1982+
this::getSql
19871983
);
1984+
caseSearchedExpression.when( predicate, expressions.get( i ) );
1985+
caseSearchedExpression.otherwise( assignable.getColumnReferences().get( i ) );
1986+
tupleExpressions.add(caseSearchedExpression);
19881987
}
19891988
expression = new SqlTuple(
19901989
tupleExpressions,

hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/CaseSearchedExpression.java

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
*/
55
package org.hibernate.sql.ast.tree.expression;
66

7-
import java.io.Serializable;
8-
import java.util.ArrayList;
9-
import java.util.List;
10-
117
import org.hibernate.metamodel.mapping.BasicValuedMapping;
8+
import org.hibernate.metamodel.mapping.JdbcMappingContainer;
129
import org.hibernate.metamodel.mapping.MappingModelExpressible;
10+
import org.hibernate.query.SemanticException;
1311
import org.hibernate.query.sqm.sql.internal.DomainResultProducer;
1412
import org.hibernate.sql.ast.SqlAstWalker;
1513
import org.hibernate.sql.ast.spi.SqlExpressionResolver;
@@ -19,6 +17,11 @@
1917
import org.hibernate.sql.results.graph.DomainResultCreationState;
2018
import org.hibernate.sql.results.graph.basic.BasicResult;
2119

20+
import java.io.Serializable;
21+
import java.util.ArrayList;
22+
import java.util.List;
23+
import java.util.function.Supplier;
24+
2225
/**
2326
* @author Steve Ebersole
2427
*/
@@ -28,16 +31,23 @@ public class CaseSearchedExpression implements Expression, DomainResultProducer
2831
private List<WhenFragment> whenFragments = new ArrayList<>();
2932
private Expression otherwise;
3033

31-
public CaseSearchedExpression(MappingModelExpressible type) {
32-
this.type = (BasicValuedMapping) type;
34+
public CaseSearchedExpression(BasicValuedMapping type) {
35+
this.type = type;
3336
}
3437

35-
public CaseSearchedExpression(MappingModelExpressible type, List<WhenFragment> whenFragments, Expression otherwise) {
36-
this.type = (BasicValuedMapping) type;
38+
public CaseSearchedExpression(BasicValuedMapping type, List<WhenFragment> whenFragments, Expression otherwise) {
39+
this.type = type;
3740
this.whenFragments = whenFragments;
3841
this.otherwise = otherwise;
3942
}
4043

44+
public static CaseSearchedExpression ofType(JdbcMappingContainer type, Supplier<String> contextSupplier) {
45+
if (type instanceof BasicValuedMapping basicValuedMapping) {
46+
return new CaseSearchedExpression( basicValuedMapping );
47+
}
48+
throw new SemanticException( "CASE only supports returning basic values, but not " + type, contextSupplier.get() );
49+
}
50+
4151
public List<WhenFragment> getWhenFragments() {
4252
return whenFragments;
4353
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.orm.test.query.hhh12225;
6+
7+
import jakarta.persistence.Entity;
8+
import jakarta.persistence.FetchType;
9+
import jakarta.persistence.GeneratedValue;
10+
import jakarta.persistence.GenerationType;
11+
import jakarta.persistence.Id;
12+
import jakarta.persistence.ManyToOne;
13+
import org.hibernate.query.SemanticException;
14+
import org.hibernate.testing.orm.junit.EntityManagerFactoryScope;
15+
import org.hibernate.testing.orm.junit.JiraKey;
16+
import org.hibernate.testing.orm.junit.Jpa;
17+
import org.junit.jupiter.api.Assertions;
18+
import org.junit.jupiter.api.BeforeAll;
19+
import org.junit.jupiter.api.Test;
20+
21+
@Jpa(annotatedClasses = {
22+
CaseToOneAssociationTest.Subject.class,
23+
CaseToOneAssociationTest.Link.class
24+
})
25+
@JiraKey("HHH-16018")
26+
public class CaseToOneAssociationTest {
27+
28+
private Subject persistedFrom;
29+
30+
@BeforeAll
31+
public void setUp(EntityManagerFactoryScope scope) {
32+
persistedFrom = scope.fromTransaction( em -> {
33+
final Subject from = new Subject();
34+
final Subject to = new Subject();
35+
em.persist( from );
36+
em.persist( to );
37+
em.persist( new Link( from, to ) );
38+
return from;
39+
} );
40+
}
41+
42+
@Test
43+
public void testUnsupportedCaseForEntity(EntityManagerFactoryScope scope) {
44+
// Won't catch the AssertionError from assertFound, because it is an Error, not an Exception.
45+
var iae = Assertions.assertThrows( IllegalArgumentException.class, () -> {
46+
assertFound(
47+
scope.fromTransaction( em -> em
48+
.createQuery(
49+
"select case when l.from = :from then l.to else l.from end from Link l where from = :from",
50+
Subject.class
51+
)
52+
.setParameter( "from", persistedFrom )
53+
.getSingleResult() )
54+
);
55+
} );
56+
Assertions.assertInstanceOf( SemanticException.class, iae.getCause(), "IllegalArgumentException wraps a SemanticException" );
57+
Assertions.assertTrue(
58+
iae.getCause().getMessage().contains( "CASE only supports returning basic values" ),
59+
"SE#message talks about CASE"
60+
);
61+
}
62+
63+
private void assertFound(Subject found) {
64+
Assertions.assertNotEquals( persistedFrom.id, found.id, "Found itself" );
65+
}
66+
67+
@Test
68+
public void testSupportedCaseForJoin(EntityManagerFactoryScope scope) {
69+
assertFound(
70+
scope.fromTransaction( em -> em
71+
.createQuery(
72+
"""
73+
select s
74+
from Link l
75+
join Subject s ON s.id = case
76+
when l.from = :from
77+
then l.to.id
78+
else l.from.id
79+
end
80+
where from = :from
81+
""",
82+
Subject.class
83+
)
84+
.setParameter( "from", persistedFrom )
85+
.getSingleResult() )
86+
);
87+
}
88+
89+
@Entity(name = "Subject")
90+
public static class Subject {
91+
@Id
92+
@GeneratedValue(strategy = GenerationType.IDENTITY)
93+
public Long id;
94+
}
95+
96+
@Entity(name = "Link")
97+
public static class Link {
98+
@Id
99+
@GeneratedValue(strategy = GenerationType.IDENTITY)
100+
public Long id;
101+
@ManyToOne(fetch = FetchType.LAZY)
102+
public Subject from;
103+
@ManyToOne(fetch = FetchType.LAZY)
104+
public Subject to;
105+
106+
public Link() {
107+
}
108+
109+
public Link(Subject from, Subject to) {
110+
this.from = from;
111+
this.to = to;
112+
}
113+
}
114+
}

0 commit comments

Comments
 (0)