Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import jakarta.persistence.Table;
import org.hibernate.cfg.QuerySettings;
import org.hibernate.query.hql.HqlTranslator;
import org.hibernate.query.sqm.tree.SqmStatement;
import org.hibernate.testing.memory.MemoryUsageUtil;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.Jira;
import org.hibernate.testing.orm.junit.ServiceRegistry;
Expand Down Expand Up @@ -80,65 +80,15 @@ SELECT AVG(p3.price) FROM Product p3
@Test
public void testParserMemoryUsage(SessionFactoryScope scope) {
final HqlTranslator hqlTranslator = scope.getSessionFactory().getQueryEngine().getHqlTranslator();
final Runtime runtime = Runtime.getRuntime();

// Ensure classes and basic stuff is initialized in case this is the first test run
hqlTranslator.translate( "from AppUser", AppUser.class );
runtime.gc();
runtime.gc();

// Track memory usage before execution
long totalMemoryBefore = runtime.totalMemory();
long usedMemoryBefore = totalMemoryBefore - runtime.freeMemory();

System.out.println("Memory Usage Before Create Query:");
System.out.println("----------------------------");
System.out.println("Total Memory: " + (totalMemoryBefore / 1024) + " KB");
System.out.println("Used Memory : " + (usedMemoryBefore / 1024) + " KB");
System.out.println();

// Create query
SqmStatement<Long> statement = hqlTranslator.translate( HQL, Long.class );

// Track memory usage after execution
long totalMemoryAfter = runtime.totalMemory();
long usedMemoryAfter = totalMemoryAfter - runtime.freeMemory();

System.out.println("Memory Usage After Create Query:");
System.out.println("----------------------------");
System.out.println("Total Memory: " + (totalMemoryAfter / 1024) + " KB");
System.out.println("Used Memory : " + (usedMemoryAfter / 1024) + " KB");
System.out.println();

System.out.println("Memory increase After Parsing:");
System.out.println("----------------------------");
System.out.println("Total Memory increase: " + ((totalMemoryAfter - totalMemoryBefore) / 1024) + " KB");
System.out.println("Used Memory increase : " + ((usedMemoryAfter - usedMemoryBefore) / 1024) + " KB");
System.out.println();

runtime.gc();
runtime.gc();

// Track memory usage after execution
long totalMemoryAfterGc = runtime.totalMemory();
long usedMemoryAfterGc = totalMemoryAfterGc - runtime.freeMemory();

System.out.println("Memory Usage After Create Query and GC:");
System.out.println("----------------------------");
System.out.println("Total Memory: " + (totalMemoryAfterGc / 1024) + " KB");
System.out.println("Used Memory : " + (usedMemoryAfterGc / 1024) + " KB");
System.out.println();

System.out.println("Memory overhead of Parsing:");
System.out.println("----------------------------");
System.out.println("Total Memory increase: " + ((totalMemoryAfter - totalMemoryAfterGc) / 1024) + " KB");
System.out.println("Used Memory increase : " + ((usedMemoryAfter - usedMemoryAfterGc) / 1024) + " KB");
System.out.println();

// During testing, before the fix for HHH-19240, the allocation was around 500+ MB,
// and after the fix it dropped to 170 - 250 MB
final long memoryConsumption = usedMemoryAfter - usedMemoryAfterGc;
assertTrue( usedMemoryAfter - usedMemoryAfterGc < 256_000_000, "Parsing of queries consumes too much memory (" + ( memoryConsumption / 1024 ) + " KB), when at most 256 MB are expected" );
final long memoryUsage = MemoryUsageUtil.estimateMemoryUsage( () -> hqlTranslator.translate( HQL, Long.class ) );
System.out.println( "Memory Consumption: " + (memoryUsage / 1024) + " KB" );
assertTrue( memoryUsage < 256_000_000, "Parsing of queries consumes too much memory (" + ( memoryUsage / 1024 ) + " KB), when at most 256 MB are expected" );
}

@Entity(name = "Address")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.testing.memory;

import java.lang.management.ManagementFactory;
import java.lang.management.MemoryPoolMXBean;
import java.util.List;

