diff --git a/hibernate-core/src/main/java/org/hibernate/query/criteria/JpaQueryStructure.java b/hibernate-core/src/main/java/org/hibernate/query/criteria/JpaQueryStructure.java index 9d2c4b2b4bae..9820da5dbdfd 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/criteria/JpaQueryStructure.java +++ b/hibernate-core/src/main/java/org/hibernate/query/criteria/JpaQueryStructure.java @@ -60,6 +60,8 @@ public interface JpaQueryStructure extends JpaQueryPart { JpaQueryStructure setRestriction(Predicate... restrictions); + JpaQueryStructure setRestriction(List restrictions); + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Grouping (group-by / having) clause @@ -78,6 +80,8 @@ public interface JpaQueryStructure extends JpaQueryPart { JpaQueryStructure setGroupRestriction(Predicate... restrictions); + JpaQueryStructure setGroupRestriction(List restrictions); + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Covariant overrides diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/NodeBuilder.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/NodeBuilder.java index a1b1d9cd8121..b2a73a1b6a94 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/NodeBuilder.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/NodeBuilder.java @@ -972,9 +972,11 @@ SqmJsonValueExpression jsonValue( @Override SqmPredicate wrap(Expression expression); - @Override + @Override @SuppressWarnings("unchecked") SqmPredicate wrap(Expression... expressions); + SqmPredicate wrap(List> restrictions); + @Override SqmExpression fk(Path path); diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmCriteriaNodeBuilder.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmCriteriaNodeBuilder.java index 56b0d85b7e19..4c898dcc802a 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmCriteriaNodeBuilder.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/SqmCriteriaNodeBuilder.java @@ -598,6 +598,19 @@ public final SqmPredicate wrap(Expression... expressions) { return new SqmJunctionPredicate( Predicate.BooleanOperator.AND, predicates, this ); } + @Override + public SqmPredicate wrap(List> restrictions) { + if ( restrictions.size() == 1 ) { + return wrap( restrictions.get( 0 ) ); + } + + final List predicates = new ArrayList<>( restrictions.size() ); + for ( Expression expression : restrictions ) { + predicates.add( wrap( expression ) ); + } + return new SqmJunctionPredicate( Predicate.BooleanOperator.AND, predicates, this ); + } + @Override @SuppressWarnings("unchecked") public T unwrap(Class clazz) { return (T) extensions.get( clazz ); diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/predicate/SqmWhereClause.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/predicate/SqmWhereClause.java index 9b2df0fe3647..92cd9abe03ed 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/predicate/SqmWhereClause.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/predicate/SqmWhereClause.java @@ -46,12 +46,10 @@ public void setPredicate(SqmPredicate predicate) { @Override public void applyPredicate(SqmPredicate predicate) { - if ( this.predicate == null ) { - this.predicate = predicate; - } - else { - this.predicate = nodeBuilder.and( this.predicate, predicate ); - } + this.predicate = + this.predicate == null + ? predicate + : nodeBuilder.and( this.predicate, predicate ); } @Override diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/AbstractSqmSelectQuery.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/AbstractSqmSelectQuery.java index 1bb92691d511..e3d6165dd25c 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/AbstractSqmSelectQuery.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/AbstractSqmSelectQuery.java @@ -377,6 +377,12 @@ public SqmSelectQuery where(Predicate... restrictions) { return this; } + @Override + public SqmSelectQuery where(List restrictions) { + getQuerySpec().setRestriction( restrictions ); + return this; + } + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Grouping @@ -388,7 +394,7 @@ public List> getGroupList() { @Override public SqmSelectQuery groupBy(Expression... expressions) { - getQuerySpec().setGroupingExpressions( List.of( expressions ) ); + getQuerySpec().setGroupingExpressions( expressions ); return this; } @@ -405,13 +411,19 @@ public SqmPredicate getGroupRestriction() { @Override public SqmSelectQuery having(Expression booleanExpression) { - getQuerySpec().setGroupRestriction( nodeBuilder().wrap( booleanExpression ) ); + getQuerySpec().setGroupRestriction( booleanExpression ); return this; } @Override public SqmSelectQuery having(Predicate... predicates) { - getQuerySpec().setGroupRestriction( nodeBuilder().wrap( predicates ) ); + getQuerySpec().setGroupRestriction( predicates ); + return this; + } + + @Override + public AbstractQuery having(List restrictions) { + getQuerySpec().setGroupRestriction( restrictions ); return this; } diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/SqmQuerySpec.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/SqmQuerySpec.java index beb49acdd297..a6f2fbf6dd00 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/SqmQuerySpec.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/SqmQuerySpec.java @@ -374,13 +374,24 @@ else if ( restrictions.length == 0 ) { setWhereClause( null ); } else { - SqmWhereClause whereClause = getWhereClause(); - if ( whereClause == null ) { - setWhereClause( whereClause = new SqmWhereClause( nodeBuilder() ) ); - } - else { - whereClause.setPredicate( null ); + final SqmWhereClause whereClause = resetWhereClause(); + for ( Predicate restriction : restrictions ) { + whereClause.applyPredicate( (SqmPredicate) restriction ); } + } + return this; + } + + @Override + public SqmQuerySpec setRestriction(List restrictions) { + if ( restrictions == null ) { + throw new IllegalArgumentException( "The predicate list cannot be null" ); + } + else if ( restrictions.isEmpty() ) { + setWhereClause( null ); + } + else { + final SqmWhereClause whereClause = resetWhereClause(); for ( Predicate restriction : restrictions ) { whereClause.applyPredicate( (SqmPredicate) restriction ); } @@ -388,6 +399,19 @@ else if ( restrictions.length == 0 ) { return this; } + private SqmWhereClause resetWhereClause() { + final SqmWhereClause whereClause = getWhereClause(); + if ( whereClause == null ) { + final SqmWhereClause newWhereClause = new SqmWhereClause( nodeBuilder() ); + setWhereClause( newWhereClause ); + return newWhereClause; + } + else { + whereClause.setPredicate( null ); + return whereClause; + } + } + @Override public List> getGroupingExpressions() { return groupByClauseExpressions; @@ -442,6 +466,12 @@ public SqmQuerySpec setGroupRestriction(Predicate... restrictions) { return this; } + @Override + public SqmQuerySpec setGroupRestriction(List restrictions) { + havingClausePredicate = nodeBuilder().wrap( restrictions ); + return this; + } + @Override public SqmQuerySpec setSortSpecifications(List sortSpecifications) { super.setSortSpecifications( sortSpecifications ); diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/SqmSelectStatement.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/SqmSelectStatement.java index 41532f27bfe1..9c18c21b62b8 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/SqmSelectStatement.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/SqmSelectStatement.java @@ -26,7 +26,6 @@ import org.hibernate.query.sqm.tree.cte.SqmCteStatement; import org.hibernate.query.sqm.tree.expression.SqmParameter; import org.hibernate.query.sqm.tree.from.SqmFromClause; -import org.hibernate.query.sqm.tree.predicate.SqmPredicate; import org.hibernate.query.sqm.tree.from.SqmRoot; import jakarta.persistence.Tuple; @@ -41,7 +40,6 @@ import static java.util.Collections.emptySet; import static java.util.Collections.unmodifiableList; import static java.util.Collections.unmodifiableSet; -import static org.hibernate.query.sqm.spi.SqmCreationHelper.combinePredicates; import static org.hibernate.query.sqm.SqmQuerySource.CRITERIA; import static org.hibernate.query.sqm.tree.SqmCopyContext.noParamCopyContext; import static org.hibernate.query.sqm.tree.jpa.ParameterCollector.collectParameters; @@ -291,7 +289,8 @@ protected JpaCteCriteria withInternal( @Override public SqmSelectStatement distinct(boolean distinct) { - return (SqmSelectStatement) super.distinct( distinct ); + super.distinct( distinct ); + return this; } @Override @@ -308,21 +307,6 @@ public SqmSubQuery subquery(EntityType type) { return new SqmSubQuery<>( this, type, nodeBuilder() ); } - @Override - public SqmSelectStatement where(List restrictions) { - //noinspection rawtypes,unchecked - getQuerySpec().getWhereClause().applyPredicates( (List) restrictions ); - return this; - } - - @Override - public SqmSelectStatement having(List restrictions) { - final SqmPredicate combined = - combinePredicates( getQuerySpec().getHavingClausePredicate(), restrictions ); - getQuerySpec().setHavingClausePredicate( combined ); - return this; - } - @Override @SuppressWarnings("unchecked") public SqmSelectStatement select(Selection selection) { @@ -413,32 +397,50 @@ public SqmSubQuery subquery(Class type) { @Override public SqmSelectStatement where(Expression restriction) { - return (SqmSelectStatement) super.where( restriction ); + super.where( restriction ); + return this; } @Override public SqmSelectStatement where(Predicate... restrictions) { - return (SqmSelectStatement) super.where( restrictions ); + super.where( restrictions ); + return this; + } + + @Override + public SqmSelectStatement where(List restrictions) { + super.where( restrictions ); + return this; } @Override public SqmSelectStatement groupBy(Expression... expressions) { - return (SqmSelectStatement) super.groupBy( expressions ); + super.groupBy( expressions ); + return this; } @Override public SqmSelectStatement groupBy(List> grouping) { - return (SqmSelectStatement) super.groupBy( grouping ); + super.groupBy( grouping ); + return this; } @Override public SqmSelectStatement having(Expression booleanExpression) { - return (SqmSelectStatement) super.having( booleanExpression ); + super.having( booleanExpression ); + return this; } @Override public SqmSelectStatement having(Predicate... predicates) { - return (SqmSelectStatement) super.having( predicates ); + super.having( predicates ); + return this; + } + + @Override + public SqmSelectStatement having(List restrictions) { + super.having( restrictions ); + return this; } @Override diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/SqmSubQuery.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/SqmSubQuery.java index d3acbd952ba1..c47f330a5870 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/SqmSubQuery.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/select/SqmSubQuery.java @@ -77,7 +77,6 @@ import jakarta.persistence.metamodel.EntityType; import static java.util.Collections.emptySet; -import static org.hibernate.query.sqm.spi.SqmCreationHelper.combinePredicates; /** * @author Steve Ebersole @@ -345,37 +344,56 @@ public List> getCompoundSelectionItems() { @Override public SqmSubQuery distinct(boolean distinct) { - return (SqmSubQuery) super.distinct( distinct ); + super.distinct( distinct ); + return this; } @Override public SqmSubQuery where(Expression restriction) { - return (SqmSubQuery) super.where( restriction ); + super.where( restriction ); + return this; } @Override public SqmSubQuery where(Predicate... restrictions) { - return (SqmSubQuery) super.where( restrictions ); + super.where( restrictions ); + return this; + } + + @Override + public SqmSubQuery where(List restrictions) { + super.where( restrictions ); + return this; } @Override public SqmSubQuery groupBy(Expression... expressions) { - return (SqmSubQuery) super.groupBy( expressions ); + super.groupBy( expressions ); + return this; } @Override public SqmSubQuery groupBy(List> grouping) { - return (SqmSubQuery) super.groupBy( grouping ); + super.groupBy( grouping ); + return this; } @Override public SqmSubQuery having(Expression booleanExpression) { - return (SqmSubQuery) super.having( booleanExpression ); + super.having( booleanExpression ); + return this; } @Override public SqmSubQuery having(Predicate... predicates) { - return (SqmSubQuery) super.having( predicates ); + super.having( predicates ); + return this; + } + + @Override + public SqmSubQuery having(List restrictions) { + super.having( restrictions ); + return this; } @Override @@ -721,21 +739,6 @@ public Subquery subquery(EntityType type) { return new SqmSubQuery<>( this, type, nodeBuilder() ); } - @Override - public Subquery where(List restrictions) { - //noinspection rawtypes,unchecked - getQuerySpec().getWhereClause().applyPredicates( (List) restrictions ); - return this; - } - - @Override - public Subquery having(List restrictions) { - final SqmPredicate combined = - combinePredicates( getQuerySpec().getHavingClausePredicate(), restrictions ); - getQuerySpec().setHavingClausePredicate( combined ); - return this; - } - @Override public Set> getParameters() { return emptySet(); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/jpa/criteria/CriteriaRestrictionTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/jpa/criteria/CriteriaRestrictionTest.java new file mode 100644 index 000000000000..2e18854ab1c0 --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/jpa/criteria/CriteriaRestrictionTest.java @@ -0,0 +1,79 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.orm.test.jpa.criteria; + +import jakarta.persistence.Entity; +import jakarta.persistence.GeneratedValue; +import jakarta.persistence.Id; +import org.hibernate.testing.orm.junit.EntityManagerFactoryScope; +import org.hibernate.testing.orm.junit.JiraKey; +import org.hibernate.testing.orm.junit.Jpa; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +@Jpa(annotatedClasses = CriteriaRestrictionTest.Doc.class) +class CriteriaRestrictionTest { + @JiraKey( "HHH-19572" ) + @Test void test(EntityManagerFactoryScope scope) { + scope.inTransaction( + entityManager -> { + Doc doc1 = new Doc(); + doc1.title = "Hibernate ORM"; + doc1.author = "Gavin King"; + doc1.text = "Hibernate ORM is a Java Persistence API implementation"; + entityManager.persist( doc1 ); + Doc doc2 = new Doc(); + doc2.title = "Hibernate ORM"; + doc2.author = "Hibernate Team"; + doc2.text = "Hibernate ORM is a Jakarta Persistence implementation"; + entityManager.persist( doc2 ); + } + ); + scope.inTransaction( + entityManager -> { + var builder = entityManager.getCriteriaBuilder(); + var query = builder.createQuery( Doc.class ); + var d = query.from( Doc.class ); + // test with list + query.where( List.of( + builder.like( d.get( "title" ), "Hibernate%" ), + builder.equal( d.get( "author" ), "Gavin King" ) + ) ); + var resultList = entityManager.createQuery( query ).getResultList(); + assertEquals( 1, resultList.size() ); + assertEquals( "Hibernate ORM is a Java Persistence API implementation", + resultList.get( 0 ).text ); + } + ); + scope.inTransaction( + entityManager -> { + var builder = entityManager.getCriteriaBuilder(); + var query = builder.createQuery( Doc.class ); + var d = query.from( Doc.class ); + // test with varargs + query.where( + builder.like( d.get( "title" ), "Hibernate%" ), + builder.equal( d.get( "author" ), "Hibernate Team" ) + ); + var resultList = entityManager.createQuery( query ).getResultList(); + assertEquals( 1, resultList.size() ); + assertEquals( "Hibernate ORM is a Jakarta Persistence implementation", + resultList.get( 0 ).text ); + } + ); + } + @Entity static class Doc { + @Id + @GeneratedValue + UUID uuid; + private String title; + private String author; + private String text; + } +}