Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import org.hibernate.SessionFactory;
import org.hibernate.SessionFactoryObserver;
import org.hibernate.boot.internal.ClassLoaderAccessImpl;
import org.hibernate.boot.registry.classloading.spi.ClassLoaderService;
import org.hibernate.engine.spi.SessionFactoryImplementor;
Expand All @@ -31,11 +32,9 @@

import jakarta.validation.ConstraintViolation;
import jakarta.validation.ConstraintViolationException;
import jakarta.validation.TraversableResolver;
import jakarta.validation.Validator;
import jakarta.validation.ValidatorFactory;

import static jakarta.validation.Validation.buildDefaultValidatorFactory;
import static org.hibernate.internal.util.NullnessUtil.castNonNull;
import static org.hibernate.internal.util.collections.CollectionHelper.setOfSize;

Expand All @@ -47,40 +46,34 @@
*/
//FIXME review exception model
public class BeanValidationEventListener
implements PreInsertEventListener, PreUpdateEventListener, PreDeleteEventListener, PreUpsertEventListener, PreCollectionUpdateEventListener {
implements PreInsertEventListener, PreUpdateEventListener, PreDeleteEventListener, PreUpsertEventListener, PreCollectionUpdateEventListener,
SessionFactoryObserver {

private static final CoreMessageLogger LOG = Logger.getMessageLogger(
MethodHandles.lookup(),
CoreMessageLogger.class,
BeanValidationEventListener.class.getName()
);

private ValidatorFactory factory;
private final ConcurrentHashMap<EntityPersister, Set<String>> associationsPerEntityPersister = new ConcurrentHashMap<>();
private HibernateTraversableResolver traversableResolver;
private Validator validator;
private GroupsPerOperation groupsPerOperation;
boolean initialized;

/**
* Constructor used in an environment where validator factory is injected (JPA2).
*
* @param factory The {@code ValidatorFactory} to use to create {@code Validator} instance(s)
* @param settings Configured properties
*/
public BeanValidationEventListener(
ValidatorFactory factory, Map<String,Object> settings, ClassLoaderService classLoaderService) {
init( factory, settings, classLoaderService );
}

public void initialize(Map<String,Object> settings, ClassLoaderService classLoaderService) {
if ( !initialized ) {
init( buildDefaultValidatorFactory(), settings, classLoaderService );
}
public BeanValidationEventListener(
ValidatorFactory factory, Map<String, Object> settings, ClassLoaderService classLoaderService) {
traversableResolver = new HibernateTraversableResolver();
validator = factory.usingContext()
.traversableResolver( traversableResolver )
.getValidator();
groupsPerOperation = GroupsPerOperation.from( settings, new ClassLoaderAccessImpl( classLoaderService ) );
}

private void init(ValidatorFactory factory, Map<String,Object> settings, ClassLoaderService classLoaderService) {
this.factory = factory;
groupsPerOperation = GroupsPerOperation.from( settings, new ClassLoaderAccessImpl( classLoaderService ) );
initialized = true;
@Override
public void sessionFactoryCreated(SessionFactory factory) {
SessionFactoryImplementor implementor = factory.unwrap( SessionFactoryImplementor.class );
implementor
.getMappingMetamodel()
.forEachEntityDescriptor( entityPersister -> traversableResolver.addPersister( entityPersister, implementor ) );
}

public boolean onPreInsert(PreInsertEvent event) {
Expand Down Expand Up @@ -143,10 +136,6 @@ private <T> void validate(
if ( object == null || persister.getRepresentationStrategy().getMode() != RepresentationMode.POJO ) {
return;
}
TraversableResolver tr = new HibernateTraversableResolver( persister, associationsPerEntityPersister, sessionFactory );
Validator validator = factory.usingContext()
.traversableResolver( tr )
.getValidator();
final Class<?>[] groups = groupsPerOperation.get( operation );
if ( groups.length > 0 ) {
final Set<ConstraintViolation<T>> constraintViolations = validator.validate( object, groups );
Expand All @@ -167,7 +156,7 @@ private <T> void validate(
builder.append( toString( groups ) );
builder.append( "\nList of constraint violations:[\n" );
for ( ConstraintViolation<?> violation : constraintViolations ) {
builder.append( "\t" ).append( violation.toString() ).append("\n");
builder.append( "\t" ).append( violation.toString() ).append( "\n" );
}
builder.append( "]" );

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
package org.hibernate.boot.beanvalidation;

import java.lang.annotation.ElementType;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import org.hibernate.AssertionFailure;
import org.hibernate.Hibernate;
Expand All @@ -30,32 +31,27 @@
* @author Emmanuel Bernard
*/
public class HibernateTraversableResolver implements TraversableResolver {
private Set<String> associations;
private final Map<Class<?>, Set<String>> associationsPerEntityClass = new HashMap<>();

public HibernateTraversableResolver(
EntityPersister persister,
ConcurrentHashMap<EntityPersister, Set<String>> associationsPerEntityPersister,
SessionFactoryImplementor factory) {
associations = associationsPerEntityPersister.get( persister );
if ( associations == null ) {
associations = new HashSet<>();
addAssociationsToTheSetForAllProperties( persister.getPropertyNames(), persister.getPropertyTypes(), "", factory );
associationsPerEntityPersister.put( persister, associations );
}
public void addPersister(EntityPersister persister, SessionFactoryImplementor factory) {
Class<?> javaTypeClass = persister.getEntityMappingType().getMappedJavaType().getJavaTypeClass();
Set<String> associations = new HashSet<>();
addAssociationsToTheSetForAllProperties( persister.getPropertyNames(), persister.getPropertyTypes(), "", factory, associations );
associationsPerEntityClass.put( javaTypeClass, associations );
}

private void addAssociationsToTheSetForAllProperties(
String[] names, Type[] types, String prefix, SessionFactoryImplementor factory) {
private static void addAssociationsToTheSetForAllProperties(
String[] names, Type[] types, String prefix, SessionFactoryImplementor factory, Set<String> associations) {
final int length = names.length;
for( int index = 0 ; index < length; index++ ) {
addAssociationsToTheSetForOneProperty( names[index], types[index], prefix, factory );
addAssociationsToTheSetForOneProperty( names[index], types[index], prefix, factory, associations );
}
}

private void addAssociationsToTheSetForOneProperty(
String name, Type type, String prefix, SessionFactoryImplementor factory) {
private static void addAssociationsToTheSetForOneProperty(
String name, Type type, String prefix, SessionFactoryImplementor factory, Set<String> associations) {
if ( type instanceof CollectionType collectionType ) {
addAssociationsToTheSetForOneProperty( name, collectionType.getElementType( factory ), prefix, factory );
addAssociationsToTheSetForOneProperty( name, collectionType.getElementType( factory ), prefix, factory, associations );
}
//ToOne association
else if ( type instanceof EntityType || type instanceof AnyType ) {
Expand All @@ -66,7 +62,8 @@ else if ( type instanceof ComponentType componentType ) {
componentType.getPropertyNames(),
componentType.getSubtypes(),
( prefix.isEmpty() ? name : prefix + name ) + '.',
factory
factory,
associations
);
}
}
Expand Down Expand Up @@ -102,6 +99,6 @@ public boolean isCascadable(Object traversableObject,
Class<?> rootBeanType,
Path pathToTraversableObject,
ElementType elementType) {
return !associations.contains( getStringBasedPath( traversableProperty, pathToTraversableObject ) );
return !associationsPerEntityClass.getOrDefault( rootBeanType, Set.of() ).contains( getStringBasedPath( traversableProperty, pathToTraversableObject ) );
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.hibernate.engine.config.spi.ConfigurationService;
import org.hibernate.engine.config.spi.StandardConverters;
import org.hibernate.engine.jdbc.spi.JdbcServices;
import org.hibernate.engine.spi.SessionFactoryImplementor;
import org.hibernate.event.service.spi.EventListenerRegistry;
import org.hibernate.event.spi.EventType;
import org.hibernate.internal.CoreMessageLogger;
Expand Down Expand Up @@ -115,7 +116,7 @@ else if ( validationModes.contains( ValidationMode.DDL ) ) {
public static void applyCallbackListeners(ValidatorFactory validatorFactory, ActivationContext context) {
if ( isValidationEnabled( context ) ) {
disableNullabilityChecking( context );
setupListener( validatorFactory, context.getServiceRegistry() );
setupListener( validatorFactory, context.getServiceRegistry(), context.getSessionFactory() );
}
}

Expand All @@ -141,7 +142,7 @@ private static boolean isCheckNullabilityExplicit(ActivationContext context) {
.getSettings().get( CHECK_NULLABILITY ) == null;
}

private static void setupListener(ValidatorFactory validatorFactory, SessionFactoryServiceRegistry serviceRegistry) {
private static void setupListener(ValidatorFactory validatorFactory, SessionFactoryServiceRegistry serviceRegistry, SessionFactoryImplementor sessionFactory) {
final ClassLoaderService classLoaderService = serviceRegistry.requireService( ClassLoaderService.class );
final ConfigurationService cfgService = serviceRegistry.requireService( ConfigurationService.class );
final BeanValidationEventListener listener =
Expand All @@ -153,7 +154,7 @@ private static void setupListener(ValidatorFactory validatorFactory, SessionFact
listenerRegistry.appendListeners( EventType.PRE_DELETE, listener );
listenerRegistry.appendListeners( EventType.PRE_UPSERT, listener );
listenerRegistry.appendListeners( EventType.PRE_COLLECTION_UPDATE, listener );
listener.initialize( cfgService.getSettings(), classLoaderService );
sessionFactory.addObserver( listener );
}

private static boolean isConstraintBasedValidationEnabled(ActivationContext context) {
Expand Down