33 using Configuration ;
44 using Configuration . Conventions ;
55 using Extensions ;
6+ using InMemory . Properties ;
67 using Query . Strategies ;
78 using System ;
89 using System . Collections . Generic ;
910 using System . Data ;
11+ using System . Globalization ;
1012 using System . Linq ;
1113 using Transactions ;
1214 using Transactions . Internal ;
@@ -21,7 +23,7 @@ internal class InMemoryRepositoryContext : LinqEnumerableRepositoryContextBase,
2123 private readonly bool _ignoreTransactionWarning ;
2224 private readonly bool _ignoreSqlQueryWarning ;
2325
24- private readonly InMemoryUnderlyingDbContext _underlyingContext ;
26+ private readonly InMemoryDatabase _db ;
2527
2628 #endregion
2729
@@ -39,7 +41,7 @@ public InMemoryRepositoryContext(string databaseName, bool ignoreTransactionWarn
3941 _ignoreTransactionWarning = ignoreTransactionWarning ;
4042 _ignoreSqlQueryWarning = ignoreSqlQueryWarning ;
4143
42- _underlyingContext = new InMemoryUnderlyingDbContext ( databaseName , Conventions ) ;
44+ _db = InMemoryDatabasesCache . Instance . GetDatabase ( databaseName ) ;
4345 }
4446
4547 #endregion
@@ -48,9 +50,9 @@ public InMemoryRepositoryContext(string databaseName, bool ignoreTransactionWarn
4850
4951 protected override IEnumerable < TEntity > AsEnumerable < TEntity > ( IFetchQueryStrategy < TEntity > fetchStrategy )
5052 {
51- return _underlyingContext
53+ return _db
5254 . FindAll < TEntity > ( )
53- . ApplyFetchingOptions ( fetchStrategy , _underlyingContext . FindAll ) ;
55+ . ApplyFetchingOptions ( fetchStrategy , _db . FindAll ) ;
5456 }
5557
5658 #endregion
@@ -61,7 +63,7 @@ protected override IEnumerable<TEntity> AsEnumerable<TEntity>(IFetchQueryStrateg
6163
6264 public void EnsureDeleted ( )
6365 {
64- _underlyingContext . ClearDatabase ( ) ;
66+ _db . Clear ( ) ;
6567 }
6668
6769 #endregion
@@ -80,22 +82,61 @@ public override ITransactionManager BeginTransaction()
8082
8183 public override void Add < TEntity > ( TEntity entity )
8284 {
83- _underlyingContext . Add ( Guard . NotNull ( entity , nameof ( entity ) ) ) ;
85+ Guard . NotNull ( entity , nameof ( entity ) ) ;
86+
87+ var entityType = typeof ( TEntity ) ;
88+ var keyValues = Conventions . GetPrimaryKeyValues ( entity ) ;
89+
90+ if ( _db . TryFind < TEntity > ( keyValues , out object _ ) )
91+ {
92+ throw new InvalidOperationException (
93+ string . Format (
94+ CultureInfo . CurrentCulture ,
95+ Resources . EntityAlreadyBeingTrackedInStore ,
96+ entityType ) ) ;
97+ }
98+
99+ if ( TryGeneratePrimaryKey < TEntity > ( entity , out var newKey ) )
100+ {
101+ // assumes we only have a single key since
102+ // we cannot generated for a composite key anyways
103+ keyValues [ 0 ] = newKey ;
104+ }
105+
106+ _db . AddOrUpdate < TEntity > ( entity , keyValues ) ;
84107 }
85108
86109 public override void Update < TEntity > ( TEntity entity )
87110 {
88- _underlyingContext . Update ( Guard . NotNull ( entity , nameof ( entity ) ) ) ;
111+ Guard . NotNull ( entity , nameof ( entity ) ) ;
112+
113+ var keyValues = Conventions . GetPrimaryKeyValues ( entity ) ;
114+
115+ if ( ! _db . TryFind < TEntity > ( keyValues , out object _ ) )
116+ {
117+ throw new InvalidOperationException ( Resources . EntityNotFoundInStore ) ;
118+ }
119+
120+ _db . AddOrUpdate < TEntity > ( entity , keyValues ) ;
89121 }
90122
91123 public override void Remove < TEntity > ( TEntity entity )
92124 {
93- _underlyingContext . Remove ( Guard . NotNull ( entity , nameof ( entity ) ) ) ;
125+ Guard . NotNull ( entity , nameof ( entity ) ) ;
126+
127+ var keyValues = Conventions . GetPrimaryKeyValues ( entity ) ;
128+
129+ if ( ! _db . TryFind < TEntity > ( keyValues , out object _ ) )
130+ {
131+ throw new InvalidOperationException ( Resources . EntityNotFoundInStore ) ;
132+ }
133+
134+ _db . Remove < TEntity > ( keyValues ) ;
94135 }
95136
96137 public override int SaveChanges ( )
97138 {
98- return _underlyingContext . SaveChanges ( ) ;
139+ return - 1 ;
99140 }
100141
101142 public override IEnumerable < TEntity > ExecuteSqlQuery < TEntity > ( string sql , CommandType cmdType , Dictionary < string , object > parameters , Func < IDataReader , TEntity > projector )
@@ -125,9 +166,12 @@ public override TEntity Find<TEntity>(IFetchQueryStrategy<TEntity> fetchStrategy
125166
126167 if ( fetchStrategy == null )
127168 {
128- var result = _underlyingContext . Find < TEntity > ( keyValues ) ;
169+ if ( _db . TryFind < TEntity > ( keyValues , out object entity ) )
170+ {
171+ return ( TEntity ) Convert . ChangeType ( entity , typeof ( TEntity ) ) ;
172+ }
129173
130- return result ;
174+ return default ( TEntity ) ;
131175 }
132176
133177 return base . Find ( fetchStrategy , keyValues ) ;
@@ -139,7 +183,6 @@ public override TEntity Find<TEntity>(IFetchQueryStrategy<TEntity> fetchStrategy
139183
140184 public override void Dispose ( )
141185 {
142- _underlyingContext . Dispose ( ) ;
143186 base . Dispose ( ) ;
144187 }
145188
0 commit comments