diff --git a/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlParser.g4 b/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlParser.g4 index 9c413db73e21..17716d416a3c 100644 --- a/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlParser.g4 +++ b/hibernate-core/src/main/antlr/org/hibernate/grammars/hql/HqlParser.g4 @@ -151,8 +151,7 @@ cycleClause * A toplevel query of subquery, which may be a union or intersection of subqueries */ queryExpression - : withClause? orderedQuery # SimpleQueryGroup - | withClause? orderedQuery (setOperator orderedQuery)+ # SetQueryGroup + : withClause? orderedQuery (setOperator orderedQuery)* ; /** @@ -427,8 +426,6 @@ pathContinuation * * VALUE( path ) * * KEY( path ) * * path[ selector ] - * * ARRAY_GET( embeddableArrayPath, index ).path - * * COALESCE( array1, array2 )[ selector ].path */ syntacticDomainPath : treatedNavigablePath @@ -436,10 +433,6 @@ syntacticDomainPath | mapKeyNavigablePath | simplePath indexedPathAccessFragment | simplePath slicedPathAccessFragment - | toOneFkReference - | function pathContinuation - | function indexedPathAccessFragment pathContinuation? - | function slicedPathAccessFragment ; /** @@ -661,19 +654,21 @@ whereClause predicate //highest to lowest precedence : LEFT_PAREN predicate RIGHT_PAREN # GroupedPredicate - | expression IS NOT? NULL # IsNullPredicate - | expression IS NOT? EMPTY # IsEmptyPredicate - | expression IS NOT? TRUE # IsTruePredicate - | expression IS NOT? FALSE # IsFalsePredicate - | expression IS NOT? DISTINCT FROM expression # IsDistinctFromPredicate + | expression IS NOT? (NULL|EMPTY|TRUE|FALSE) # UnaryIsPredicate | expression NOT? MEMBER OF? path # MemberOfPredicate | expression NOT? IN inList # InPredicate | expression NOT? BETWEEN expression AND expression # BetweenPredicate | expression NOT? (LIKE | ILIKE) REGEXP? expression likeEscape? # LikePredicate - | expression NOT? CONTAINS expression # ContainsPredicate - | expression NOT? INCLUDES expression # IncludesPredicate - | expression NOT? INTERSECTS expression # IntersectsPredicate - | expression comparisonOperator expression # ComparisonPredicate + | expression + ( NOT? (CONTAINS | INCLUDES | INTERSECTS) + | IS NOT? DISTINCT FROM + | EQUAL + | NOT_EQUAL + | GREATER + | GREATER_EQUAL + | LESS + | LESS_EQUAL + ) expression # BinaryExpressionPredicate | EXISTS collectionQuantifier LEFT_PAREN simplePath RIGHT_PAREN # ExistsCollectionPartPredicate | EXISTS expression # ExistsPredicate | NOT predicate # NegatedPredicate @@ -682,18 +677,6 @@ predicate | expression # BooleanExpressionPredicate ; -/** - * An operator which compares values for equality or order - */ -comparisonOperator - : EQUAL - | NOT_EQUAL - | GREATER - | GREATER_EQUAL - | LESS - | LESS_EQUAL - ; - /** * Any right operand of the 'in' operator * @@ -748,7 +731,14 @@ primaryExpression | entityVersionReference # EntityVersionExpression | entityNaturalIdReference # EntityNaturalIdExpression | syntacticDomainPath pathContinuation? # SyntacticPathExpression - | function # FunctionExpression + // ARRAY_GET( embeddableArrayPath, index ).path + // COALESCE( array1, array2 )[ selector ].path + // COALESCE( array1, array2 )[ start : end ] + | function ( + pathContinuation + | slicedPathAccessFragment + | indexedPathAccessFragment pathContinuation? + )? # FunctionExpression | generalPathFragment # GeneralPathExpression ; @@ -1108,6 +1098,7 @@ function | columnFunction | jsonFunction | xmlFunction + | toOneFkReference | genericFunction ; diff --git a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/ordering/OrderByFragmentTranslator.java b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/ordering/OrderByFragmentTranslator.java index b1344d38e757..43c784348e90 100644 --- a/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/ordering/OrderByFragmentTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/metamodel/mapping/ordering/OrderByFragmentTranslator.java @@ -73,8 +73,13 @@ private static OrderingParser.OrderByFragmentContext buildParseTree(String fragm return parser.orderByFragment(); } catch (ParseCancellationException e) { + // When resetting the parser, its CommonTokenStream will seek(0) i.e. restart emitting buffered tokens. + // This is enough when reusing the lexer and parser, and it would be wrong to also reset the lexer. + // Resetting the lexer causes it to hand out tokens again from the start, which will then append to the + // CommonTokenStream and cause a wrong parse + // lexer.reset(); + // reset the input token stream and parser state - lexer.reset(); parser.reset(); // fall back to LL(k)-based parsing diff --git a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java index 88a0496ba548..302122174ce0 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java +++ b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/SemanticQueryBuilder.java @@ -771,12 +771,10 @@ public Object visitCte(HqlParser.CteContext ctx) { final JpaCteCriteria oldCte = currentPotentialRecursiveCte; try { currentPotentialRecursiveCte = null; - if ( queryExpressionContext instanceof HqlParser.SetQueryGroupContext setContext ) { - // A recursive query is only possible if the child count is lower than 5 e.g. `withClause? q1 op q2` - if ( setContext.getChildCount() < 5 ) { - if ( handleRecursive( ctx, setContext, cteContainer, name, cte, materialization ) ) { - return null; - } + // A recursive query is only possible if there are 2 ordered queries e.g. `q1 op q2` + if ( queryExpressionContext.orderedQuery().size() == 2 ) { + if ( handleRecursive( ctx, queryExpressionContext, cteContainer, name, cte, materialization ) ) { + return null; } } queryExpressionContext.accept( this ); @@ -794,7 +792,7 @@ public Object visitCte(HqlParser.CteContext ctx) { private boolean handleRecursive( HqlParser.CteContext cteContext, - HqlParser.SetQueryGroupContext setContext, + HqlParser.QueryExpressionContext setContext, SqmCteContainer cteContainer, String name, SqmSelectQuery cte, @@ -963,15 +961,6 @@ private static CteSearchClauseKind getCteSearchClauseKind(HqlParser.SearchClause return ctx.BREADTH() != null ? CteSearchClauseKind.BREADTH_FIRST : CteSearchClauseKind.DEPTH_FIRST; } - @Override - public SqmQueryPart visitSimpleQueryGroup(HqlParser.SimpleQueryGroupContext ctx) { - final var withClauseContext = ctx.withClause(); - if ( withClauseContext != null ) { - withClauseContext.accept( this ); - } - return (SqmQueryPart) ctx.orderedQuery().accept( this ); - } - @Override public SqmQueryPart visitQueryOrderExpression(HqlParser.QueryOrderExpressionContext ctx) { final SqmQuerySpec sqmQuerySpec = currentQuerySpec(); @@ -1008,25 +997,28 @@ public SqmQueryPart visitNestedQueryExpression(HqlParser.NestedQueryExpressio } @Override - public SqmQueryGroup visitSetQueryGroup(HqlParser.SetQueryGroupContext ctx) { + public SqmQueryPart visitQueryExpression(HqlParser.QueryExpressionContext ctx) { var withClauseContext = ctx.withClause(); if ( withClauseContext != null ) { withClauseContext.accept( this ); } + final var orderedQueryContexts = ctx.orderedQuery(); final SqmQueryPart firstQueryPart = - (SqmQueryPart) ctx.orderedQuery(0).accept( this ); + (SqmQueryPart) orderedQueryContexts.get( 0 ).accept( this ); + if ( orderedQueryContexts.size() == 1 ) { + return firstQueryPart; + } SqmQueryGroup queryGroup = firstQueryPart instanceof SqmQueryGroup sqmQueryGroup ? sqmQueryGroup : new SqmQueryGroup<>( firstQueryPart ); setCurrentQueryPart( queryGroup ); - final var orderedQueryContexts = ctx.orderedQuery(); final var setOperatorContexts = ctx.setOperator(); final SqmCreationProcessingState firstProcessingState = processingStateStack.pop(); for ( int i = 0; i < setOperatorContexts.size(); i++ ) { queryGroup = getSqmQueryGroup( visitSetOperator( setOperatorContexts.get(i) ), - orderedQueryContexts.get( i+1 ), + orderedQueryContexts.get( i + 1 ), queryGroup, setOperatorContexts.size(), firstProcessingState, @@ -1856,8 +1848,34 @@ public Object visitGeneralPathExpression(HqlParser.GeneralPathExpressionContext } @Override - public SqmExpression visitFunctionExpression(HqlParser.FunctionExpressionContext ctx) { - return (SqmExpression) ctx.function().accept( this ); + public Object visitFunctionExpression(HqlParser.FunctionExpressionContext ctx) { + final var slicedFragmentsCtx = ctx.slicedPathAccessFragment(); + if ( slicedFragmentsCtx != null ) { + final List slicedFragments = slicedFragmentsCtx.expression(); + return getFunctionDescriptor( "array_slice" ).generateSqmExpression( + List.of( + (SqmTypedNode) visitFunction( ctx.function() ), + (SqmTypedNode) slicedFragments.get( 0 ).accept( this ), + (SqmTypedNode) slicedFragments.get( 1 ).accept( this ) + ), + null, + queryEngine() + ); + } + else { + final var function = (SqmExpression) visitFunction( ctx.function() ); + final var indexedPathAccessFragment = ctx.indexedPathAccessFragment(); + final var pathContinuation = ctx.pathContinuation(); + if ( indexedPathAccessFragment == null && pathContinuation == null ) { + return function; + } + else { + return visitPathContinuation( + visitIndexedPathAccessFragment( (SemanticPathPart) function, indexedPathAccessFragment ), + pathContinuation + ); + } + } } @Override @@ -2415,83 +2433,131 @@ public SqmBetweenPredicate visitBetweenPredicate(HqlParser.BetweenPredicateConte ); } - @Override - public SqmNullnessPredicate visitIsNullPredicate(HqlParser.IsNullPredicateContext ctx) { - return new SqmNullnessPredicate( - (SqmExpression) ctx.expression().accept( this ), - ctx.NOT() != null, - nodeBuilder() - ); + public SqmPredicate visitUnaryIsPredicate(HqlParser.UnaryIsPredicateContext ctx) { + final var expression = (SqmExpression) ctx.expression().accept( this ); + final var negated = ctx.NOT() != null; + final var nodeBuilder = nodeBuilder(); + return switch ( ((TerminalNode) ctx.getChild( ctx.getChildCount() - 1 )).getSymbol().getType() ) { + case HqlParser.NULL -> new SqmNullnessPredicate( expression, negated, nodeBuilder ); + case HqlParser.EMPTY -> { + if ( expression instanceof SqmPluralValuedSimplePath pluralValuedSimplePath ) { + yield new SqmEmptinessPredicate( pluralValuedSimplePath, negated, nodeBuilder ); + } + else { + throw new SemanticException( "Operand of 'is empty' operator must be a plural path", query ); + } + } + case HqlParser.TRUE -> new SqmTruthnessPredicate( expression, true, negated, nodeBuilder ); + case HqlParser.FALSE -> new SqmTruthnessPredicate( expression, false, negated, nodeBuilder ); + default -> throw new AssertionError( "Unknown unary is predicate: " + ctx.getChild( ctx.getChildCount() - 1 ) ); + }; } @Override - public SqmEmptinessPredicate visitIsEmptyPredicate(HqlParser.IsEmptyPredicateContext ctx) { - SqmExpression expression = (SqmExpression) ctx.expression().accept(this); - if ( expression instanceof SqmPluralValuedSimplePath pluralValuedSimplePath ) { - return new SqmEmptinessPredicate( - pluralValuedSimplePath, - ctx.NOT() != null, - nodeBuilder() - ); + public SqmPredicate visitBinaryExpressionPredicate(HqlParser.BinaryExpressionPredicateContext ctx) { + final var firstSymbol = ((TerminalNode) ctx.getChild( 1 )).getSymbol(); + final boolean negated; + final Token operationSymbol; + if ( firstSymbol.getType() == HqlParser.NOT ) { + negated = true; + operationSymbol = ((TerminalNode) ctx.getChild( 2 )).getSymbol(); } else { - throw new SemanticException( "Operand of 'is empty' operator must be a plural path", query ); - } - } - - @Override - public Object visitIsTruePredicate(HqlParser.IsTruePredicateContext ctx) { - return new SqmTruthnessPredicate( - (SqmExpression) ctx.expression().accept( this ), - true, - ctx.NOT() != null, - nodeBuilder() - ); - } - - @Override - public Object visitIsFalsePredicate(HqlParser.IsFalsePredicateContext ctx) { - return new SqmTruthnessPredicate( - (SqmExpression) ctx.expression().accept( this ), - false, - ctx.NOT() != null, - nodeBuilder() - ); - } - - @Override - public Object visitComparisonOperator(HqlParser.ComparisonOperatorContext ctx) { - final TerminalNode firstToken = (TerminalNode) ctx.getChild( 0 ); - return switch ( firstToken.getSymbol().getType() ) { - case HqlLexer.EQUAL -> ComparisonOperator.EQUAL; - case HqlLexer.NOT_EQUAL -> ComparisonOperator.NOT_EQUAL; - case HqlLexer.LESS -> ComparisonOperator.LESS_THAN; - case HqlLexer.LESS_EQUAL -> ComparisonOperator.LESS_THAN_OR_EQUAL; - case HqlLexer.GREATER -> ComparisonOperator.GREATER_THAN; - case HqlLexer.GREATER_EQUAL -> ComparisonOperator.GREATER_THAN_OR_EQUAL; - default -> throw new ParsingException( "Unrecognized comparison operator" ); + negated = firstSymbol.getType() == HqlParser.IS + && ((TerminalNode) ctx.getChild( 2 )).getSymbol().getType() == HqlParser.NOT; + operationSymbol = firstSymbol; + } + final var expressions = ctx.expression(); + final var lhsCtx = expressions.get( 0 ); + final var rhsCtx = expressions.get( 1 ); + return switch ( operationSymbol.getType() ) { + case HqlParser.CONTAINS -> { + final var lhs = (SqmExpression) lhsCtx.accept( this ); + final var rhs = (SqmExpression) rhsCtx.accept( this ); + final var lhsExpressible = lhs.getExpressible(); + if ( lhsExpressible != null && !(lhsExpressible.getSqmType() instanceof BasicPluralType) ) { + throw new SemanticException( + "First operand for contains predicate must be a basic plural type expression, but found: " + lhsExpressible.getSqmType(), + query + ); + } + final SelfRenderingSqmFunction contains = getFunctionDescriptor( + "array_contains" ).generateSqmExpression( + asList( lhs, rhs ), + null, + queryEngine() + ); + yield new SqmBooleanExpressionPredicate( contains, negated, nodeBuilder() ); + } + case HqlParser.INCLUDES -> { + final var lhs = (SqmExpression) lhsCtx.accept( this ); + final var rhs = (SqmExpression) rhsCtx.accept( this ); + final var lhsExpressible = lhs.getExpressible(); + final var rhsExpressible = rhs.getExpressible(); + if ( lhsExpressible != null && !( lhsExpressible.getSqmType() instanceof BasicPluralType) ) { + throw new SemanticException( + "First operand for includes predicate must be a basic plural type expression, but found: " + + lhsExpressible.getSqmType(), + query + ); + } + if ( rhsExpressible != null && !( rhsExpressible.getSqmType() instanceof BasicPluralType) ) { + throw new SemanticException( + "Second operand for includes predicate must be a basic plural type expression, but found: " + + rhsExpressible.getSqmType(), + query + ); + } + final SelfRenderingSqmFunction contains = getFunctionDescriptor( "array_includes" ).generateSqmExpression( + asList( lhs, rhs ), + null, + queryEngine() + ); + yield new SqmBooleanExpressionPredicate( contains, negated, nodeBuilder() ); + } + case HqlParser.INTERSECTS -> { + final var lhs = (SqmExpression) lhsCtx.accept( this ); + final var rhs = (SqmExpression) rhsCtx.accept( this ); + final var lhsExpressible = lhs.getExpressible(); + if ( lhsExpressible != null && !( lhsExpressible.getSqmType() instanceof BasicPluralType) ) { + throw new SemanticException( + "First operand for intersects predicate must be a basic plural type expression, but found: " + + lhsExpressible.getSqmType(), + query + ); + } + final SelfRenderingSqmFunction contains = + getFunctionDescriptor( "array_intersects" ) + .generateSqmExpression( + asList( lhs, rhs ), + null, + queryEngine() + ); + yield new SqmBooleanExpressionPredicate( contains, negated, nodeBuilder() ); + } + case HqlParser.EQUAL -> + createComparisonPredicate( ComparisonOperator.EQUAL, lhsCtx, rhsCtx ); + case HqlParser.NOT_EQUAL -> + createComparisonPredicate( ComparisonOperator.NOT_EQUAL, lhsCtx, rhsCtx ); + case HqlParser.LESS -> + createComparisonPredicate( ComparisonOperator.LESS_THAN, lhsCtx, rhsCtx ); + case HqlParser.LESS_EQUAL -> + createComparisonPredicate( ComparisonOperator.LESS_THAN_OR_EQUAL, lhsCtx, rhsCtx ); + case HqlParser.GREATER -> + createComparisonPredicate( ComparisonOperator.GREATER_THAN, lhsCtx, rhsCtx ); + case HqlParser.GREATER_EQUAL -> + createComparisonPredicate( ComparisonOperator.GREATER_THAN_OR_EQUAL, lhsCtx, rhsCtx ); + case HqlParser.IS -> { + final ComparisonOperator comparisonOperator = !negated + ? ComparisonOperator.DISTINCT_FROM + : ComparisonOperator.NOT_DISTINCT_FROM; + yield createComparisonPredicate( comparisonOperator, lhsCtx, rhsCtx ); + } + default -> throw new AssertionError( "Unknown binary expression predicate: " + operationSymbol ); }; } - @Override - public SqmPredicate visitComparisonPredicate(HqlParser.ComparisonPredicateContext ctx) { - final ComparisonOperator comparisonOperator = (ComparisonOperator) ctx.comparisonOperator().accept( this ); - final var leftExpressionContext = ctx.expression( 0 ); - final var rightExpressionContext = ctx.expression( 1 ); - return createComparisonPredicate( comparisonOperator, leftExpressionContext, rightExpressionContext ); - } - - @Override - public SqmPredicate visitIsDistinctFromPredicate(HqlParser.IsDistinctFromPredicateContext ctx) { - final var leftExpressionContext = ctx.expression( 0 ); - final var rightExpressionContext = ctx.expression( 1 ); - final ComparisonOperator comparisonOperator = ctx.NOT() == null - ? ComparisonOperator.DISTINCT_FROM - : ComparisonOperator.NOT_DISTINCT_FROM; - return createComparisonPredicate( comparisonOperator, leftExpressionContext, rightExpressionContext ); - } - private SqmComparisonPredicate createComparisonPredicate( ComparisonOperator comparisonOperator, HqlParser.ExpressionContext leftExpressionContext, @@ -2614,26 +2680,6 @@ private String getPossibleEnumValue(HqlParser.ExpressionContext expressionContex return null; } - @Override - public SqmPredicate visitContainsPredicate(HqlParser.ContainsPredicateContext ctx) { - final boolean negated = ctx.NOT() != null; - final SqmExpression lhs = (SqmExpression) ctx.expression( 0 ).accept( this ); - final SqmExpression rhs = (SqmExpression) ctx.expression( 1 ).accept( this ); - final SqmExpressible lhsExpressible = lhs.getExpressible(); - if ( lhsExpressible != null && !( lhsExpressible.getSqmType() instanceof BasicPluralType) ) { - throw new SemanticException( - "First operand for contains predicate must be a basic plural type expression, but found: " + lhsExpressible.getSqmType(), - query - ); - } - final SelfRenderingSqmFunction contains = getFunctionDescriptor( "array_contains" ).generateSqmExpression( - asList( lhs, rhs ), - null, - queryEngine() - ); - return new SqmBooleanExpressionPredicate( contains, negated, nodeBuilder() ); - } - @Override public SqmExpression visitJsonValueFunction(HqlParser.JsonValueFunctionContext ctx) { checkJsonFunctionsEnabled( ctx ); @@ -3168,58 +3214,6 @@ private void checkXmlFunctionsEnabled(ParserRuleContext ctx) { } } - @Override - public SqmPredicate visitIncludesPredicate(HqlParser.IncludesPredicateContext ctx) { - final boolean negated = ctx.NOT() != null; - final SqmExpression lhs = (SqmExpression) ctx.expression( 0 ).accept( this ); - final SqmExpression rhs = (SqmExpression) ctx.expression( 1 ).accept( this ); - final SqmExpressible lhsExpressible = lhs.getExpressible(); - final SqmExpressible rhsExpressible = rhs.getExpressible(); - if ( lhsExpressible != null && !( lhsExpressible.getSqmType() instanceof BasicPluralType) ) { - throw new SemanticException( - "First operand for includes predicate must be a basic plural type expression, but found: " - + lhsExpressible.getSqmType(), - query - ); - } - if ( rhsExpressible != null && !( rhsExpressible.getSqmType() instanceof BasicPluralType) ) { - throw new SemanticException( - "Second operand for includes predicate must be a basic plural type expression, but found: " - + rhsExpressible.getSqmType(), - query - ); - } - final SelfRenderingSqmFunction contains = getFunctionDescriptor( "array_includes" ).generateSqmExpression( - asList( lhs, rhs ), - null, - queryEngine() - ); - return new SqmBooleanExpressionPredicate( contains, negated, nodeBuilder() ); - } - - @Override - public SqmPredicate visitIntersectsPredicate(HqlParser.IntersectsPredicateContext ctx) { - final boolean negated = ctx.NOT() != null; - final SqmExpression lhs = (SqmExpression) ctx.expression( 0 ).accept( this ); - final SqmExpression rhs = (SqmExpression) ctx.expression( 1 ).accept( this ); - final SqmExpressible lhsExpressible = lhs.getExpressible(); - if ( lhsExpressible != null && !( lhsExpressible.getSqmType() instanceof BasicPluralType) ) { - throw new SemanticException( - "First operand for intersects predicate must be a basic plural type expression, but found: " - + lhsExpressible.getSqmType(), - query - ); - } - final SelfRenderingSqmFunction contains = - getFunctionDescriptor( "array_intersects" ) - .generateSqmExpression( - asList( lhs, rhs ), - null, - queryEngine() - ); - return new SqmBooleanExpressionPredicate( contains, negated, nodeBuilder() ); - } - @Override public SqmPredicate visitLikePredicate(HqlParser.LikePredicateContext ctx) { final boolean negated = ctx.NOT() != null; @@ -3568,11 +3562,6 @@ else if ( attributes.size() >1 ) { throw new FunctionArgumentException( "Argument '" + sqmPath.getNavigablePath() + "' of 'naturalid()' does not resolve to an entity type" ); } -// -// @Override -// public Object visitToOneFkExpression(HqlParser.ToOneFkExpressionContext ctx) { -// return visitToOneFkReference( (HqlParser.ToOneFkReferenceContext) ctx.getChild( 0 ) ); -// } @Override public SqmFkExpression visitToOneFkReference(HqlParser.ToOneFkReferenceContext ctx) { @@ -5738,33 +5727,6 @@ else if ( ctx.collectionValueNavigablePath() != null ) { else if ( ctx.mapKeyNavigablePath() != null ) { return visitMapKeyNavigablePath( ctx.mapKeyNavigablePath() ); } - else if ( ctx.toOneFkReference() != null ) { - return visitToOneFkReference( ctx.toOneFkReference() ); - } - else if ( ctx.function() != null ) { - final var slicedFragmentsCtx = ctx.slicedPathAccessFragment(); - if ( slicedFragmentsCtx != null ) { - final List slicedFragments = slicedFragmentsCtx.expression(); - return getFunctionDescriptor( "array_slice" ).generateSqmExpression( - List.of( - (SqmTypedNode) visitFunction( ctx.function() ), - (SqmTypedNode) slicedFragments.get( 0 ).accept( this ), - (SqmTypedNode) slicedFragments.get( 1 ).accept( this ) - ), - null, - queryEngine() - ); - } - else { - return visitPathContinuation( - visitIndexedPathAccessFragment( - (SemanticPathPart) visitFunction( ctx.function() ), - ctx.indexedPathAccessFragment() - ), - ctx.pathContinuation() - ); - } - } else if ( ctx.simplePath() != null && ctx.indexedPathAccessFragment() != null ) { return visitIndexedPathAccessFragment( visitSimplePath( ctx.simplePath() ), ctx.indexedPathAccessFragment() ); } diff --git a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/StandardHqlTranslator.java b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/StandardHqlTranslator.java index a40c1589fdc1..3b6eef1846c7 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/hql/internal/StandardHqlTranslator.java +++ b/hibernate-core/src/main/java/org/hibernate/query/hql/internal/StandardHqlTranslator.java @@ -104,9 +104,14 @@ private HqlParser.StatementContext parseHql(String hql) { try { return hqlParser.statement(); } - catch ( ParseCancellationException e) { + catch (ParseCancellationException e) { + // When resetting the parser, its CommonTokenStream will seek(0) i.e. restart emitting buffered tokens. + // This is enough when reusing the lexer and parser, and it would be wrong to also reset the lexer. + // Resetting the lexer causes it to hand out tokens again from the start, which will then append to the + // CommonTokenStream and cause a wrong parse + // hqlLexer.reset(); + // reset the input token stream and parser state - hqlLexer.reset(); hqlParser.reset(); // fall back to LL(k)-based parsing diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/hql/HqlParserMemoryUsageTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/hql/HqlParserMemoryUsageTest.java new file mode 100644 index 000000000000..495a00eeba0d --- /dev/null +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/hql/HqlParserMemoryUsageTest.java @@ -0,0 +1,166 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.orm.test.hql; + +import jakarta.persistence.Entity; +import jakarta.persistence.FetchType; +import jakarta.persistence.Id; +import jakarta.persistence.ManyToOne; +import jakarta.persistence.OneToMany; +import jakarta.persistence.Table; +import org.hibernate.cfg.QuerySettings; +import org.hibernate.query.hql.HqlTranslator; +import org.hibernate.testing.memory.MemoryUsageUtil; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.Jira; +import org.hibernate.testing.orm.junit.ServiceRegistry; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.testing.orm.junit.Setting; +import org.junit.jupiter.api.Test; + +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +@DomainModel( + annotatedClasses = { + HqlParserMemoryUsageTest.Address.class, + HqlParserMemoryUsageTest.AppUser.class, + HqlParserMemoryUsageTest.Category.class, + HqlParserMemoryUsageTest.Discount.class, + HqlParserMemoryUsageTest.Order.class, + HqlParserMemoryUsageTest.OrderItem.class, + HqlParserMemoryUsageTest.Product.class + } +) +@SessionFactory +@ServiceRegistry(settings = @Setting(name = QuerySettings.QUERY_PLAN_CACHE_ENABLED, value = "false")) +@Jira("https://hibernate.atlassian.net/browse/HHH-19240") +public class HqlParserMemoryUsageTest { + + private static final String HQL = """ + SELECT DISTINCT u.id + FROM AppUser u + LEFT JOIN u.addresses a + LEFT JOIN u.orders o + LEFT JOIN o.orderItems oi + LEFT JOIN oi.product p + LEFT JOIN p.discounts d + WHERE u.id = :userId + AND ( + CASE + WHEN u.name = 'SPECIAL_USER' THEN TRUE + ELSE ( + CASE + WHEN a.city = 'New York' THEN TRUE + ELSE ( + p.category.name = 'Electronics' + OR d.code LIKE '%DISC%' + OR u.id IN ( + SELECT u2.id + FROM AppUser u2 + JOIN u2.orders o2 + JOIN o2.orderItems oi2 + JOIN oi2.product p2 + WHERE p2.price > ( + SELECT AVG(p3.price) FROM Product p3 + ) + ) + ) + END + ) + END + ) + """; + + + @Test + public void testParserMemoryUsage(SessionFactoryScope scope) { + final HqlTranslator hqlTranslator = scope.getSessionFactory().getQueryEngine().getHqlTranslator(); + + // Ensure classes and basic stuff is initialized in case this is the first test run + hqlTranslator.translate( "from AppUser", AppUser.class ); + + // During testing, before the fix for HHH-19240, the allocation was around 500+ MB, + // and after the fix it dropped to 170 - 250 MB + final long memoryUsage = MemoryUsageUtil.estimateMemoryUsage( () -> hqlTranslator.translate( HQL, Long.class ) ); + System.out.println( "Memory Consumption: " + (memoryUsage / 1024) + " KB" ); + assertTrue( memoryUsage < 256_000_000, "Parsing of queries consumes too much memory (" + ( memoryUsage / 1024 ) + " KB), when at most 256 MB are expected" ); + } + + @Entity(name = "Address") + @Table(name = "addresses") + public static class Address { + @Id + private Long id; + private String city; + @ManyToOne(fetch = FetchType.LAZY) + private AppUser user; + } + @Entity(name = "AppUser") + @Table(name = "app_users") + public static class AppUser { + @Id + private Long id; + private String name; + @OneToMany(mappedBy = "user") + private Set
addresses; + @OneToMany(mappedBy = "user") + private Set orders; + } + + @Entity(name = "Category") + @Table(name = "categories") + public static class Category { + @Id + private Long id; + private String name; + } + + @Entity(name = "Discount") + @Table(name = "discounts") + public static class Discount { + @Id + private Long id; + private String code; + @ManyToOne(fetch = FetchType.LAZY) + private Product product; + } + + @Entity(name = "Order") + @Table(name = "orders") + public static class Order { + @Id + private Long id; + @ManyToOne(fetch = FetchType.LAZY) + private AppUser user; + @OneToMany(mappedBy = "order") + private Set orderItems; + } + @Entity(name = "OrderItem") + @Table(name = "order_items") + public static class OrderItem { + @Id + private Long id; + @ManyToOne(fetch = FetchType.LAZY) + private Order order; + @ManyToOne(fetch = FetchType.LAZY) + private Product product; + } + + @Entity(name = "Product") + @Table(name = "products") + public static class Product { + @Id + private Long id; + private String name; + private Double price; + @ManyToOne(fetch = FetchType.LAZY) + private Category category; + @OneToMany(mappedBy = "product") + private Set discounts; + } +} diff --git a/hibernate-testing/src/main/java/org/hibernate/testing/memory/GlobalMemoryUsageSnapshotter.java b/hibernate-testing/src/main/java/org/hibernate/testing/memory/GlobalMemoryUsageSnapshotter.java new file mode 100644 index 000000000000..055feea7bf30 --- /dev/null +++ b/hibernate-testing/src/main/java/org/hibernate/testing/memory/GlobalMemoryUsageSnapshotter.java @@ -0,0 +1,53 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.testing.memory; + +import java.lang.management.ManagementFactory; +import java.lang.management.MemoryPoolMXBean; +import java.util.List; + +final class GlobalMemoryUsageSnapshotter implements MemoryAllocationSnapshotter { + + private static final GlobalMemoryUsageSnapshotter INSTANCE = new GlobalMemoryUsageSnapshotter( + ManagementFactory.getMemoryPoolMXBeans() + ); + + private final List heapPoolBeans; + private final Runnable gcAndWait; + + private GlobalMemoryUsageSnapshotter(List heapPoolBeans) { + this.heapPoolBeans = heapPoolBeans; + this.gcAndWait = () -> { + for (int i = 0; i < 3; i++) { + System.gc(); + try { Thread.sleep(50); } catch (InterruptedException ignored) {} + } + }; + } + + public static GlobalMemoryUsageSnapshotter getInstance() { + return INSTANCE; + } + + @Override + public MemoryAllocationSnapshot snapshot() { + final long peakUsage = heapPoolBeans.stream().mapToLong(p -> p.getPeakUsage().getUsed()).sum(); + gcAndWait.run(); + final long retainedUsage = heapPoolBeans.stream().mapToLong(p -> p.getUsage().getUsed()).sum(); + heapPoolBeans.forEach(MemoryPoolMXBean::resetPeakUsage); + return new GlobalMemoryAllocationSnapshot( peakUsage, retainedUsage ); + } + + record GlobalMemoryAllocationSnapshot(long peakUsage, long retainedUsage) implements MemoryAllocationSnapshot { + + @Override + public long difference(MemoryAllocationSnapshot before) { + // When doing the "before" snapshot, the peak usage is reset. + // Since this object is the "after" snapshot, we can simply estimate the memory usage of an operation + // to be the peak usage of that operation minus the usage after GC + return peakUsage - retainedUsage; + } + } +} diff --git a/hibernate-testing/src/main/java/org/hibernate/testing/memory/HotspotPerThreadAllocationSnapshotter.java b/hibernate-testing/src/main/java/org/hibernate/testing/memory/HotspotPerThreadAllocationSnapshotter.java new file mode 100644 index 000000000000..e73d67e11fd8 --- /dev/null +++ b/hibernate-testing/src/main/java/org/hibernate/testing/memory/HotspotPerThreadAllocationSnapshotter.java @@ -0,0 +1,95 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.testing.memory; + +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.lang.management.ManagementFactory; +import java.lang.management.ThreadMXBean; +import java.lang.reflect.Method; +import java.util.HashMap; + +record HotspotPerThreadAllocationSnapshotter(ThreadMXBean threadMXBean) implements MemoryAllocationSnapshotter { + + private static final @Nullable HotspotPerThreadAllocationSnapshotter INSTANCE; + private static final Method GET_THREAD_ALLOCATED_BYTES; + + static { + ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean(); + Method method = null; + try { + @SuppressWarnings("unchecked") + Class hotspotInterface = + (Class) Class.forName( "com.sun.management.ThreadMXBean" ); + try { + method = hotspotInterface.getMethod( "getThreadAllocatedBytes", long[].class ); + } + catch (Exception e) { + // Ignore + } + + if ( !hotspotInterface.isInstance( threadMXBean ) ) { + threadMXBean = ManagementFactory.getPlatformMXBean( hotspotInterface ); + } + } + catch (Throwable e) { + // Ignore + } + + GET_THREAD_ALLOCATED_BYTES = method; + + HotspotPerThreadAllocationSnapshotter instance = null; + if ( method != null && threadMXBean != null ) { + try { + instance = new HotspotPerThreadAllocationSnapshotter( threadMXBean ); + instance.snapshot(); + } + catch (Exception e) { + instance = null; + } + } + INSTANCE = instance; + } + + public static @Nullable HotspotPerThreadAllocationSnapshotter getInstance() { + return INSTANCE; + } + + @Override + public MemoryAllocationSnapshot snapshot() { + long[] threadIds = threadMXBean.getAllThreadIds(); + try { + return new PerThreadMemoryAllocationSnapshot( + threadIds, + (long[]) GET_THREAD_ALLOCATED_BYTES.invoke( threadMXBean, (Object) threadIds ) + ); + } + catch (Exception e) { + throw new RuntimeException( e ); + } + } + + record PerThreadMemoryAllocationSnapshot(long[] threadIds, long[] threadAllocatedBytes) + implements MemoryAllocationSnapshot { + + @Override + public long difference(MemoryAllocationSnapshot before) { + final PerThreadMemoryAllocationSnapshot other = (PerThreadMemoryAllocationSnapshot) before; + final HashMap previousThreadIdToIndexMap = new HashMap<>(); + for ( int i = 0; i < other.threadIds.length; i++ ) { + previousThreadIdToIndexMap.put( other.threadIds[i], i ); + } + long allocatedBytes = 0; + for ( int i = 0; i < threadIds.length; i++ ) { + allocatedBytes += threadAllocatedBytes[i]; + final Integer previousThreadIndex = previousThreadIdToIndexMap.get( threadIds[i] ); + if ( previousThreadIndex != null ) { + allocatedBytes -= other.threadAllocatedBytes[previousThreadIndex]; + } + } + return allocatedBytes; + } + } +} diff --git a/hibernate-testing/src/main/java/org/hibernate/testing/memory/HotspotTotalThreadBytesSnapshotter.java b/hibernate-testing/src/main/java/org/hibernate/testing/memory/HotspotTotalThreadBytesSnapshotter.java new file mode 100644 index 000000000000..d55c1e63f923 --- /dev/null +++ b/hibernate-testing/src/main/java/org/hibernate/testing/memory/HotspotTotalThreadBytesSnapshotter.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.testing.memory; + +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.lang.management.ManagementFactory; +import java.lang.management.ThreadMXBean; +import java.lang.reflect.Method; + +record HotspotTotalThreadBytesSnapshotter(ThreadMXBean threadMXBean) implements MemoryAllocationSnapshotter { + + private static final @Nullable HotspotTotalThreadBytesSnapshotter INSTANCE; + private static final Method GET_TOTAL_THREAD_ALLOCATED_BYTES; + + static { + ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean(); + Method method = null; + try { + @SuppressWarnings("unchecked") + Class hotspotInterface = + (Class) Class.forName( "com.sun.management.ThreadMXBean" ); + try { + method = hotspotInterface.getMethod( "getTotalThreadAllocatedBytes" ); + } + catch (Exception e) { + // Ignore + } + + if ( !hotspotInterface.isInstance( threadMXBean ) ) { + threadMXBean = ManagementFactory.getPlatformMXBean( hotspotInterface ); + } + } + catch (Throwable e) { + // Ignore + } + + GET_TOTAL_THREAD_ALLOCATED_BYTES = method; + + HotspotTotalThreadBytesSnapshotter instance = null; + if ( method != null && threadMXBean != null ) { + try { + instance = new HotspotTotalThreadBytesSnapshotter( threadMXBean ); + instance.snapshot(); + } + catch (Exception e) { + instance = null; + } + } + INSTANCE = instance; + } + + public static @Nullable HotspotTotalThreadBytesSnapshotter getInstance() { + return INSTANCE; + } + + @Override + public MemoryAllocationSnapshot snapshot() { + try { + return new GlobalMemoryAllocationSnapshot( (long) GET_TOTAL_THREAD_ALLOCATED_BYTES.invoke( threadMXBean ) ); + } + catch (Exception e) { + throw new RuntimeException( e ); + } + } + + record GlobalMemoryAllocationSnapshot(long allocatedBytes) implements MemoryAllocationSnapshot { + + GlobalMemoryAllocationSnapshot { + if ( allocatedBytes == -1L ) { + throw new IllegalArgumentException( "getTotalThreadAllocatedBytes is disabled" ); + } + } + + @Override + public long difference(MemoryAllocationSnapshot before) { + final GlobalMemoryAllocationSnapshot other = (GlobalMemoryAllocationSnapshot) before; + return Math.max( allocatedBytes - other.allocatedBytes, 0L ); + } + } +} diff --git a/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryAllocationSnapshot.java b/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryAllocationSnapshot.java new file mode 100644 index 000000000000..1aba66cef8db --- /dev/null +++ b/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryAllocationSnapshot.java @@ -0,0 +1,9 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.testing.memory; + +interface MemoryAllocationSnapshot { + long difference(MemoryAllocationSnapshot before); +} diff --git a/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryAllocationSnapshotter.java b/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryAllocationSnapshotter.java new file mode 100644 index 000000000000..75d87fc9325f --- /dev/null +++ b/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryAllocationSnapshotter.java @@ -0,0 +1,9 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.testing.memory; + +interface MemoryAllocationSnapshotter { + MemoryAllocationSnapshot snapshot(); +} diff --git a/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryUsageUtil.java b/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryUsageUtil.java new file mode 100644 index 000000000000..55d458832bc6 --- /dev/null +++ b/hibernate-testing/src/main/java/org/hibernate/testing/memory/MemoryUsageUtil.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.testing.memory; + +public class MemoryUsageUtil { + + private static final MemoryAllocationSnapshotter SNAPSHOTTER; + + static { + MemoryAllocationSnapshotter snapshotter = HotspotTotalThreadBytesSnapshotter.getInstance(); + if ( snapshotter == null ) { + snapshotter = HotspotPerThreadAllocationSnapshotter.getInstance(); + } + if ( snapshotter == null ) { + snapshotter = GlobalMemoryUsageSnapshotter.getInstance(); + } + SNAPSHOTTER = snapshotter; + } + + public static long estimateMemoryUsage(Runnable runnable) { + final MemoryAllocationSnapshot beforeSnapshot = SNAPSHOTTER.snapshot(); + runnable.run(); + return SNAPSHOTTER.snapshot().difference( beforeSnapshot ); + } +}