final class GlobalMemoryUsageSnapshotter implements MemoryAllocationSnapshotter {

private static final GlobalMemoryUsageSnapshotter INSTANCE = new GlobalMemoryUsageSnapshotter(
ManagementFactory.getMemoryPoolMXBeans()
);

private final List<MemoryPoolMXBean> heapPoolBeans;
private final Runnable gcAndWait;

private GlobalMemoryUsageSnapshotter(List<MemoryPoolMXBean> heapPoolBeans) {
this.heapPoolBeans = heapPoolBeans;
this.gcAndWait = () -> {
for (int i = 0; i < 3; i++) {
System.gc();
try { Thread.sleep(50); } catch (InterruptedException ignored) {}
}
};
}

public static GlobalMemoryUsageSnapshotter getInstance() {
return INSTANCE;
}

@Override
public MemoryAllocationSnapshot snapshot() {
final long peakUsage = heapPoolBeans.stream().mapToLong(p -> p.getPeakUsage().getUsed()).sum();
gcAndWait.run();
final long retainedUsage = heapPoolBeans.stream().mapToLong(p -> p.getUsage().getUsed()).sum();
heapPoolBeans.forEach(MemoryPoolMXBean::resetPeakUsage);
return new GlobalMemoryAllocationSnapshot( peakUsage, retainedUsage );
}

record GlobalMemoryAllocationSnapshot(long peakUsage, long retainedUsage) implements MemoryAllocationSnapshot {

@Override
public long difference(MemoryAllocationSnapshot before) {
// When doing the "before" snapshot, the peak usage is reset.
// Since this object is the "after" snapshot, we can simply estimate the memory usage of an operation
// to be the peak usage of that operation minus the usage after GC
return peakUsage - retainedUsage;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.testing.memory;

import org.checkerframework.checker.nullness.qual.Nullable;

import java.lang.management.ManagementFactory;
import java.lang.management.ThreadMXBean;
import java.lang.reflect.Method;
import java.util.HashMap;

record HotspotPerThreadAllocationSnapshotter(ThreadMXBean threadMXBean) implements MemoryAllocationSnapshotter {

private static final @Nullable HotspotPerThreadAllocationSnapshotter INSTANCE;
private static final Method GET_THREAD_ALLOCATED_BYTES;

static {
ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean();
Method method = null;
try {
@SuppressWarnings("unchecked")
Class<? extends ThreadMXBean> hotspotInterface =
(Class<? extends ThreadMXBean>) Class.forName( "com.sun.management.ThreadMXBean" );
try {
method = hotspotInterface.getMethod( "getThreadAllocatedBytes", long[].class );
}
catch (Exception e) {
// Ignore
}

if ( !hotspotInterface.isInstance( threadMXBean ) ) {
threadMXBean = ManagementFactory.getPlatformMXBean( hotspotInterface );
}
}
catch (Throwable e) {
// Ignore
}

GET_THREAD_ALLOCATED_BYTES = method;

HotspotPerThreadAllocationSnapshotter instance = null;
if ( method != null && threadMXBean != null ) {
try {
instance = new HotspotPerThreadAllocationSnapshotter( threadMXBean );
instance.snapshot();
}
catch (Exception e) {
instance = null;
}
}
INSTANCE = instance;
}

public static @Nullable HotspotPerThreadAllocationSnapshotter getInstance() {
return INSTANCE;
}

@Override
public MemoryAllocationSnapshot snapshot() {
long[] threadIds = threadMXBean.getAllThreadIds();
try {
return new PerThreadMemoryAllocationSnapshot(
threadIds,
(long[]) GET_THREAD_ALLOCATED_BYTES.invoke( threadMXBean, (Object) threadIds )
);
}
catch (Exception e) {
throw new RuntimeException( e );
}
}

record PerThreadMemoryAllocationSnapshot(long[] threadIds, long[] threadAllocatedBytes)
implements MemoryAllocationSnapshot {

@Override
public long difference(MemoryAllocationSnapshot before) {
final PerThreadMemoryAllocationSnapshot other = (PerThreadMemoryAllocationSnapshot) before;
final HashMap<Long, Integer> previousThreadIdToIndexMap = new HashMap<>();
for ( int i = 0; i < other.threadIds.length; i++ ) {
previousThreadIdToIndexMap.put( other.threadIds[i], i );
}
long allocatedBytes = 0;
for ( int i = 0; i < threadIds.length; i++ ) {
allocatedBytes += threadAllocatedBytes[i];
final Integer previousThreadIndex = previousThreadIdToIndexMap.get( threadIds[i] );
if ( previousThreadIndex != null ) {
allocatedBytes -= other.threadAllocatedBytes[previousThreadIndex];
}
}
return allocatedBytes;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.testing.memory;

import org.checkerframework.checker.nullness.qual.Nullable;

import java.lang.management.ManagementFactory;
import java.lang.management.ThreadMXBean;
import java.lang.reflect.Method;

record HotspotTotalThreadBytesSnapshotter(ThreadMXBean threadMXBean) implements MemoryAllocationSnapshotter {

private static final @Nullable HotspotTotalThreadBytesSnapshotter INSTANCE;
private static final Method GET_TOTAL_THREAD_ALLOCATED_BYTES;

static {
ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean();
Method method = null;
try {
@SuppressWarnings("unchecked")
Class<? extends ThreadMXBean> hotspotInterface =
(Class<? extends ThreadMXBean>) Class.forName( "com.sun.management.ThreadMXBean" );
try {
method = hotspotInterface.getMethod( "getTotalThreadAllocatedBytes" );
}
catch (Exception e) {
// Ignore
}

if ( !hotspotInterface.isInstance( threadMXBean ) ) {
threadMXBean = ManagementFactory.getPlatformMXBean( hotspotInterface );
}
}
catch (Throwable e) {
// Ignore
}

GET_TOTAL_THREAD_ALLOCATED_BYTES = method;

HotspotTotalThreadBytesSnapshotter instance = null;
if ( method != null && threadMXBean != null ) {
try {
instance = new HotspotTotalThreadBytesSnapshotter( threadMXBean );
instance.snapshot();
}
catch (Exception e) {
instance = null;
}
}
INSTANCE = instance;
}

public static @Nullable HotspotTotalThreadBytesSnapshotter getInstance() {
return INSTANCE;
}

@Override
public MemoryAllocationSnapshot snapshot() {
try {
return new GlobalMemoryAllocationSnapshot( (long) GET_TOTAL_THREAD_ALLOCATED_BYTES.invoke( threadMXBean ) );
}
catch (Exception e) {
throw new RuntimeException( e );
}
}

record GlobalMemoryAllocationSnapshot(long allocatedBytes) implements MemoryAllocationSnapshot {

GlobalMemoryAllocationSnapshot {
if ( allocatedBytes == -1L ) {
throw new IllegalArgumentException( "getTotalThreadAllocatedBytes is disabled" );
}
}

@Override
public long difference(MemoryAllocationSnapshot before) {
final GlobalMemoryAllocationSnapshot other = (GlobalMemoryAllocationSnapshot) before;
return Math.max( allocatedBytes - other.allocatedBytes, 0L );
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.testing.memory;

interface MemoryAllocationSnapshot {
long difference(MemoryAllocationSnapshot before);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.testing.memory;

interface MemoryAllocationSnapshotter {
MemoryAllocationSnapshot snapshot();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.testing.memory;

public class MemoryUsageUtil {

private static final MemoryAllocationSnapshotter SNAPSHOTTER;

static {
MemoryAllocationSnapshotter snapshotter = HotspotTotalThreadBytesSnapshotter.getInstance();
if ( snapshotter == null ) {
snapshotter = HotspotPerThreadAllocationSnapshotter.getInstance();
}
if ( snapshotter == null ) {
snapshotter = GlobalMemoryUsageSnapshotter.getInstance();
}
SNAPSHOTTER = snapshotter;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be useful to log which snapshotter implementation is selected, in case one proves more prone to failures than others.

System.out.println("MemoryUsageUtil: Using " + snapshotter.getClass().getSimpleName());

}

public static long estimateMemoryUsage(Runnable runnable) {
final MemoryAllocationSnapshot beforeSnapshot = SNAPSHOTTER.snapshot();
runnable.run();
return SNAPSHOTTER.snapshot().difference( beforeSnapshot );
}
}