diff --git a/src/NHibernate.Test/NHSpecificTest/NH1452/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/NH1452/Fixture.cs new file mode 100644 index 00000000000..d82d0ceae14 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH1452/Fixture.cs @@ -0,0 +1,135 @@ +using NHibernate.Cfg; +using NHibernate.Criterion; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.NH1452 +{ + [TestFixture] + public class Fixture : BugTestCase + { + protected override void Configure(Configuration configuration) + { + base.Configure(configuration); + configuration.SetProperty(Environment.FormatSql, "false"); + } + + /// + /// push some data into the database + /// Really functions as a save test also + /// + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var tran = session.BeginTransaction()) + { + session.Save(new Product + { + ProductId = "XO1234", + Id = 1, + Name = "Some product", + Description = "Very good" + }); + + session.Save(new Product + { + ProductId = "XO54321", + Id = 2, + Name = "Other product", + Description = "Very bad" + }); + + tran.Commit(); + } + } + + protected override void OnTearDown() + { + base.OnTearDown(); + + using (var session = OpenSession()) + using (var tran = session.BeginTransaction()) + { + session.Delete("from Product"); + tran.Commit(); + } + } + + [Test] + public void Delete_single_record() + { + using (var session = OpenSession()) + { + var product = new Product + { + ProductId = "XO1111", + Id = 3, + Name = "Test", + Description = "Test" + }; + + session.Save(product); + + session.Flush(); + + session.Delete(product); + session.Flush(); + + session.Clear(); + + //try to query for this product + product = session.CreateCriteria(typeof (Product)) + .Add(Restrictions.Eq("ProductId", "XO1111")) + .UniqueResult(); + + Assert.That(product, Is.Null); + } + } + + [Test] + public void Query_records() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var product = session.CreateCriteria(typeof (Product)) + .Add(Restrictions.Eq("ProductId", "XO1234")) + .UniqueResult(); + + Assert.That(product, Is.Not.Null); + Assert.That(product.Description, Is.EqualTo("Very good")); + + var log = sqlLog.GetWholeLog(); + //needs to be joining on the Id column not the productId + Assert.That(log.Contains("inner join ProductLocalized this_1_ on this_.Id=this_1_.Id"), Is.True); + } + } + + [Test] + public void Update_record() + { + using (var session = OpenSession()) + { + var product = session.CreateCriteria(typeof (Product)) + .Add(Restrictions.Eq("ProductId", "XO1234")) + .UniqueResult(); + + Assert.That(product, Is.Not.Null); + + product.Name = "TestValue"; + product.Description = "TestValue"; + + session.Flush(); + session.Clear(); + + //pull again + product = session.CreateCriteria(typeof (Product)) + .Add(Restrictions.Eq("ProductId", "XO1234")) + .UniqueResult(); + + Assert.That(product, Is.Not.Null); + Assert.That(product.Name, Is.EqualTo("TestValue")); + Assert.That(product.Description, Is.EqualTo("TestValue")); + } + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH1452/Mappings.hbm.xml b/src/NHibernate.Test/NHSpecificTest/NH1452/Mappings.hbm.xml new file mode 100644 index 00000000000..5760a4d4b0f --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH1452/Mappings.hbm.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + diff --git a/src/NHibernate.Test/NHSpecificTest/NH1452/Product.cs b/src/NHibernate.Test/NHSpecificTest/NH1452/Product.cs new file mode 100644 index 00000000000..751efe1f4ea --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH1452/Product.cs @@ -0,0 +1,13 @@ +namespace NHibernate.Test.NHSpecificTest.NH1452 +{ + public class Product + { + public virtual string ProductId { get; set; } + + public virtual int Id { get; set; } + + public virtual string Name { get; set; } + + public virtual string Description { get; set; } + } +} diff --git a/src/NHibernate.Test/NHibernate.Test.csproj b/src/NHibernate.Test/NHibernate.Test.csproj index 569f8ab024d..e0d4a3f97c9 100644 --- a/src/NHibernate.Test/NHibernate.Test.csproj +++ b/src/NHibernate.Test/NHibernate.Test.csproj @@ -847,6 +847,8 @@ + + @@ -3142,6 +3144,7 @@ + diff --git a/src/NHibernate/Cfg/XmlHbmBinding/ClassBinder.cs b/src/NHibernate/Cfg/XmlHbmBinding/ClassBinder.cs index c7ca07972bf..8e3237858b6 100644 --- a/src/NHibernate/Cfg/XmlHbmBinding/ClassBinder.cs +++ b/src/NHibernate/Cfg/XmlHbmBinding/ClassBinder.cs @@ -195,7 +195,33 @@ private void BindJoin(HbmJoin joinMapping, Join join, IDictionary {1}", persistentClass.EntityName, join.Table.Name); // KEY - SimpleValue key = new DependantValue(table, persistentClass.Identifier); + SimpleValue key; + if (!String.IsNullOrEmpty(joinMapping.key.propertyref)) + { + string propertyRef = joinMapping.key.propertyref; + var propertyRefKey = new SimpleValue(persistentClass.Table) + { + IsAlternateUniqueKey = true + }; + var property = persistentClass.GetProperty(propertyRef); + join.RefIdProperty = property; + //we only want one column + var column = (Column) property.ColumnIterator.First(); + if (!column.Unique) + throw new MappingException( + string.Format( + "Property {0}, on class {1} must be marked as unique to be joined to with a property-ref.", + property.Name, + persistentClass.ClassName)); + propertyRefKey.AddColumn(column); + propertyRefKey.TypeName = property.Type.Name; + key = new ReferenceDependantValue(table, propertyRefKey); + } + else + { + key = new DependantValue(table, persistentClass.Identifier); + } + key.ForeignKeyName = joinMapping.key.foreignkey; join.Key = key; key.IsCascadeDeleteEnabled = joinMapping.key.ondelete == HbmOndelete.Cascade; diff --git a/src/NHibernate/Mapping/Join.cs b/src/NHibernate/Mapping/Join.cs index 5ca984e98a2..84e452be4ba 100644 --- a/src/NHibernate/Mapping/Join.cs +++ b/src/NHibernate/Mapping/Join.cs @@ -37,6 +37,9 @@ public void AddProperty(Property prop) prop.PersistentClass = PersistentClass; } + //if we are joining to a non pk, this is the property of the class that serves as id + public Property RefIdProperty { get; set; } + public bool ContainsProperty(Property prop) { return properties.Contains(prop); diff --git a/src/NHibernate/Mapping/ReferenceDependantValue.cs b/src/NHibernate/Mapping/ReferenceDependantValue.cs new file mode 100644 index 00000000000..52100c513d1 --- /dev/null +++ b/src/NHibernate/Mapping/ReferenceDependantValue.cs @@ -0,0 +1,40 @@ +using System; +using System.Collections.Generic; + +namespace NHibernate.Mapping +{ + /// + /// + /// + [Serializable] + public class ReferenceDependantValue : DependantValue + { + private readonly SimpleValue _prototype; + + public ReferenceDependantValue(Table table, SimpleValue prototype) + : base(table, prototype) + { + _prototype = prototype; + } + + public IEnumerable ReferenceColumns + { + get { return _prototype.ConstraintColumns; } + } + + public override void CreateForeignKeyOfEntity(string entityName) + { + if (!HasFormula && !string.Equals("none", ForeignKeyName, StringComparison.InvariantCultureIgnoreCase)) + { + var referencedColumns = new List(_prototype.ColumnSpan); + foreach (Column column in _prototype.ColumnIterator) + { + referencedColumns.Add(column); + } + + ForeignKey fk = Table.CreateForeignKey(ForeignKeyName, ConstraintColumns, entityName, referencedColumns); + fk.CascadeDeleteEnabled = IsCascadeDeleteEnabled; + } + } + } +} diff --git a/src/NHibernate/Mapping/SimpleValue.cs b/src/NHibernate/Mapping/SimpleValue.cs index 46ab042e228..c9ad4cca562 100644 --- a/src/NHibernate/Mapping/SimpleValue.cs +++ b/src/NHibernate/Mapping/SimpleValue.cs @@ -73,7 +73,7 @@ public virtual bool IsComposite #region IKeyValue Members - public void CreateForeignKeyOfEntity(string entityName) + public virtual void CreateForeignKeyOfEntity(string entityName) { if (!HasFormula && ! "none".Equals(ForeignKeyName, StringComparison.InvariantCultureIgnoreCase)) { diff --git a/src/NHibernate/NHibernate.csproj b/src/NHibernate/NHibernate.csproj index 54fa348f4e2..aac4fdb4b00 100644 --- a/src/NHibernate/NHibernate.csproj +++ b/src/NHibernate/NHibernate.csproj @@ -526,6 +526,7 @@ + diff --git a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs index 62ed069d4b2..f847f5eedd5 100644 --- a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs @@ -1011,6 +1011,23 @@ protected virtual void AddDiscriminatorToSelect(SelectFragment select, string na public abstract string GetSubclassTableName(int j); + //gets the identifier for a join table if other than pk + protected virtual object GetJoinTableId(int j, object[] fields) + { + return null; + } + + protected virtual object GetJoinTableId(int table, object obj, EntityMode entityMode) + { + return null; + } + + //for joining to other keys than pk + protected virtual string[] GetJoinIdKeyColumns(int j) + { + return IdentifierColumnNames; + } + protected abstract string[] GetSubclassTableKeyColumns(int j); protected abstract bool IsClassOrSuperclassTable(int j); @@ -1027,6 +1044,25 @@ protected virtual void AddDiscriminatorToSelect(SelectFragment select, string na protected abstract bool IsPropertyOfTable(int property, int table); + protected virtual int? GetRefIdColumnOfTable(int table) + { + return null; + } + + protected virtual Tuple.Property GetIdentiferProperty(int table) + { + var refId = GetRefIdColumnOfTable(table); + if (refId == null) + return entityMetamodel.IdentifierProperty; + + return entityMetamodel.Properties[refId.Value]; + } + + protected virtual bool IsIdOfTable(int property, int table) + { + return false; + } + protected abstract int GetSubclassPropertyTableNumber(int i); public abstract string FilterFragment(string alias); @@ -2423,8 +2459,9 @@ protected int Dehydrate(object id, object[] fields, object rowId, bool[] include } else if (id != null) { - IdentifierType.NullSafeSet(statement, id, index, session); - index += IdentifierColumnSpan; + var property = GetIdentiferProperty(table); + property.Type.NullSafeSet(statement, id, index, session); + index += property.Type.GetColumnSpan(factory); } return index; @@ -2591,6 +2628,9 @@ public virtual SqlString GetSelectByUniqueKeyString(string propertyName) protected void Insert(object id, object[] fields, bool[] notNull, int j, SqlCommandInfo sql, object obj, ISessionImplementor session) { + //check if the id comes from an alternate column + object tableId = GetJoinTableId(j, fields) ?? id; + if (IsInverseTable(j)) { return; @@ -2605,7 +2645,7 @@ protected void Insert(object id, object[] fields, bool[] notNull, int j, if (log.IsDebugEnabled) { - log.Debug("Inserting entity: " + MessageHelper.InfoString(this, id, Factory)); + log.Debug("Inserting entity: " + MessageHelper.InfoString(this, tableId, Factory)); if (j == 0 && IsVersioned) { log.Debug("Version: " + Versioning.GetVersion(fields, this)); @@ -2634,7 +2674,7 @@ protected void Insert(object id, object[] fields, bool[] notNull, int j, // state at the time the insert was issued (cos of foreign key constraints). // Not necessarily the obect's current state - Dehydrate(id, fields, null, notNull, propertyColumnInsertable, j, insertCmd, session, index); + Dehydrate(tableId, fields, null, notNull, propertyColumnInsertable, j, insertCmd, session, index); if (useBatch) { @@ -2666,10 +2706,10 @@ protected void Insert(object id, object[] fields, bool[] notNull, int j, var exceptionContext = new AdoExceptionContextInfo { SqlException = sqle, - Message = "could not insert: " + MessageHelper.InfoString(this, id), + Message = "could not insert: " + MessageHelper.InfoString(this, tableId), Sql = sql.ToString(), EntityName = EntityName, - EntityId = id + EntityId = tableId }; throw ADOExceptionHelper.Convert(Factory.SQLExceptionConverter, exceptionContext); } @@ -2681,6 +2721,9 @@ protected internal virtual void UpdateOrInsert(object id, object[] fields, objec { if (!IsInverseTable(j)) { + //check if the id comes from an alternate column + object tableId = GetJoinTableId(j, fields) ?? id; + bool isRowToUpdate; if (IsNullableTable(j) && oldFields != null && IsAllNull(oldFields, j)) { @@ -2691,13 +2734,13 @@ protected internal virtual void UpdateOrInsert(object id, object[] fields, objec { //if all fields are null, we might need to delete existing row isRowToUpdate = true; - Delete(id, oldVersion, j, obj, SqlDeleteStrings[j], session, null); + Delete(tableId, oldVersion, j, obj, SqlDeleteStrings[j], session, null); } else { //there is probably a row there, so try to update //if no rows were updated, we will find out - isRowToUpdate = Update(id, fields, oldFields, rowId, includeProperty, j, oldVersion, obj, sql, session); + isRowToUpdate = Update(tableId, fields, oldFields, rowId, includeProperty, j, oldVersion, obj, sql, session); } if (!isRowToUpdate && !IsAllNull(fields, j)) @@ -2705,7 +2748,7 @@ protected internal virtual void UpdateOrInsert(object id, object[] fields, objec // assume that the row was not there since it previously had only null // values, so do an INSERT instead //TODO: does not respect dynamic-insert - Insert(id, fields, PropertyInsertability, j, SqlInsertStrings[j], obj, session); + Insert(tableId, fields, PropertyInsertability, j, SqlInsertStrings[j], obj, session); } } } @@ -2824,6 +2867,9 @@ protected bool Update(object id, object[] fields, object[] oldFields, object row public void Delete(object id, object version, int j, object obj, SqlCommandInfo sql, ISessionImplementor session, object[] loadedState) { + //check if the id should come from another column + object tableId = GetJoinTableId(j, obj, session.EntityMode) ?? id; + if (IsInverseTable(j)) { return; @@ -2838,7 +2884,7 @@ public void Delete(object id, object version, int j, object obj, SqlCommandInfo if (log.IsDebugEnabled) { - log.Debug("Deleting entity: " + MessageHelper.InfoString(this, id, Factory)); + log.Debug("Deleting entity: " + MessageHelper.InfoString(this, tableId, Factory)); if (useVersion) { log.Debug("Version: " + version); @@ -2873,8 +2919,9 @@ public void Delete(object id, object version, int j, object obj, SqlCommandInfo // Do the key. The key is immutable so we can use the _current_ object state - not necessarily // the state at the time the delete was issued - IdentifierType.NullSafeSet(statement, id, index, session); - index += IdentifierColumnSpan; + var property = GetIdentiferProperty(j); + property.Type.NullSafeSet(statement, tableId, index, session); + index += property.Type.GetColumnSpan(factory); // We should use the _current_ object state (ie. after any updates that occurred during flush) if (useVersion) @@ -2905,7 +2952,7 @@ public void Delete(object id, object version, int j, object obj, SqlCommandInfo } else { - Check(session.Batcher.ExecuteNonQuery(statement), id, j, expectation, statement); + Check(session.Batcher.ExecuteNonQuery(statement), tableId, j, expectation, statement); } } catch (Exception e) @@ -2929,10 +2976,10 @@ public void Delete(object id, object version, int j, object obj, SqlCommandInfo var exceptionContext = new AdoExceptionContextInfo { SqlException = sqle, - Message = "could not delete: " + MessageHelper.InfoString(this, id, Factory), + Message = "could not delete: " + MessageHelper.InfoString(this, tableId, Factory), Sql = sql.Text.ToString(), EntityName = EntityName, - EntityId = id + EntityId = tableId }; throw ADOExceptionHelper.Convert(Factory.SQLExceptionConverter, exceptionContext); } @@ -3296,11 +3343,12 @@ protected internal virtual bool IsSubclassTableLazy(int j) private JoinFragment CreateJoin(string name, bool innerjoin, bool includeSubclasses) { - string[] idCols = StringHelper.Qualify(name, IdentifierColumnNames); //all joins join to the pk of the driving table JoinFragment join = Factory.Dialect.CreateOuterJoinFragment(); int tableSpan = SubclassTableSpan; for (int j = 1; j < tableSpan; j++) //notice that we skip the first table; it is the driving table! { + string[] idCols = StringHelper.Qualify(name, GetJoinIdKeyColumns(j)); //some joins may be to non primary keys + bool joinIsIncluded = IsClassOrSuperclassTable(j) || (includeSubclasses && !IsSubclassTableSequentialSelect(j) && !IsSubclassTableLazy(j)); if (joinIsIncluded) diff --git a/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs b/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs index 32d3a86b77a..5db5744edb0 100644 --- a/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs @@ -38,6 +38,10 @@ public class SingleTableEntityPersister : AbstractEntityPersister, IQueryable // properties of this class, including inherited properties private readonly int[] propertyTableNumbers; + // if the id is a property of the base table eg join to property-ref + // if the id is not a property the value will be -1 + private readonly Dictionary tableIdPropertyNumbers; + // the closure of all columns used by the entire hierarchy including // subclasses and superclasses of this class private readonly int[] subclassPropertyTableNumberClosure; @@ -68,6 +72,9 @@ public class SingleTableEntityPersister : AbstractEntityPersister, IQueryable private static readonly object NullDiscriminator = new object(); private static readonly object NotNullDiscriminator = new object(); + //provided so we can join to keys other than the primary key + private readonly Dictionary joinToKeyColumns; + public SingleTableEntityPersister(PersistentClass persistentClass, ICacheConcurrencyStrategy cache, ISessionFactoryImplementor factory, IMapping mapping) : base(persistentClass, cache, factory) @@ -158,6 +165,11 @@ public SingleTableEntityPersister(PersistentClass persistentClass, ICacheConcurr bool hasDeferred = false; List subclassTables = new List(); List joinKeyColumns = new List(); + //provided so we can join to keys other than the primary key + joinToKeyColumns = new Dictionary(); + //Columns that also function as Id's + List idColumns = new List(); + tableIdPropertyNumbers = new Dictionary(); List isConcretes = new List(); List isDeferreds = new List(); List isInverses = new List(); @@ -183,6 +195,33 @@ public SingleTableEntityPersister(PersistentClass persistentClass, ICacheConcurr var keyCols = join.Key.ColumnIterator.OfType().Select(col => col.GetQuotedName(factory.Dialect)).ToArray(); joinKeyColumns.Add(keyCols); + + //are we joining to other than the primary key? + if (join.RefIdProperty != null) + { + var curTableIndex = joinKeyColumns.Count - 1; + //there should only ever be one key + var toKeyCols = new List(join.RefIdProperty.ColumnSpan); + foreach (Column col in join.RefIdProperty.ColumnIterator) + { + toKeyCols.Add(col.GetQuotedName(factory.Dialect)); + + //find out what property index this is + int i = 0; + foreach (var prop in persistentClass.PropertyClosureIterator) + { + if (prop == @join.RefIdProperty) + { + tableIdPropertyNumbers.Add(curTableIndex, i); + break; + } + i++; + } + + idColumns.Add(col); + } + joinToKeyColumns.Add(curTableIndex, toKeyCols.ToArray()); + } } subclassTableSequentialSelect = isDeferreds.ToArray(); @@ -487,6 +526,55 @@ protected override bool IsTableCascadeDeleteEnabled(int j) return cascadeDeleteEnabled[j]; } + protected override object GetJoinTableId(int table, object obj, EntityMode entityMode) + { + //0 is the base table there is no join + if (table == 0) + return null; + + //check index first for speed + var refIdColumn = GetRefIdColumnOfTable(table); + if (refIdColumn == null) + return null; + + object[] fields = GetPropertyValues(obj, entityMode); + return GetJoinTableId(table, refIdColumn, fields); + } + + //gets the identifier for a join table if other than pk + protected override object GetJoinTableId(int table, object[] fields) + { + //0 is the base table there is no join + if (table == 0) + return null; + + return GetJoinTableId(table, GetRefIdColumnOfTable(table), fields); + } + + private static object GetJoinTableId(int table, int? index, object[] fields) + { + if (index == null) + return null; + + return fields[index.Value]; + } + + //if the table's id is a reference column, returns the index of that property + //returns null if not found + protected override int? GetRefIdColumnOfTable(int table) + { + int value; + if (tableIdPropertyNumbers.TryGetValue(table, out value)) + return value; + + return null; + } + + protected override bool IsIdOfTable(int property, int table) + { + return GetRefIdColumnOfTable(table) == property; + } + protected override bool IsPropertyOfTable(int property, int table) { return propertyTableNumbers[property] == table; @@ -499,7 +587,7 @@ protected override bool IsSubclassTableSequentialSelect(int table) public override string FromTableFragment(string name) { - return TableName + ' ' + name; + return TableName + " " + name; } public override string FilterFragment(string alias) @@ -671,6 +759,16 @@ private SqlString GenerateSequentialSelect(ILoadable persister) return RenderSelect(tableNumbers.ToArray(), columnNumbers.ToArray(), formulaNumbers.ToArray()); } + //provide columns to join to if the key is other than the primary key + protected override string[] GetJoinIdKeyColumns(int j) + { + string[] key; + if (joinToKeyColumns.TryGetValue(j, out key)) + return key; + + return base.GetJoinIdKeyColumns(j); + } + protected override string[] GetSubclassTableKeyColumns(int j) { return subclassTableKeyColumnClosure[j];