diff --git a/hibernate-core/src/main/java/org/hibernate/boot/beanvalidation/BeanValidationEventListener.java b/hibernate-core/src/main/java/org/hibernate/boot/beanvalidation/BeanValidationEventListener.java index 2e20bbae47eb..2cab523720b2 100644 --- a/hibernate-core/src/main/java/org/hibernate/boot/beanvalidation/BeanValidationEventListener.java +++ b/hibernate-core/src/main/java/org/hibernate/boot/beanvalidation/BeanValidationEventListener.java @@ -8,8 +8,9 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; +import org.hibernate.SessionFactory; +import org.hibernate.SessionFactoryObserver; import org.hibernate.boot.internal.ClassLoaderAccessImpl; import org.hibernate.boot.registry.classloading.spi.ClassLoaderService; import org.hibernate.engine.spi.SessionFactoryImplementor; @@ -31,11 +32,9 @@ import jakarta.validation.ConstraintViolation; import jakarta.validation.ConstraintViolationException; -import jakarta.validation.TraversableResolver; import jakarta.validation.Validator; import jakarta.validation.ValidatorFactory; -import static jakarta.validation.Validation.buildDefaultValidatorFactory; import static org.hibernate.internal.util.NullnessUtil.castNonNull; import static org.hibernate.internal.util.collections.CollectionHelper.setOfSize; @@ -47,7 +46,8 @@ */ //FIXME review exception model public class BeanValidationEventListener - implements PreInsertEventListener, PreUpdateEventListener, PreDeleteEventListener, PreUpsertEventListener, PreCollectionUpdateEventListener { + implements PreInsertEventListener, PreUpdateEventListener, PreDeleteEventListener, PreUpsertEventListener, PreCollectionUpdateEventListener, + SessionFactoryObserver { private static final CoreMessageLogger LOG = Logger.getMessageLogger( MethodHandles.lookup(), @@ -55,32 +55,25 @@ public class BeanValidationEventListener BeanValidationEventListener.class.getName() ); - private ValidatorFactory factory; - private final ConcurrentHashMap> associationsPerEntityPersister = new ConcurrentHashMap<>(); + private HibernateTraversableResolver traversableResolver; + private Validator validator; private GroupsPerOperation groupsPerOperation; - boolean initialized; - - /** - * Constructor used in an environment where validator factory is injected (JPA2). - * - * @param factory The {@code ValidatorFactory} to use to create {@code Validator} instance(s) - * @param settings Configured properties - */ - public BeanValidationEventListener( - ValidatorFactory factory, Map settings, ClassLoaderService classLoaderService) { - init( factory, settings, classLoaderService ); - } - public void initialize(Map settings, ClassLoaderService classLoaderService) { - if ( !initialized ) { - init( buildDefaultValidatorFactory(), settings, classLoaderService ); - } + public BeanValidationEventListener( + ValidatorFactory factory, Map settings, ClassLoaderService classLoaderService) { + traversableResolver = new HibernateTraversableResolver(); + validator = factory.usingContext() + .traversableResolver( traversableResolver ) + .getValidator(); + groupsPerOperation = GroupsPerOperation.from( settings, new ClassLoaderAccessImpl( classLoaderService ) ); } - private void init(ValidatorFactory factory, Map settings, ClassLoaderService classLoaderService) { - this.factory = factory; - groupsPerOperation = GroupsPerOperation.from( settings, new ClassLoaderAccessImpl( classLoaderService ) ); - initialized = true; + @Override + public void sessionFactoryCreated(SessionFactory factory) { + SessionFactoryImplementor implementor = factory.unwrap( SessionFactoryImplementor.class ); + implementor + .getMappingMetamodel() + .forEachEntityDescriptor( entityPersister -> traversableResolver.addPersister( entityPersister, implementor ) ); } public boolean onPreInsert(PreInsertEvent event) { @@ -143,10 +136,6 @@ private void validate( if ( object == null || persister.getRepresentationStrategy().getMode() != RepresentationMode.POJO ) { return; } - TraversableResolver tr = new HibernateTraversableResolver( persister, associationsPerEntityPersister, sessionFactory ); - Validator validator = factory.usingContext() - .traversableResolver( tr ) - .getValidator(); final Class[] groups = groupsPerOperation.get( operation ); if ( groups.length > 0 ) { final Set> constraintViolations = validator.validate( object, groups ); @@ -167,7 +156,7 @@ private void validate( builder.append( toString( groups ) ); builder.append( "\nList of constraint violations:[\n" ); for ( ConstraintViolation violation : constraintViolations ) { - builder.append( "\t" ).append( violation.toString() ).append("\n"); + builder.append( "\t" ).append( violation.toString() ).append( "\n" ); } builder.append( "]" ); diff --git a/hibernate-core/src/main/java/org/hibernate/boot/beanvalidation/HibernateTraversableResolver.java b/hibernate-core/src/main/java/org/hibernate/boot/beanvalidation/HibernateTraversableResolver.java index 33dab6e92338..333819b5d056 100644 --- a/hibernate-core/src/main/java/org/hibernate/boot/beanvalidation/HibernateTraversableResolver.java +++ b/hibernate-core/src/main/java/org/hibernate/boot/beanvalidation/HibernateTraversableResolver.java @@ -5,9 +5,10 @@ package org.hibernate.boot.beanvalidation; import java.lang.annotation.ElementType; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; import org.hibernate.AssertionFailure; import org.hibernate.Hibernate; @@ -30,32 +31,27 @@ * @author Emmanuel Bernard */ public class HibernateTraversableResolver implements TraversableResolver { - private Set associations; + private final Map, Set> associationsPerEntityClass = new HashMap<>(); - public HibernateTraversableResolver( - EntityPersister persister, - ConcurrentHashMap> associationsPerEntityPersister, - SessionFactoryImplementor factory) { - associations = associationsPerEntityPersister.get( persister ); - if ( associations == null ) { - associations = new HashSet<>(); - addAssociationsToTheSetForAllProperties( persister.getPropertyNames(), persister.getPropertyTypes(), "", factory ); - associationsPerEntityPersister.put( persister, associations ); - } + public void addPersister(EntityPersister persister, SessionFactoryImplementor factory) { + Class javaTypeClass = persister.getEntityMappingType().getMappedJavaType().getJavaTypeClass(); + Set associations = new HashSet<>(); + addAssociationsToTheSetForAllProperties( persister.getPropertyNames(), persister.getPropertyTypes(), "", factory, associations ); + associationsPerEntityClass.put( javaTypeClass, associations ); } - private void addAssociationsToTheSetForAllProperties( - String[] names, Type[] types, String prefix, SessionFactoryImplementor factory) { + private static void addAssociationsToTheSetForAllProperties( + String[] names, Type[] types, String prefix, SessionFactoryImplementor factory, Set associations) { final int length = names.length; for( int index = 0 ; index < length; index++ ) { - addAssociationsToTheSetForOneProperty( names[index], types[index], prefix, factory ); + addAssociationsToTheSetForOneProperty( names[index], types[index], prefix, factory, associations ); } } - private void addAssociationsToTheSetForOneProperty( - String name, Type type, String prefix, SessionFactoryImplementor factory) { + private static void addAssociationsToTheSetForOneProperty( + String name, Type type, String prefix, SessionFactoryImplementor factory, Set associations) { if ( type instanceof CollectionType collectionType ) { - addAssociationsToTheSetForOneProperty( name, collectionType.getElementType( factory ), prefix, factory ); + addAssociationsToTheSetForOneProperty( name, collectionType.getElementType( factory ), prefix, factory, associations ); } //ToOne association else if ( type instanceof EntityType || type instanceof AnyType ) { @@ -66,7 +62,8 @@ else if ( type instanceof ComponentType componentType ) { componentType.getPropertyNames(), componentType.getSubtypes(), ( prefix.isEmpty() ? name : prefix + name ) + '.', - factory + factory, + associations ); } } @@ -102,6 +99,6 @@ public boolean isCascadable(Object traversableObject, Class rootBeanType, Path pathToTraversableObject, ElementType elementType) { - return !associations.contains( getStringBasedPath( traversableProperty, pathToTraversableObject ) ); + return !associationsPerEntityClass.getOrDefault( rootBeanType, Set.of() ).contains( getStringBasedPath( traversableProperty, pathToTraversableObject ) ); } } diff --git a/hibernate-core/src/main/java/org/hibernate/boot/beanvalidation/TypeSafeActivator.java b/hibernate-core/src/main/java/org/hibernate/boot/beanvalidation/TypeSafeActivator.java index e07bd7a32cb1..ac36b9fcffee 100644 --- a/hibernate-core/src/main/java/org/hibernate/boot/beanvalidation/TypeSafeActivator.java +++ b/hibernate-core/src/main/java/org/hibernate/boot/beanvalidation/TypeSafeActivator.java @@ -34,6 +34,7 @@ import org.hibernate.engine.config.spi.ConfigurationService; import org.hibernate.engine.config.spi.StandardConverters; import org.hibernate.engine.jdbc.spi.JdbcServices; +import org.hibernate.engine.spi.SessionFactoryImplementor; import org.hibernate.event.service.spi.EventListenerRegistry; import org.hibernate.event.spi.EventType; import org.hibernate.internal.CoreMessageLogger; @@ -115,7 +116,7 @@ else if ( validationModes.contains( ValidationMode.DDL ) ) { public static void applyCallbackListeners(ValidatorFactory validatorFactory, ActivationContext context) { if ( isValidationEnabled( context ) ) { disableNullabilityChecking( context ); - setupListener( validatorFactory, context.getServiceRegistry() ); + setupListener( validatorFactory, context.getServiceRegistry(), context.getSessionFactory() ); } } @@ -141,7 +142,7 @@ private static boolean isCheckNullabilityExplicit(ActivationContext context) { .getSettings().get( CHECK_NULLABILITY ) == null; } - private static void setupListener(ValidatorFactory validatorFactory, SessionFactoryServiceRegistry serviceRegistry) { + private static void setupListener(ValidatorFactory validatorFactory, SessionFactoryServiceRegistry serviceRegistry, SessionFactoryImplementor sessionFactory) { final ClassLoaderService classLoaderService = serviceRegistry.requireService( ClassLoaderService.class ); final ConfigurationService cfgService = serviceRegistry.requireService( ConfigurationService.class ); final BeanValidationEventListener listener = @@ -153,7 +154,7 @@ private static void setupListener(ValidatorFactory validatorFactory, SessionFact listenerRegistry.appendListeners( EventType.PRE_DELETE, listener ); listenerRegistry.appendListeners( EventType.PRE_UPSERT, listener ); listenerRegistry.appendListeners( EventType.PRE_COLLECTION_UPDATE, listener ); - listener.initialize( cfgService.getSettings(), classLoaderService ); + sessionFactory.addObserver( listener ); } private static boolean isConstraintBasedValidationEnabled(ActivationContext context) {