Skip to content

HHH-16018 Semantic exception for CASE returning an entity. #8865

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5248,19 +5248,12 @@ private Expression createCaseExpression(SqmPath<?> lhs, EntityDomainType<?> trea
if ( treatTypeRestriction == null ) {
return expression;
}
final BasicValuedMapping mappingModelExpressible = (BasicValuedMapping) expression.getExpressionType();
final List<CaseSearchedExpression.WhenFragment> whenFragments = new ArrayList<>( 1 );
whenFragments.add(
new CaseSearchedExpression.WhenFragment(
treatTypeRestriction,
expression
)
);
return new CaseSearchedExpression(
mappingModelExpressible,
whenFragments,
null
var caseSearchedExpression = CaseSearchedExpression.ofType(
expression.getExpressionType(),
lhs::toHqlString
);
caseSearchedExpression.when( treatTypeRestriction, expression );
return caseSearchedExpression;
}

private Predicate consumeConjunctTreatTypeRestrictions() {
Expand Down Expand Up @@ -7168,34 +7161,37 @@ public CaseSearchedExpression visitSearchedCaseExpression(SqmCaseSearched<?> exp
final boolean oldInNestedContext = inNestedContext;

inNestedContext = true;
MappingModelExpressible<?> resolved = determineCurrentExpressible( expression );
JdbcMappingContainer resolved = determineCurrentExpressible( expression );

Expression otherwise = null;
for ( SqmCaseSearched.WhenFragment<?> whenFragment : expression.getWhenFragments() ) {
final Predicate whenPredicate = visitNestedTopLevelPredicate( whenFragment.getPredicate() );
final MappingModelExpressible<?> alreadyKnown = resolved;
final JdbcMappingContainer alreadyKnown = resolved;
inferrableTypeAccessStack.push(
() -> alreadyKnown == null && inferenceSupplier != null ? inferenceSupplier.get() : alreadyKnown
);
final Expression resultExpression = (Expression) whenFragment.getResult().accept( this );
inferrableTypeAccessStack.pop();
resolved = (MappingModelExpressible<?>) highestPrecedence( resolved, resultExpression.getExpressionType() );
resolved = highestPrecedence( resolved, resultExpression.getExpressionType() );

whenFragments.add( new CaseSearchedExpression.WhenFragment( whenPredicate, resultExpression ) );
}

if ( expression.getOtherwise() != null ) {
final MappingModelExpressible<?> alreadyKnown = resolved;
final JdbcMappingContainer alreadyKnown = resolved;
inferrableTypeAccessStack.push(
() -> alreadyKnown == null && inferenceSupplier != null ? inferenceSupplier.get() : alreadyKnown
);
otherwise = (Expression) expression.getOtherwise().accept( this );
inferrableTypeAccessStack.pop();
resolved = (MappingModelExpressible<?>) highestPrecedence( resolved, otherwise.getExpressionType() );
resolved = highestPrecedence( resolved, otherwise.getExpressionType() );
}

var caseSearchedExpression = CaseSearchedExpression.ofType( resolved, expression::toHqlString );
caseSearchedExpression.getWhenFragments().addAll( whenFragments );
caseSearchedExpression.otherwise( otherwise );
inNestedContext = oldInNestedContext;
return new CaseSearchedExpression( resolved, whenFragments, otherwise );
return caseSearchedExpression;
}

private MappingModelExpressible<?> determineCurrentExpressible(SqmTypedNode<?> expression) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1963,28 +1963,27 @@ private void renderPredicatedSetAssignments(List<Assignment> assignments, Predic
final Expression assignedValue = assignment.getAssignedValue();
final Expression expression;
if ( assignable.getColumnReferences().size() == 1 ) {
expression = new CaseSearchedExpression(
(MappingModelExpressible) assignedValue.getExpressionType(),
List.of( new CaseSearchedExpression.WhenFragment( predicate, assignedValue ) ),
assignable.getColumnReferences().get( 0 )
CaseSearchedExpression caseSearchedExpression = CaseSearchedExpression.ofType(
assignedValue.getExpressionType(),
this::getSql
);
caseSearchedExpression.when( predicate, assignedValue );
caseSearchedExpression.otherwise( assignable.getColumnReferences().get( 0 ) );
expression = caseSearchedExpression;
}
else {
assert assignedValue instanceof SqlTupleContainer;
final List<? extends Expression> expressions =
( (SqlTupleContainer) assignedValue ).getSqlTuple().getExpressions();
final List<Expression> tupleExpressions = new ArrayList<>( expressions.size() );
for ( int i = 0; i < expressions.size(); i++ ) {
tupleExpressions.add(
new CaseSearchedExpression(
(MappingModelExpressible<?>) expressions.get( i ).getExpressionType(),
List.of( new CaseSearchedExpression.WhenFragment(
predicate,
expressions.get( i )
) ),
assignable.getColumnReferences().get( i )
)
CaseSearchedExpression caseSearchedExpression = CaseSearchedExpression.ofType(
expressions.get( i ).getExpressionType(),
this::getSql
);
caseSearchedExpression.when( predicate, expressions.get( i ) );
caseSearchedExpression.otherwise( assignable.getColumnReferences().get( i ) );
tupleExpressions.add(caseSearchedExpression);
}
expression = new SqlTuple(
tupleExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
*/
package org.hibernate.sql.ast.tree.expression;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

import org.hibernate.metamodel.mapping.BasicValuedMapping;
import org.hibernate.metamodel.mapping.JdbcMappingContainer;
import org.hibernate.metamodel.mapping.MappingModelExpressible;
import org.hibernate.query.SemanticException;
import org.hibernate.query.sqm.sql.internal.DomainResultProducer;
import org.hibernate.sql.ast.SqlAstWalker;
import org.hibernate.sql.ast.spi.SqlExpressionResolver;
Expand All @@ -19,6 +17,11 @@
import org.hibernate.sql.results.graph.DomainResultCreationState;
import org.hibernate.sql.results.graph.basic.BasicResult;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;

/**
* @author Steve Ebersole
*/
Expand All @@ -28,16 +31,23 @@ public class CaseSearchedExpression implements Expression, DomainResultProducer
private List<WhenFragment> whenFragments = new ArrayList<>();
private Expression otherwise;

public CaseSearchedExpression(MappingModelExpressible type) {
this.type = (BasicValuedMapping) type;
public CaseSearchedExpression(BasicValuedMapping type) {
this.type = type;
}

public CaseSearchedExpression(MappingModelExpressible type, List<WhenFragment> whenFragments, Expression otherwise) {
this.type = (BasicValuedMapping) type;
public CaseSearchedExpression(BasicValuedMapping type, List<WhenFragment> whenFragments, Expression otherwise) {
this.type = type;
this.whenFragments = whenFragments;
this.otherwise = otherwise;
}

public static CaseSearchedExpression ofType(JdbcMappingContainer type, Supplier<String> contextSupplier) {
if (type instanceof BasicValuedMapping basicValuedMapping) {
return new CaseSearchedExpression( basicValuedMapping );
}
throw new SemanticException( "CASE only supports returning basic values, but not " + type, contextSupplier.get() );
}

public List<WhenFragment> getWhenFragments() {
return whenFragments;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.orm.test.query.hhh12225;

import jakarta.persistence.Entity;
import jakarta.persistence.FetchType;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.ManyToOne;
import org.hibernate.query.SemanticException;
import org.hibernate.testing.orm.junit.EntityManagerFactoryScope;
import org.hibernate.testing.orm.junit.JiraKey;
import org.hibernate.testing.orm.junit.Jpa;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

@Jpa(annotatedClasses = {
CaseToOneAssociationTest.Subject.class,
CaseToOneAssociationTest.Link.class
})
@JiraKey("HHH-16018")
public class CaseToOneAssociationTest {

private Subject persistedFrom;

@BeforeAll
public void setUp(EntityManagerFactoryScope scope) {
persistedFrom = scope.fromTransaction( em -> {
final Subject from = new Subject();
final Subject to = new Subject();
em.persist( from );
em.persist( to );
em.persist( new Link( from, to ) );
return from;
} );
}

@Test
public void testUnsupportedCaseForEntity(EntityManagerFactoryScope scope) {
// Won't catch the AssertionError from assertFound, because it is an Error, not an Exception.
var iae = Assertions.assertThrows( IllegalArgumentException.class, () -> {
assertFound(
scope.fromTransaction( em -> em
.createQuery(
"select case when l.from = :from then l.to else l.from end from Link l where from = :from",
Subject.class
)
.setParameter( "from", persistedFrom )
.getSingleResult() )
);
} );
Assertions.assertInstanceOf( SemanticException.class, iae.getCause(), "IllegalArgumentException wraps a SemanticException" );
Assertions.assertTrue(
iae.getCause().getMessage().contains( "CASE only supports returning basic values" ),
"SE#message talks about CASE"
);
}

private void assertFound(Subject found) {
Assertions.assertNotEquals( persistedFrom.id, found.id, "Found itself" );
}

@Test
public void testSupportedCaseForJoin(EntityManagerFactoryScope scope) {
assertFound(
scope.fromTransaction( em -> em
.createQuery(
"""
select s
from Link l
join Subject s ON s.id = case
when l.from = :from
then l.to.id
else l.from.id
end
where from = :from
""",
Subject.class
)
.setParameter( "from", persistedFrom )
.getSingleResult() )
);
}

@Entity(name = "Subject")
public static class Subject {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
public Long id;
}

@Entity(name = "Link")
public static class Link {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
public Long id;
@ManyToOne(fetch = FetchType.LAZY)
public Subject from;
@ManyToOne(fetch = FetchType.LAZY)
public Subject to;

public Link() {
}

public Link(Subject from, Subject to) {
this.from = from;
this.to = to;
}
}
}