diff --git a/src/NHibernate.Test/TypesTest/ChangeDefaultTypeClass.cs b/src/NHibernate.Test/TypesTest/ChangeDefaultTypeClass.cs index 1dca978091e..ec5640460a7 100644 --- a/src/NHibernate.Test/TypesTest/ChangeDefaultTypeClass.cs +++ b/src/NHibernate.Test/TypesTest/ChangeDefaultTypeClass.cs @@ -7,5 +7,10 @@ public class ChangeDefaultTypeClass public int Id { get; set; } public DateTime NormalDateTimeValue { get; set; } = DateTime.Today; + + public string StringTypeLengthInType25 { get; set; } + public string StringTypeExplicitLength20 { get; set; } + public decimal CurrencyTypePrecisionInType5And2 { get; set; } + public decimal CurrencyTypeExplicitPrecision6And3 { get; set; } } } diff --git a/src/NHibernate.Test/TypesTest/ChangeDefaultTypeClass.hbm.xml b/src/NHibernate.Test/TypesTest/ChangeDefaultTypeClass.hbm.xml index f2cfaf24dec..129615cd2db 100644 --- a/src/NHibernate.Test/TypesTest/ChangeDefaultTypeClass.hbm.xml +++ b/src/NHibernate.Test/TypesTest/ChangeDefaultTypeClass.hbm.xml @@ -12,5 +12,9 @@ + + + + diff --git a/src/NHibernate.Test/TypesTest/ChangeDefaultTypeWithLengthFixture.cs b/src/NHibernate.Test/TypesTest/ChangeDefaultTypeWithLengthFixture.cs new file mode 100644 index 00000000000..b499568f1da --- /dev/null +++ b/src/NHibernate.Test/TypesTest/ChangeDefaultTypeWithLengthFixture.cs @@ -0,0 +1,113 @@ +using System; +using System.Collections.Generic; +using System.Reflection; +using NHibernate.Cfg; +using NHibernate.Engine; +using NHibernate.Impl; +using NHibernate.SqlTypes; +using NHibernate.Type; +using NUnit.Framework; + +namespace NHibernate.Test.TypesTest +{ + /// + /// TestFixtures for changing a default .Net type. + /// + [TestFixture] + public class ChangeDefaultTypeWithLengthFixture : TypeFixtureBase + { + public class CustomStringType : AbstractStringType + { + public CustomStringType() : base(new StringSqlType()) + { + } + + public CustomStringType(int length) : base(new StringSqlType(length)) + { + } + + public override string Name => "CustomStringType"; + } + + protected override string TypeName => "ChangeDefaultType"; + + private IType _originalDefaultStringType; + private IType _testDefaultStringType; + private static System.Type _replacedType = typeof(string); + + protected override void Configure(Configuration configuration) + { + _originalDefaultStringType = TypeFactory.GetDefaultTypeFor(_replacedType); + Assert.That(_originalDefaultStringType, Is.Not.Null); + _testDefaultStringType = new CustomStringType(); + + TypeFactory.RegisterType( + _replacedType, + _testDefaultStringType, + new[] {"string"}, + length => new CustomStringType(length)); + base.Configure(configuration); + } + + protected override void DropSchema() + { + base.DropSchema(); + TypeFactory.ClearCustomRegistrations(); + Assert.That(TypeFactory.GetDefaultTypeFor(_replacedType), Is.Not.EqualTo(_testDefaultStringType)); + } + + [Test] + public void DefaultType() + { + Assert.That(TypeFactory.GetDefaultTypeFor(_replacedType), Is.EqualTo(_testDefaultStringType)); + } + + [Test] + public void PropertyType() + { + var propertyType25 = Sfi.GetClassMetadata(typeof(ChangeDefaultTypeClass)) + .GetPropertyType(nameof(ChangeDefaultTypeClass.StringTypeLengthInType25)); + Assert.That( + propertyType25, + Is.EqualTo(_testDefaultStringType)); + Assert.That(propertyType25.SqlTypes(Sfi)[0].Length, Is.EqualTo(25)); + + var propertyType20 = Sfi.GetClassMetadata(typeof(ChangeDefaultTypeClass)) + .GetPropertyType(nameof(ChangeDefaultTypeClass.StringTypeExplicitLength20)); + + Assert.That( + propertyType20, + Is.EqualTo(_testDefaultStringType)); + Assert.That(propertyType20.SqlTypes(Sfi)[0].Length, Is.EqualTo(20)); + } + + [Test] + public void GuessType() + { + Assert.That(NHibernateUtil.GuessType(_replacedType), Is.EqualTo(_testDefaultStringType)); + } + + [Test] + public void ParameterType() + { + var namedParametersField = typeof(AbstractQueryImpl) + .GetField("namedParameters", BindingFlags.Instance | BindingFlags.NonPublic); + Assert.That(namedParametersField, Is.Not.Null, "Missing internal field"); + + using (var s = OpenSession()) + { + // Query where the parameter type cannot be deducted from compared entity property + var q = s.CreateQuery($"from {nameof(ChangeDefaultTypeClass)} where :str1 = :str2 or :str1 = :str3") + .SetParameter("str1", "aaa") + .SetString("str2", "bbb") + .SetAnsiString("str3", "bbb"); + + var namedParameters = namedParametersField.GetValue(q) as Dictionary; + Assert.That(namedParameters, Is.Not.Null, "Unable to retrieve parameters internal field"); + Assert.That(namedParameters["str1"].Type, Is.EqualTo(_testDefaultStringType)); + Assert.That(namedParameters["str2"].Type, Is.EqualTo(NHibernateUtil.String)); + Assert.That(namedParameters["str3"].Type, Is.EqualTo(NHibernateUtil.AnsiString)); + } + } + } +} diff --git a/src/NHibernate.Test/TypesTest/ChangeDefaultTypeWithPrecisionFixture.cs b/src/NHibernate.Test/TypesTest/ChangeDefaultTypeWithPrecisionFixture.cs new file mode 100644 index 00000000000..456a81b8f64 --- /dev/null +++ b/src/NHibernate.Test/TypesTest/ChangeDefaultTypeWithPrecisionFixture.cs @@ -0,0 +1,114 @@ +using System.Collections.Generic; +using System.Data; +using System.Reflection; +using NHibernate.Cfg; +using NHibernate.Engine; +using NHibernate.Impl; +using NHibernate.SqlTypes; +using NHibernate.Type; +using NUnit.Framework; + +namespace NHibernate.Test.TypesTest +{ + /// + /// TestFixtures for changing a default .Net type. + /// + [TestFixture] + public class ChangeDefaultTypeWithPrecisionFixture : TypeFixtureBase + { + public class CustomCurrencyType : DecimalType + { + public CustomCurrencyType() : base(SqlTypeFactory.Currency) + { + } + + public CustomCurrencyType(byte precision, byte scale) : base(new SqlType(DbType.Currency, precision, scale)) + { + } + + public override string Name => "CustomCurrencyType"; + } + + protected override string TypeName => "ChangeDefaultType"; + + private IType _originalDefaultType; + private IType _testDefaultType; + private static System.Type _replacedType = typeof(decimal); + + protected override void Configure(Configuration configuration) + { + _originalDefaultType = TypeFactory.GetDefaultTypeFor(_replacedType); + _testDefaultType = new CustomCurrencyType(); + Assert.That(_originalDefaultType, Is.Not.Null); + Assert.That(_originalDefaultType, Is.Not.EqualTo(_testDefaultType)); + + TypeFactory.RegisterType( + _replacedType, + _testDefaultType, + new[] {"currency"}, + (precision, scale) => new CustomCurrencyType(precision, scale)); + base.Configure(configuration); + } + + protected override void DropSchema() + { + base.DropSchema(); + TypeFactory.ClearCustomRegistrations(); + Assert.That(TypeFactory.GetDefaultTypeFor(_replacedType), Is.Not.EqualTo(_testDefaultType)); + } + + [Test] + public void DefaultType() + { + Assert.That(TypeFactory.GetDefaultTypeFor(_replacedType), Is.EqualTo(_testDefaultType)); + } + + [Test] + public void PropertyType() + { + var propertyType1 = Sfi.GetClassMetadata(typeof(ChangeDefaultTypeClass)) + .GetPropertyType(nameof(ChangeDefaultTypeClass.CurrencyTypeExplicitPrecision6And3)); + Assert.That( + propertyType1, + Is.EqualTo(_testDefaultType)); + Assert.That(propertyType1.SqlTypes(Sfi)[0].Precision, Is.EqualTo(6)); + Assert.That(propertyType1.SqlTypes(Sfi)[0].Scale, Is.EqualTo(3)); + + var propertyType2 = Sfi.GetClassMetadata(typeof(ChangeDefaultTypeClass)) + .GetPropertyType(nameof(ChangeDefaultTypeClass.CurrencyTypePrecisionInType5And2)); + + Assert.That( + propertyType2, + Is.EqualTo(_testDefaultType)); + Assert.That(propertyType2.SqlTypes(Sfi)[0].Precision, Is.EqualTo(5)); + Assert.That(propertyType2.SqlTypes(Sfi)[0].Scale, Is.EqualTo(2)); + } + + [Test] + public void GuessType() + { + Assert.That(NHibernateUtil.GuessType(_replacedType), Is.EqualTo(_testDefaultType)); + } + + [Test] + public void ParameterType() + { + var namedParametersField = typeof(AbstractQueryImpl) + .GetField("namedParameters", BindingFlags.Instance | BindingFlags.NonPublic); + Assert.That(namedParametersField, Is.Not.Null, "Missing internal field"); + + using (var s = OpenSession()) + { + // Query where the parameter type cannot be deducted from compared entity property + var q = s.CreateQuery($"from {nameof(ChangeDefaultTypeClass)} where :str1 = :str2") + .SetParameter("str1", 1.22m) + .SetDecimal("str2", 1m); + + var namedParameters = namedParametersField.GetValue(q) as Dictionary; + Assert.That(namedParameters, Is.Not.Null, "Unable to retrieve parameters internal field"); + Assert.That(namedParameters["str1"].Type, Is.EqualTo(_testDefaultType)); + Assert.That(namedParameters["str2"].Type, Is.EqualTo(NHibernateUtil.Decimal)); + } + } + } +} diff --git a/src/NHibernate/Type/TypeFactory.cs b/src/NHibernate/Type/TypeFactory.cs index 3ce5f04001f..c75cc18f45f 100644 --- a/src/NHibernate/Type/TypeFactory.cs +++ b/src/NHibernate/Type/TypeFactory.cs @@ -91,9 +91,9 @@ private enum TypeClassification private static readonly ConcurrentDictionary getTypeDelegatesWithPrecision = new ConcurrentDictionary(); - private delegate NullableType GetNullableTypeWithLengthOrScale(int lengthOrScale); // Func + public delegate NullableType GetNullableTypeWithLengthOrScale(int lengthOrScale); // Func - private delegate NullableType GetNullableTypeWithPrecision(byte precision, byte scale); + public delegate NullableType GetNullableTypeWithPrecision(byte precision, byte scale); private delegate NullableType NullableTypeCreatorDelegate(SqlType sqlType); @@ -113,22 +113,66 @@ public static void RegisterType(System.Type systemType, IType nhibernateType, IE RegisterType(nhibernateType, typeAliases); } - private static void RegisterType(System.Type systemType, IType nhibernateType, - IEnumerable aliases, GetNullableTypeWithLengthOrScale ctorLengthOrScale) + /// + /// Defines which NHibernate type should be chosen by default for handling a given .Net type. + /// This must be done before any operation on NHibernate, including building its + /// and building session factory. Otherwise the behavior will be undefined. + /// + /// The .Net type. + /// The NHibernate type. + /// The additional aliases to map to the type. Use if none. + /// The factory method to create the NHibernate type using length or scale. + public static void RegisterType( + System.Type systemType, + IType nhibernateType, + IEnumerable aliases, + GetNullableTypeWithLengthOrScale ctorLengthOrScale) + { + RegisterType(systemType, nhibernateType, aliases, ctorLengthOrScale, true); + } + + private static void RegisterType( + System.Type systemType, + IType nhibernateType, + IEnumerable aliases, + GetNullableTypeWithLengthOrScale ctorLengthOrScale, + bool @override) { var typeAliases = new List(aliases); typeAliases.AddRange(GetClrTypeAliases(systemType)); - RegisterType(nhibernateType, typeAliases, ctorLengthOrScale); + RegisterType(nhibernateType, typeAliases, ctorLengthOrScale, @override); } - private static void RegisterType(System.Type systemType, IType nhibernateType, - IEnumerable aliases, GetNullableTypeWithPrecision ctorPrecision) + /// + /// Defines which NHibernate type should be chosen by default for handling a given .Net type. + /// This must be done before any operation on NHibernate, including building its + /// and building session factory. Otherwise the behavior will be undefined. + /// + /// The .Net type. + /// The NHibernate type. + /// The additional aliases to map to the type. Use if none. + /// The factory method to create the NHibernate type using precision. + public static void RegisterType( + System.Type systemType, + IType nhibernateType, + IEnumerable aliases, + GetNullableTypeWithPrecision ctorPrecision) + { + RegisterType(systemType, nhibernateType, aliases, ctorPrecision, true); + } + + private static void RegisterType( + System.Type systemType, + IType nhibernateType, + IEnumerable aliases, + GetNullableTypeWithPrecision ctorPrecision, + bool @override) { var typeAliases = new List(aliases); typeAliases.AddRange(GetClrTypeAliases(systemType)); - RegisterType(nhibernateType, typeAliases, ctorPrecision); + RegisterType(nhibernateType, typeAliases, ctorPrecision, @override); } private static IEnumerable GetClrTypeAliases(System.Type systemType) @@ -158,28 +202,38 @@ private static void RegisterType(IType nhibernateType, IEnumerable alias } } - private static void RegisterType(IType nhibernateType, IEnumerable aliases, GetNullableTypeWithLengthOrScale ctorLengthOrScale) + private static void RegisterType(IType nhibernateType, IEnumerable aliases, GetNullableTypeWithLengthOrScale ctorLengthOrScale, bool @override = false) { var typeAliases = new List(aliases) { nhibernateType.Name }; foreach (var alias in typeAliases) { RegisterTypeAlias(nhibernateType, alias); - if (!_getTypeDelegatesWithLengthOrScale.TryAdd(alias, ctorLengthOrScale)) + if (@override) { - throw new HibernateException("An item with the same key has already been added to getTypeDelegatesWithLength."); + _getTypeDelegatesWithLengthOrScale[alias] = ctorLengthOrScale; + } + else if (!_getTypeDelegatesWithLengthOrScale.TryAdd(alias, ctorLengthOrScale)) + { + throw new HibernateException( + "An item with the same key has already been added to getTypeDelegatesWithLength."); } } } - private static void RegisterType(IType nhibernateType, IEnumerable aliases, GetNullableTypeWithPrecision ctorPrecision) + private static void RegisterType(IType nhibernateType, IEnumerable aliases, GetNullableTypeWithPrecision ctorPrecision, bool @override = false) { var typeAliases = new List(aliases) { nhibernateType.Name }; foreach (var alias in typeAliases) { RegisterTypeAlias(nhibernateType, alias); - if (!getTypeDelegatesWithPrecision.TryAdd(alias, ctorPrecision)) + if (@override) + { + getTypeDelegatesWithPrecision[alias] = ctorPrecision; + } + else if (!getTypeDelegatesWithPrecision.TryAdd(alias, ctorPrecision)) { - throw new HibernateException("An item with the same key has already been added to getTypeDelegatesWithPrecision."); + throw new HibernateException( + "An item with the same key has already been added to getTypeDelegatesWithPrecision."); } } } @@ -204,13 +258,31 @@ private static void RegisterTypeAlias(IType nhibernateType, string alias) /// static TypeFactory() { - // set up the mappings of .NET Classes/Structs to their NHibernate types. + RegisterTypes(); + } + + private static void RegisterTypes() + { + // set up the mappings of .NET Classes/Structs to their NHibernate types. RegisterDefaultNetTypes(); // add the mappings of the NHibernate specific names that are used in type="" RegisterBuiltInTypes(); } + /// + /// Clears all custom type registrations and re-register all default NHibernate types + /// + public static void ClearCustomRegistrations() + { + typeByTypeOfName.Clear(); + _obsoleteMessageByAlias.Clear(); + _getTypeDelegatesWithLengthOrScale.Clear(); + getTypeDelegatesWithPrecision.Clear(); + + RegisterTypes(); + } + /// /// Register other Default .NET type /// @@ -221,39 +293,47 @@ private static void RegisterDefaultNetTypes() { // NOTE: each .NET type should appear only one time RegisterType(typeof (Byte[]), NHibernateUtil.Binary, new[] {"binary"}, - l => GetType(NHibernateUtil.Binary, l, len => new BinaryType(SqlTypeFactory.GetBinary(len)))); + l => GetType(NHibernateUtil.Binary, l, len => new BinaryType(SqlTypeFactory.GetBinary(len))), + false); - RegisterType(typeof(Boolean), NHibernateUtil.Boolean, new[] { "boolean", "bool" }); + RegisterType(typeof (Boolean), NHibernateUtil.Boolean, new[] { "boolean", "bool" }); RegisterType(typeof (Byte), NHibernateUtil.Byte, new[]{ "byte"}); RegisterType(typeof (Char), NHibernateUtil.Character, new[] {"character", "char"}); RegisterType(typeof (CultureInfo), NHibernateUtil.CultureInfo, new[]{ "locale"}); - RegisterType(typeof(DateTime), NHibernateUtil.DateTime, new[] { "datetime" }, - s => GetType(NHibernateUtil.DateTime, s, scale => new DateTimeType(SqlTypeFactory.GetDateTime((byte)scale)))); + RegisterType(typeof (DateTime), NHibernateUtil.DateTime, new[] { "datetime" }, + s => GetType(NHibernateUtil.DateTime, s, scale => new DateTimeType(SqlTypeFactory.GetDateTime((byte)scale))), + false); RegisterType(typeof (DateTimeOffset), NHibernateUtil.DateTimeOffset, new[]{ "datetimeoffset"}, - s => GetType(NHibernateUtil.DateTimeOffset, s, scale => new DateTimeOffsetType(SqlTypeFactory.GetDateTimeOffset((byte)scale)))); + s => GetType(NHibernateUtil.DateTimeOffset, s, scale => new DateTimeOffsetType(SqlTypeFactory.GetDateTimeOffset((byte)scale))), + false); RegisterType(typeof (Decimal), NHibernateUtil.Decimal, new[] {"big_decimal", "decimal"}, - (p, s) => GetType(NHibernateUtil.Decimal, p, s, st => new DecimalType(st))); + (p, s) => GetType(NHibernateUtil.Decimal, p, s, st => new DecimalType(st)), + false); RegisterType(typeof (Double), NHibernateUtil.Double, new[] {"double"}, - (p, s) => GetType(NHibernateUtil.Double, p, s, st => new DoubleType(st))); + (p, s) => GetType(NHibernateUtil.Double, p, s, st => new DoubleType(st)), + false); RegisterType(typeof (Guid), NHibernateUtil.Guid, new[]{ "guid"}); RegisterType(typeof (Int16), NHibernateUtil.Int16, new[]{ "short"}); RegisterType(typeof (Int32), NHibernateUtil.Int32, new[] {"integer", "int"}); RegisterType(typeof (Int64), NHibernateUtil.Int64, new[]{ "long"}); - RegisterType(typeof(SByte), NHibernateUtil.SByte, EmptyAliases); + RegisterType(typeof (SByte), NHibernateUtil.SByte, EmptyAliases); RegisterType(typeof (Single), NHibernateUtil.Single, new[] {"float", "single"}, - (p, s) => GetType(NHibernateUtil.Single, p, s, st => new SingleType(st))); + (p, s) => GetType(NHibernateUtil.Single, p, s, st => new SingleType(st)), + false); RegisterType(typeof (String), NHibernateUtil.String, new[] {"string"}, - l => GetType(NHibernateUtil.String, l, len => new StringType(SqlTypeFactory.GetString(len)))); + l => GetType(NHibernateUtil.String, l, len => new StringType(SqlTypeFactory.GetString(len))), + false); RegisterType(typeof (TimeSpan), NHibernateUtil.TimeSpan, new[] {"timespan"}); RegisterType(typeof (System.Type), NHibernateUtil.Class, new[] {"class"}, - l => GetType(NHibernateUtil.Class, l, len => new TypeType(SqlTypeFactory.GetString(len)))); + l => GetType(NHibernateUtil.Class, l, len => new TypeType(SqlTypeFactory.GetString(len))), + false); RegisterType(typeof (UInt16), NHibernateUtil.UInt16, new[] {"ushort"}); RegisterType(typeof (UInt32), NHibernateUtil.UInt32, new[] {"uint"}); @@ -263,7 +343,7 @@ private static void RegisterDefaultNetTypes() RegisterType(typeof (Uri), NHibernateUtil.Uri, new[] {"uri", "url"}); - RegisterType(typeof(XDocument), NHibernateUtil.XDoc, new[] { "xdoc", "xdocument" }); + RegisterType(typeof (XDocument), NHibernateUtil.XDoc, new[] { "xdoc", "xdocument" }); // object needs to have both class and serializable setup before it can // be created. @@ -574,7 +654,7 @@ public static IType HeuristicType(string typeName, IDictionary p if (typeClassification == TypeClassification.LengthOrScale) { parsedTypeName = typeName.Split(LengthSplit); - if (!int.TryParse(parsedTypeName[1], out int parsedLength)) + if (!Int32.TryParse(parsedTypeName[1], out int parsedLength)) { throw new MappingException($"Could not parse length value '{parsedTypeName[1]}' as int for type '{typeName}'"); } @@ -992,4 +1072,4 @@ public static void InjectParameters(Object type, IDictionary par } } } -} \ No newline at end of file +}