diff --git a/hibernate-core/src/main/java/org/hibernate/internal/util/collections/LinkedIdentityHashMap.java b/hibernate-core/src/main/java/org/hibernate/internal/util/collections/LinkedIdentityHashMap.java new file mode 100644 index 000000000000..564bf1009e73 --- /dev/null +++ b/hibernate-core/src/main/java/org/hibernate/internal/util/collections/LinkedIdentityHashMap.java @@ -0,0 +1,297 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.internal.util.collections; + +import java.util.AbstractMap; +import java.util.AbstractSet; +import java.util.Arrays; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Set; + + +/** + * Utility {@link Map} implementation that uses identity (==) for key comparison and preserves insertion order + */ +public class LinkedIdentityHashMap extends AbstractMap implements Map { + private static final int DEFAULT_INITIAL_CAPACITY = 16; // must be power of two + + static final class Node implements Map.Entry { + final K key; + V value; + Node next; + Node before; + Node after; + + Node(K key, V value, Node next) { + this.key = key; + this.value = value; + this.next = next; + } + + @Override + public K getKey() { + return key; + } + + @Override + public V getValue() { + return value; + } + + @Override + public V setValue(V newValue) { + final V old = value; + value = newValue; + return old; + } + + @Override + public boolean equals(Object o) { + return o instanceof Node node && key == node.key && Objects.equals( value, node.value ); + } + + @Override + public int hashCode() { + int result = System.identityHashCode( key ); + result = 31 * result + Objects.hashCode( value ); + return result; + } + + @Override + public String toString() { + return key + "=" + value; + } + } + + private Node[] table; + private int size; + + private Node head; + private Node tail; + + private transient Set> entrySet; + + public LinkedIdentityHashMap() { + this( DEFAULT_INITIAL_CAPACITY ); + } + + public LinkedIdentityHashMap(int initialCapacity) { + if ( initialCapacity < 0 ) { + throw new IllegalArgumentException( "Illegal initial capacity: " + initialCapacity ); + } + int cap = 1; + while ( cap < initialCapacity ) { + cap <<= 1; + } + //noinspection unchecked + table = (Node[]) new Node[cap]; + } + + private static int indexFor(int hash, int length) { + return hash & (length - 1); + } + + @Override + public V get(Object key) { + final Node e = getNode( key ); + return e != null ? e.value : null; + } + + private Node getNode(Object key) { + final int hash = System.identityHashCode( key ); + final int idx = indexFor( hash, table.length ); + for ( Node e = table[idx]; e != null; e = e.next ) { + if ( e.key == key ) { + return e; + } + } + return null; + } + + @Override + public boolean containsKey(Object key) { + return getNode( key ) != null; + } + + @Override + public boolean containsValue(Object value) { + for ( Node e = head; e != null; e = e.after ) { + if ( Objects.equals( e.value, value ) ) { + return true; + } + } + return false; + } + + @Override + public V put(K key, V value) { + final int hash = System.identityHashCode( key ); + final int idx = indexFor( hash, table.length ); + for ( Node e = table[idx]; e != null; e = e.next ) { + if ( e.key == key ) { + final V old = e.value; + e.value = value; + return old; + } + } + // not found -> insert + final Node newNode = new Node<>( key, value, table[idx] ); + table[idx] = newNode; + linkLast( newNode ); + size++; + if ( size == table.length ) { + resize(); + } + return null; + } + + private void linkLast(Node node) { + if ( tail == null ) { + head = tail = node; + } + else { + tail.after = node; + node.before = tail; + tail = node; + } + } + + @Override + public V remove(Object key) { + final int hash = System.identityHashCode( key ); + final int idx = indexFor( hash, table.length ); + Node prev = null; + for ( Node e = table[idx]; e != null; prev = e, e = e.next ) { + if ( e.key == key ) { + // remove from bucket chain + if ( prev == null ) { + table[idx] = e.next; + } + else { + prev.next = e.next; + } + // unlink from insertion-order list + final Node b = e.before; + final Node a = e.after; + if ( b == null ) { + head = a; + } + else { + b.after = a; + } + if ( a == null ) { + tail = b; + } + else { + a.before = b; + } + size--; + return e.value; + } + } + return null; + } + + @Override + public void clear() { + Arrays.fill( table, null ); + head = tail = null; + size = 0; + } + + @Override + public int size() { + return size; + } + + private void resize() { + final int oldCap = table.length; + final int newCap = oldCap << 1; + //noinspection unchecked + final Node[] newTable = (Node[]) new Node[newCap]; + for ( int i = 0; i < oldCap; i++ ) { + Node e = table[i]; + while ( e != null ) { + final Node next = e.next; + final int idx = indexFor( System.identityHashCode( e.key ), newCap ); + e.next = newTable[idx]; + newTable[idx] = e; + e = next; + } + } + table = newTable; + } + + final class EntryIterator implements Iterator> { + private Node next = head; + private Node current = null; + + @Override + public boolean hasNext() { + return next != null; + } + + @Override + public Node next() { + Node e = next; + if ( e == null ) { + throw new NoSuchElementException(); + } + current = e; + next = e.after; + return e; + } + + @Override + public void remove() { + Node e = current; + if ( e == null ) { + throw new IllegalStateException(); + } + LinkedIdentityHashMap.this.remove( e.key ); + current = null; + } + } + + final class EntrySet extends AbstractSet> { + @Override + public Iterator> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return LinkedIdentityHashMap.this.size; + } + + @Override + public void clear() { + LinkedIdentityHashMap.this.clear(); + } + + @Override + public boolean contains(Object o) { + if ( o instanceof Entry e ) { + final Node n = getNode( e.getKey() ); + return n != null && Objects.equals( n.value, e.getValue() ); + } + return false; + } + + @Override + public boolean remove(Object o) { + return o instanceof Entry e && LinkedIdentityHashMap.this.remove( e.getKey() ) != null; + } + } + + @Override + public Set> entrySet() { + Set> es; + return (es = entrySet) == null ? (entrySet = new EntrySet()) : es; + } +} diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/DomainParameterXref.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/DomainParameterXref.java index dc5c8a1868d9..79d7e2929f10 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/DomainParameterXref.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/internal/DomainParameterXref.java @@ -6,11 +6,11 @@ import java.util.ArrayList; import java.util.IdentityHashMap; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.TreeMap; +import org.hibernate.internal.util.collections.LinkedIdentityHashMap; import org.hibernate.query.internal.QueryParameterNamedImpl; import org.hibernate.query.internal.QueryParameterPositionalImpl; import org.hibernate.query.spi.QueryParameterImplementor; @@ -31,7 +31,7 @@ public class DomainParameterXref { public static final DomainParameterXref EMPTY = new DomainParameterXref( - new LinkedHashMap<>( 0 ), + new LinkedIdentityHashMap<>( 0 ), new IdentityHashMap<>( 0 ), SqmStatement.ParameterResolutions.empty() ); @@ -46,8 +46,8 @@ public static DomainParameterXref from(SqmStatement sqmStatement) { } else { final int sqmParamCount = parameterResolutions.getSqmParameters().size(); - final LinkedHashMap, List>> sqmParamsByQueryParam = - new LinkedHashMap<>( sqmParamCount ); + final Map, List>> sqmParamsByQueryParam = + new LinkedIdentityHashMap<>( sqmParamCount ); final IdentityHashMap, QueryParameterImplementor> queryParamBySqmParam = new IdentityHashMap<>( sqmParamCount ); @@ -118,13 +118,13 @@ else if ( sqmParameter.getExpressible() != null private final SqmStatement.ParameterResolutions parameterResolutions; - private final LinkedHashMap, List>> sqmParamsByQueryParam; + private final Map, List>> sqmParamsByQueryParam; private final IdentityHashMap, QueryParameterImplementor> queryParamBySqmParam; private Map,List>> expansions; private DomainParameterXref( - LinkedHashMap, List>> sqmParamsByQueryParam, + Map, List>> sqmParamsByQueryParam, IdentityHashMap, QueryParameterImplementor> queryParamBySqmParam, SqmStatement.ParameterResolutions parameterResolutions) { this.sqmParamsByQueryParam = sqmParamsByQueryParam; @@ -148,7 +148,9 @@ public boolean hasParameters() { } /** - * Get all of the QueryParameters mapped by this xref + * Get all the QueryParameters mapped by this xref. + * Note that order of parameters is important - parameters are + * included in cache keys for query results caching. */ public Map, List>> getQueryParameters() { return sqmParamsByQueryParam; diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/ValueBindJpaCriteriaParameter.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/ValueBindJpaCriteriaParameter.java index 551041b1fbc1..880520028197 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/ValueBindJpaCriteriaParameter.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/expression/ValueBindJpaCriteriaParameter.java @@ -10,6 +10,7 @@ import org.hibernate.query.sqm.tree.SqmCopyContext; import org.hibernate.query.sqm.tree.SqmRenderContext; +import java.util.Objects; /** @@ -59,12 +60,34 @@ public int compareTo(SqmParameter parameter) { } @Override - public boolean equals(Object object) { - return this == object; + public boolean equals(Object obj) { + if ( this == obj ) { + return true; + } + if ( obj instanceof ValueBindJpaCriteriaParameter that ) { + if ( value == null ) { + return that.value == null && Objects.equals( getNodeType(), that.getNodeType() ); + } + final var javaType = getJavaTypeDescriptor(); + if ( that.value != null ) { + if ( javaType != null ) { + //noinspection unchecked + return javaType.equals( that.getJavaTypeDescriptor() ) && javaType.areEqual( value, (T) that.value ); + } + else { + return that.getJavaTypeDescriptor() == null && value.equals( that.value ); + } + } + } + return false; } @Override public int hashCode() { - return super.hashCode(); + if ( value == null ) { + return 0; + } + final var javaType = getJavaTypeDescriptor(); + return javaType == null ? value.hashCode() : javaType.extractHashCode( value ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/jpa/ParameterCollector.java b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/jpa/ParameterCollector.java index ea302c20026a..633e5f2451a2 100644 --- a/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/jpa/ParameterCollector.java +++ b/hibernate-core/src/main/java/org/hibernate/query/sqm/tree/jpa/ParameterCollector.java @@ -5,7 +5,7 @@ package org.hibernate.query.sqm.tree.jpa; import java.util.Collections; -import java.util.HashSet; +import java.util.IdentityHashMap; import java.util.Set; import java.util.function.Consumer; @@ -132,7 +132,7 @@ private BindableType getInferredParameterType(JpaCriteriaParameter exp private > T visitParameter(T param) { if ( parameterExpressions == null ) { - parameterExpressions = new HashSet<>(); + parameterExpressions = Collections.newSetFromMap( new IdentityHashMap<>() ); } parameterExpressions.add( param ); consumer.accept( param ); @@ -141,7 +141,7 @@ private > T visitParameter(T param) { private SqmJpaCriteriaParameterWrapper visitParameter(SqmJpaCriteriaParameterWrapper param) { if ( parameterExpressions == null ) { - parameterExpressions = new HashSet<>(); + parameterExpressions = Collections.newSetFromMap( new IdentityHashMap<>() ); } parameterExpressions.add( param.getJpaCriteriaParameter() ); consumer.accept( param );