diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java index 7213d5d2fe44..9a0c6aef1a6f 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/sql/BaseSqmToSqlAstConverter.java @@ -5248,19 +5248,12 @@ private Expression createCaseExpression(SqmPath lhs, EntityDomainType trea if ( treatTypeRestriction == null ) { return expression; } - final BasicValuedMapping mappingModelExpressible = (BasicValuedMapping) expression.getExpressionType(); - final List whenFragments = new ArrayList<>( 1 ); - whenFragments.add( - new CaseSearchedExpression.WhenFragment( - treatTypeRestriction, - expression - ) - ); - return new CaseSearchedExpression( - mappingModelExpressible, - whenFragments, - null + var caseSearchedExpression = CaseSearchedExpression.ofType( + expression.getExpressionType(), + lhs::toHqlString ); + caseSearchedExpression.when( treatTypeRestriction, expression ); + return caseSearchedExpression; } private Predicate consumeConjunctTreatTypeRestrictions() { @@ -7168,34 +7161,37 @@ public CaseSearchedExpression visitSearchedCaseExpression(SqmCaseSearched exp final boolean oldInNestedContext = inNestedContext; inNestedContext = true; - MappingModelExpressible resolved = determineCurrentExpressible( expression ); + JdbcMappingContainer resolved = determineCurrentExpressible( expression ); Expression otherwise = null; for ( SqmCaseSearched.WhenFragment whenFragment : expression.getWhenFragments() ) { final Predicate whenPredicate = visitNestedTopLevelPredicate( whenFragment.getPredicate() ); - final MappingModelExpressible alreadyKnown = resolved; + final JdbcMappingContainer alreadyKnown = resolved; inferrableTypeAccessStack.push( () -> alreadyKnown == null && inferenceSupplier != null ? inferenceSupplier.get() : alreadyKnown ); final Expression resultExpression = (Expression) whenFragment.getResult().accept( this ); inferrableTypeAccessStack.pop(); - resolved = (MappingModelExpressible) highestPrecedence( resolved, resultExpression.getExpressionType() ); + resolved = highestPrecedence( resolved, resultExpression.getExpressionType() ); whenFragments.add( new CaseSearchedExpression.WhenFragment( whenPredicate, resultExpression ) ); } if ( expression.getOtherwise() != null ) { - final MappingModelExpressible alreadyKnown = resolved; + final JdbcMappingContainer alreadyKnown = resolved; inferrableTypeAccessStack.push( () -> alreadyKnown == null && inferenceSupplier != null ? inferenceSupplier.get() : alreadyKnown ); otherwise = (Expression) expression.getOtherwise().accept( this ); inferrableTypeAccessStack.pop(); - resolved = (MappingModelExpressible) highestPrecedence( resolved, otherwise.getExpressionType() ); + resolved = highestPrecedence( resolved, otherwise.getExpressionType() ); } + var caseSearchedExpression = CaseSearchedExpression.ofType( resolved, expression::toHqlString ); + caseSearchedExpression.getWhenFragments().addAll( whenFragments ); + caseSearchedExpression.otherwise( otherwise ); inNestedContext = oldInNestedContext; - return new CaseSearchedExpression( resolved, whenFragments, otherwise ); + return caseSearchedExpression; } private MappingModelExpressible determineCurrentExpressible(SqmTypedNode expression) { diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java index 46b3a75d00fd..4eafdd861d26 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/spi/AbstractSqlAstTranslator.java @@ -1963,11 +1963,13 @@ private void renderPredicatedSetAssignments(List assignments, Predic final Expression assignedValue = assignment.getAssignedValue(); final Expression expression; if ( assignable.getColumnReferences().size() == 1 ) { - expression = new CaseSearchedExpression( - (MappingModelExpressible) assignedValue.getExpressionType(), - List.of( new CaseSearchedExpression.WhenFragment( predicate, assignedValue ) ), - assignable.getColumnReferences().get( 0 ) + CaseSearchedExpression caseSearchedExpression = CaseSearchedExpression.ofType( + assignedValue.getExpressionType(), + this::getSql ); + caseSearchedExpression.when( predicate, assignedValue ); + caseSearchedExpression.otherwise( assignable.getColumnReferences().get( 0 ) ); + expression = caseSearchedExpression; } else { assert assignedValue instanceof SqlTupleContainer; @@ -1975,16 +1977,13 @@ private void renderPredicatedSetAssignments(List assignments, Predic ( (SqlTupleContainer) assignedValue ).getSqlTuple().getExpressions(); final List tupleExpressions = new ArrayList<>( expressions.size() ); for ( int i = 0; i < expressions.size(); i++ ) { - tupleExpressions.add( - new CaseSearchedExpression( - (MappingModelExpressible) expressions.get( i ).getExpressionType(), - List.of( new CaseSearchedExpression.WhenFragment( - predicate, - expressions.get( i ) - ) ), - assignable.getColumnReferences().get( i ) - ) + CaseSearchedExpression caseSearchedExpression = CaseSearchedExpression.ofType( + expressions.get( i ).getExpressionType(), + this::getSql ); + caseSearchedExpression.when( predicate, expressions.get( i ) ); + caseSearchedExpression.otherwise( assignable.getColumnReferences().get( i ) ); + tupleExpressions.add(caseSearchedExpression); } expression = new SqlTuple( tupleExpressions, diff --git a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/CaseSearchedExpression.java b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/CaseSearchedExpression.java index 96be3c130fa0..28c192e57d98 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/CaseSearchedExpression.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/ast/tree/expression/CaseSearchedExpression.java @@ -4,12 +4,10 @@ */ package org.hibernate.sql.ast.tree.expression; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; - import org.hibernate.metamodel.mapping.BasicValuedMapping; +import org.hibernate.metamodel.mapping.JdbcMappingContainer; import org.hibernate.metamodel.mapping.MappingModelExpressible; +import org.hibernate.query.SemanticException; import org.hibernate.query.sqm.sql.internal.DomainResultProducer; import org.hibernate.sql.ast.SqlAstWalker; import org.hibernate.sql.ast.spi.SqlExpressionResolver; @@ -19,6 +17,11 @@ import org.hibernate.sql.results.graph.DomainResultCreationState; import org.hibernate.sql.results.graph.basic.BasicResult; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + /** * @author Steve Ebersole */ @@ -28,16 +31,23 @@ public class CaseSearchedExpression implements Expression, DomainResultProducer private List whenFragments = new ArrayList<>(); private Expression otherwise; - public CaseSearchedExpression(MappingModelExpressible type) { - this.type = (BasicValuedMapping) type; + public CaseSearchedExpression(BasicValuedMapping type) { + this.type = type; } - public CaseSearchedExpression(MappingModelExpressible type, List whenFragments, Expression otherwise) { - this.type = (BasicValuedMapping) type; + public CaseSearchedExpression(BasicValuedMapping type, List whenFragments, Expression otherwise) { + this.type = type; this.whenFragments = whenFragments; this.otherwise = otherwise; } + public static CaseSearchedExpression ofType(JdbcMappingContainer type, Supplier contextSupplier) { + if (type instanceof BasicValuedMapping basicValuedMapping) { + return new CaseSearchedExpression( basicValuedMapping ); + } + throw new SemanticException( "CASE only supports returning basic values, but not " + type, contextSupplier.get() ); + } + public List getWhenFragments() { return whenFragments; } diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/query/hhh12225/CaseToOneAssociationTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hhh12225/CaseToOneAssociationTest.java new file mode 100644 index 000000000000..5041c9752777 --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/query/hhh12225/CaseToOneAssociationTest.java @@ -0,0 +1,114 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.orm.test.query.hhh12225; + +import jakarta.persistence.Entity; +import jakarta.persistence.FetchType; +import jakarta.persistence.GeneratedValue; +import jakarta.persistence.GenerationType; +import jakarta.persistence.Id; +import jakarta.persistence.ManyToOne; +import org.hibernate.query.SemanticException; +import org.hibernate.testing.orm.junit.EntityManagerFactoryScope; +import org.hibernate.testing.orm.junit.JiraKey; +import org.hibernate.testing.orm.junit.Jpa; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +@Jpa(annotatedClasses = { + CaseToOneAssociationTest.Subject.class, + CaseToOneAssociationTest.Link.class +}) +@JiraKey("HHH-16018") +public class CaseToOneAssociationTest { + + private Subject persistedFrom; + + @BeforeAll + public void setUp(EntityManagerFactoryScope scope) { + persistedFrom = scope.fromTransaction( em -> { + final Subject from = new Subject(); + final Subject to = new Subject(); + em.persist( from ); + em.persist( to ); + em.persist( new Link( from, to ) ); + return from; + } ); + } + + @Test + public void testUnsupportedCaseForEntity(EntityManagerFactoryScope scope) { + // Won't catch the AssertionError from assertFound, because it is an Error, not an Exception. + var iae = Assertions.assertThrows( IllegalArgumentException.class, () -> { + assertFound( + scope.fromTransaction( em -> em + .createQuery( + "select case when l.from = :from then l.to else l.from end from Link l where from = :from", + Subject.class + ) + .setParameter( "from", persistedFrom ) + .getSingleResult() ) + ); + } ); + Assertions.assertInstanceOf( SemanticException.class, iae.getCause(), "IllegalArgumentException wraps a SemanticException" ); + Assertions.assertTrue( + iae.getCause().getMessage().contains( "CASE only supports returning basic values" ), + "SE#message talks about CASE" + ); + } + + private void assertFound(Subject found) { + Assertions.assertNotEquals( persistedFrom.id, found.id, "Found itself" ); + } + + @Test + public void testSupportedCaseForJoin(EntityManagerFactoryScope scope) { + assertFound( + scope.fromTransaction( em -> em + .createQuery( + """ + select s + from Link l + join Subject s ON s.id = case + when l.from = :from + then l.to.id + else l.from.id + end + where from = :from + """, + Subject.class + ) + .setParameter( "from", persistedFrom ) + .getSingleResult() ) + ); + } + + @Entity(name = "Subject") + public static class Subject { + @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) + public Long id; + } + + @Entity(name = "Link") + public static class Link { + @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) + public Long id; + @ManyToOne(fetch = FetchType.LAZY) + public Subject from; + @ManyToOne(fetch = FetchType.LAZY) + public Subject to; + + public Link() { + } + + public Link(Subject from, Subject to) { + this.from = from; + this.to = to; + } + } +}