Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.internal.util.collections;

import java.util.AbstractMap;
import java.util.AbstractSet;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Set;


/**
* Utility {@link Map} implementation that uses identity (==) for key comparison and preserves insertion order
*/
public class LinkedIdentityHashMap<K, V> extends AbstractMap<K, V> implements Map<K, V> {
private static final int DEFAULT_INITIAL_CAPACITY = 16; // must be power of two

static final class Node<K, V> implements Map.Entry<K, V> {
final K key;
V value;
Node<K, V> next;
Node<K, V> before;
Node<K, V> after;

Node(K key, V value, Node<K, V> next) {
this.key = key;
this.value = value;
this.next = next;
}

@Override
public K getKey() {
return key;
}

@Override
public V getValue() {
return value;
}

@Override
public V setValue(V newValue) {
final V old = value;
value = newValue;
return old;
}

@Override
public boolean equals(Object o) {
return o instanceof Node<?, ?> node && key == node.key && Objects.equals( value, node.value );
}

@Override
public int hashCode() {
int result = System.identityHashCode( key );
result = 31 * result + Objects.hashCode( value );
return result;
}

@Override
public String toString() {
return key + "=" + value;
}
}

private Node<K, V>[] table;
private int size;

private Node<K, V> head;
private Node<K, V> tail;

private transient Set<Map.Entry<K, V>> entrySetView;

public LinkedIdentityHashMap() {
this( DEFAULT_INITIAL_CAPACITY );
}

public LinkedIdentityHashMap(int initialCapacity) {
if ( initialCapacity < 0 ) {
throw new IllegalArgumentException( "Illegal initial capacity: " + initialCapacity );
}
int cap = 1;
while ( cap < initialCapacity ) {
cap <<= 1;
}
//noinspection unchecked
table = (Node<K, V>[]) new Node[cap];
}

private static int indexFor(int hash, int length) {
return hash & (length - 1);
}

@Override
public V get(Object key) {
final Node<K, V> e = getNode( key );
return e != null ? e.value : null;
}

private Node<K, V> getNode(Object key) {
final int hash = System.identityHashCode( key );
final int idx = indexFor( hash, table.length );
for ( Node<K, V> e = table[idx]; e != null; e = e.next ) {
if ( e.key == key ) {
return e;
}
}
return null;
}

@Override
public boolean containsKey(Object key) {
return getNode( key ) != null;
}

@Override
public boolean containsValue(Object value) {
for ( Node<K, V> e = head; e != null; e = e.after ) {
if ( Objects.equals( e.value, value ) ) {
return true;
}
}
return false;
}

@Override
public V put(K key, V value) {
final int hash = System.identityHashCode( key );
final int idx = indexFor( hash, table.length );
for ( Node<K, V> e = table[idx]; e != null; e = e.next ) {
if ( e.key == key ) {
final V old = e.value;
e.value = value;
return old;
}
}
// not found -> insert
final Node<K, V> newNode = new Node<>( key, value, table[idx] );
table[idx] = newNode;
linkLast( newNode );
size++;
if ( size == table.length ) {
resize();
}
return null;
}

private void linkLast(Node<K, V> node) {
if ( tail == null ) {
head = tail = node;
}
else {
tail.after = node;
node.before = tail;
tail = node;
}
}

@Override
public V remove(Object key) {
final int hash = System.identityHashCode( key );
final int idx = indexFor( hash, table.length );
Node<K, V> prev = null;
for ( Node<K, V> e = table[idx]; e != null; prev = e, e = e.next ) {
if ( e.key == key ) {
// remove from bucket chain
if ( prev == null ) {
table[idx] = e.next;
}
else {
prev.next = e.next;
}
// unlink from insertion-order list
final Node<K, V> b = e.before;
final Node<K, V> a = e.after;
if ( b == null ) {
head = a;
}
else {
b.after = a;
}
if ( a == null ) {
tail = b;
}
else {
a.before = b;
}
size--;
return e.value;
}
}
return null;
}

@Override
public void clear() {
Arrays.fill( table, null );
head = tail = null;
size = 0;
}

@Override
public int size() {
return size;
}

private void resize() {
final int oldCap = table.length;
final int newCap = oldCap << 1;
//noinspection unchecked
final Node<K, V>[] newTable = (Node<K, V>[]) new Node[newCap];
for ( int i = 0; i < oldCap; i++ ) {
Node<K, V> e = table[i];
while ( e != null ) {
final Node<K, V> next = e.next;
final int idx = indexFor( System.identityHashCode( e.key ), newCap );
e.next = newTable[idx];
newTable[idx] = e;
e = next;
}
}
table = newTable;
}

@Override
public Set<Map.Entry<K, V>> entrySet() {
if ( entrySetView == null ) {
entrySetView = new EntrySet();
}
return entrySetView;
}

private final class EntrySet extends AbstractSet<Entry<K, V>> {
@Override
public Iterator<Entry<K, V>> iterator() {
return new Iterator<>() {
private Node<K, V> next = head;
private Node<K, V> lastReturned = null;

@Override
public boolean hasNext() {
return next != null;
}

@Override
public Map.Entry<K, V> next() {
if ( next == null ) {
throw new NoSuchElementException();
}
lastReturned = next;
next = next.after;
return lastReturned;
}

@Override
public void remove() {
if ( lastReturned == null ) {
throw new IllegalStateException();
}
LinkedIdentityHashMap.this.remove( lastReturned.key );
lastReturned = null;
}
};
}

@Override
public int size() {
return LinkedIdentityHashMap.this.size;
}

@Override
public void clear() {
LinkedIdentityHashMap.this.clear();
}

@Override
public boolean contains(Object o) {
if ( !(o instanceof Entry<?, ?> e) ) {
return false;
}
final Node<K, V> n = getNode( e.getKey() );
return n != null && Objects.equals( n.value, e.getValue() );
}

@Override
public boolean remove(Object o) {
if ( !(o instanceof Entry<?, ?> e) ) {
return false;
}
return LinkedIdentityHashMap.this.remove( e.getKey() ) != null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import org.hibernate.internal.util.collections.LinkedIdentityHashMap;
import org.hibernate.query.internal.QueryParameterNamedImpl;
import org.hibernate.query.internal.QueryParameterPositionalImpl;
import org.hibernate.query.spi.QueryParameterImplementor;
Expand All @@ -31,7 +31,7 @@
public class DomainParameterXref {

public static final DomainParameterXref EMPTY = new DomainParameterXref(
new LinkedHashMap<>( 0 ),
new LinkedIdentityHashMap<>( 0 ),
new IdentityHashMap<>( 0 ),
SqmStatement.ParameterResolutions.empty()
);
Expand All @@ -46,8 +46,8 @@ public static DomainParameterXref from(SqmStatement<?> sqmStatement) {
}
else {
final int sqmParamCount = parameterResolutions.getSqmParameters().size();
final LinkedHashMap<QueryParameterImplementor<?>, List<SqmParameter<?>>> sqmParamsByQueryParam =
new LinkedHashMap<>( sqmParamCount );
final Map<QueryParameterImplementor<?>, List<SqmParameter<?>>> sqmParamsByQueryParam =
new LinkedIdentityHashMap<>( sqmParamCount );
final IdentityHashMap<SqmParameter<?>, QueryParameterImplementor<?>> queryParamBySqmParam =
new IdentityHashMap<>( sqmParamCount );

Expand Down Expand Up @@ -118,13 +118,13 @@ else if ( sqmParameter.getExpressible() != null

private final SqmStatement.ParameterResolutions parameterResolutions;

private final LinkedHashMap<QueryParameterImplementor<?>, List<SqmParameter<?>>> sqmParamsByQueryParam;
private final Map<QueryParameterImplementor<?>, List<SqmParameter<?>>> sqmParamsByQueryParam;
private final IdentityHashMap<SqmParameter<?>, QueryParameterImplementor<?>> queryParamBySqmParam;

private Map<SqmParameter<?>,List<SqmParameter<?>>> expansions;

private DomainParameterXref(
LinkedHashMap<QueryParameterImplementor<?>, List<SqmParameter<?>>> sqmParamsByQueryParam,
Map<QueryParameterImplementor<?>, List<SqmParameter<?>>> sqmParamsByQueryParam,
IdentityHashMap<SqmParameter<?>, QueryParameterImplementor<?>> queryParamBySqmParam,
SqmStatement.ParameterResolutions parameterResolutions) {
this.sqmParamsByQueryParam = sqmParamsByQueryParam;
Expand All @@ -148,7 +148,9 @@ public boolean hasParameters() {
}

/**
* Get all of the QueryParameters mapped by this xref
* Get all the QueryParameters mapped by this xref.
* Note that order of parameters is important - parameters are
* included in cache keys for query results caching.
*/
public Map<QueryParameterImplementor<?>, List<SqmParameter<?>>> getQueryParameters() {
return sqmParamsByQueryParam;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.hibernate.query.sqm.tree.SqmCopyContext;
import org.hibernate.query.sqm.tree.SqmRenderContext;

import java.util.Objects;


/**
Expand Down Expand Up @@ -59,12 +60,13 @@ public int compareTo(SqmParameter<T> parameter) {
}

@Override
public boolean equals(Object object) {
return this == object;
public boolean equals(Object obj) {
return obj instanceof ValueBindJpaCriteriaParameter<?> that
&& Objects.equals( value, that.value );
}

@Override
public int hashCode() {
return super.hashCode();
return Objects.hashCode( value );
}
}
Loading
Loading