Skip to content
This repository was archived by the owner on Dec 24, 2022. It is now read-only.

Commit bac4b11

Browse files
committed
Make all DbScripts context aware
1 parent ac773ed commit bac4b11

File tree

3 files changed

+106
-42
lines changed

3 files changed

+106
-42
lines changed

src/ServiceStack.OrmLite/DbScripts.cs

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,21 @@ public IDbConnection OpenDbConnection(ScriptScopeContext scope, Dictionary<strin
4040
return DbFactory.OpenDbConnection();
4141
}
4242

43+
T dialect<T>(ScriptScopeContext scope, Func<IOrmLiteDialectProvider, T> fn)
44+
{
45+
if (scope.PageResult != null)
46+
{
47+
if (scope.PageResult.Args.TryGetValue(DbInfo, out var oDbInfo) && oDbInfo is ConnectionInfo dbInfo)
48+
return fn(DbFactory.GetDialectProvider(dbInfo));
49+
50+
if (scope.PageResult.Args.TryGetValue(DbConnection, out var oDbConn) && oDbConn is Dictionary<string, object> useDb)
51+
return fn(DbFactory.GetDialectProvider(
52+
providerName:useDb.GetValueOrDefault("providerName")?.ToString(),
53+
namedConnection:useDb.GetValueOrDefault("namedConnection")?.ToString()));
54+
}
55+
return fn(OrmLiteConfig.DialectProvider);
56+
}
57+
4358
public IgnoreResult useDb(ScriptScopeContext scope, Dictionary<string, object> dbConnOptions)
4459
{
4560
if (dbConnOptions == null)
@@ -80,10 +95,8 @@ T exec<T>(Func<IDbConnection, T> fn, ScriptScopeContext scope, object options)
8095
{
8196
try
8297
{
83-
using (var db = OpenDbConnection(scope, options as Dictionary<string, object>))
84-
{
85-
return fn(db);
86-
}
98+
using var db = OpenDbConnection(scope, options as Dictionary<string, object>);
99+
return fn(db);
87100
}
88101
catch (Exception ex)
89102
{
@@ -173,24 +186,31 @@ public string[] dbColumnNames(ScriptScopeContext scope, string tableName, object
173186

174187
public ColumnSchema[] dbColumns(ScriptScopeContext scope, string tableName) => dbColumns(scope, tableName, null);
175188
public ColumnSchema[] dbColumns(ScriptScopeContext scope, string tableName, object options) =>
176-
exec(db => db.GetTableColumns($"SELECT * FROM {sqlQuote(tableName)}"), scope, options);
189+
exec(db => db.GetTableColumns($"SELECT * FROM {sqlQuote(scope,tableName)}"), scope, options);
177190

178191
public ColumnSchema[] dbDesc(ScriptScopeContext scope, string sql) => dbDesc(scope, sql, null);
179192
public ColumnSchema[] dbDesc(ScriptScopeContext scope, string sql, object options) => exec(db => db.GetTableColumns(sql), scope, options);
180193

181-
public string sqlQuote(string name) => OrmLiteConfig.DialectProvider.GetQuotedName(name);
182-
public string sqlConcat(IEnumerable<object> values) => OrmLiteConfig.DialectProvider.SqlConcat(values);
183-
public string sqlCurrency(string fieldOrValue) => OrmLiteConfig.DialectProvider.SqlCurrency(fieldOrValue);
184-
public string sqlCurrency(string fieldOrValue, string symbol) => OrmLiteConfig.DialectProvider.SqlCurrency(fieldOrValue, symbol);
185-
186-
public string sqlBool(bool value) => OrmLiteConfig.DialectProvider.SqlBool(value);
187-
public string sqlTrue() => OrmLiteConfig.DialectProvider.SqlBool(true);
188-
public string sqlFalse() => OrmLiteConfig.DialectProvider.SqlBool(false);
189-
public string sqlLimit(int? offset, int? limit) => padCondition(OrmLiteConfig.DialectProvider.SqlLimit(offset, limit));
190-
public string sqlLimit(int? limit) => padCondition(OrmLiteConfig.DialectProvider.SqlLimit(null, limit));
191-
public string sqlSkip(int? offset) => padCondition(OrmLiteConfig.DialectProvider.SqlLimit(offset, null));
192-
public string sqlTake(int? limit) => padCondition(OrmLiteConfig.DialectProvider.SqlLimit(null, limit));
193-
public string ormliteVar(string name) => OrmLiteConfig.DialectProvider.Variables.TryGetValue(name, out var value) ? value : null;
194+
195+
public string sqlQuote(ScriptScopeContext scope, string name) => dialect(scope, d => d.GetQuotedName(name));
196+
public string sqlConcat(ScriptScopeContext scope, IEnumerable<object> values) => dialect(scope, d => d.SqlConcat(values));
197+
public string sqlCurrency(ScriptScopeContext scope, string fieldOrValue) => dialect(scope, d => d.SqlCurrency(fieldOrValue));
198+
public string sqlCurrency(ScriptScopeContext scope, string fieldOrValue, string symbol) =>
199+
dialect(scope, d => d.SqlCurrency(fieldOrValue, symbol));
200+
201+
public string sqlBool(ScriptScopeContext scope, bool value) => dialect(scope, d => d.SqlBool(value));
202+
public string sqlTrue(ScriptScopeContext scope) => dialect(scope, d => d.SqlBool(true));
203+
public string sqlFalse(ScriptScopeContext scope) => dialect(scope, d => d.SqlBool(false));
204+
public string sqlLimit(ScriptScopeContext scope, int? offset, int? limit) =>
205+
dialect(scope, d => padCondition(d.SqlLimit(offset, limit)));
206+
public string sqlLimit(ScriptScopeContext scope, int? limit) =>
207+
dialect(scope, d => padCondition(d.SqlLimit(null, limit)));
208+
public string sqlSkip(ScriptScopeContext scope, int? offset) =>
209+
dialect(scope, d => padCondition(d.SqlLimit(offset, null)));
210+
public string sqlTake(ScriptScopeContext scope, int? limit) =>
211+
dialect(scope, d => padCondition(d.SqlLimit(null, limit)));
212+
public string ormliteVar(ScriptScopeContext scope, string name) =>
213+
dialect(scope, d => d.Variables.TryGetValue(name, out var value) ? value : null);
194214

195215
public string sqlVerifyFragment(string sql) => sql.SqlVerifyFragment();
196216
public bool isUnsafeSql(string sql) => OrmLiteUtils.isUnsafeSql(sql, OrmLiteUtils.VerifySqlRegEx);

src/ServiceStack.OrmLite/DbScriptsAsync.cs

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using System.Threading.Tasks;
66
using ServiceStack.Data;
77
using ServiceStack.Script;
8-
using ServiceStack.Text;
98

109
namespace ServiceStack.OrmLite
1110
{
@@ -15,7 +14,7 @@ public class TemplateDbFiltersAsync : DbScriptsAsync {}
1514
public partial class DbScriptsAsync : ScriptMethods
1615
{
1716
private const string DbInfo = "__dbinfo"; // Keywords.DbInfo
18-
private const string DbConnection = "__dbconnection"; // useDbConnection global
17+
private const string DbConnection = "__dbconnection"; // useDb global
1918

2019
private IDbConnectionFactory dbFactory;
2120
public IDbConnectionFactory DbFactory
@@ -42,6 +41,21 @@ public async Task<IDbConnection> OpenDbConnectionAsync(ScriptScopeContext scope,
4241
return await DbFactory.OpenAsync();
4342
}
4443

44+
T dialect<T>(ScriptScopeContext scope, Func<IOrmLiteDialectProvider, T> fn)
45+
{
46+
if (scope.PageResult != null)
47+
{
48+
if (scope.PageResult.Args.TryGetValue(DbInfo, out var oDbInfo) && oDbInfo is ConnectionInfo dbInfo)
49+
return fn(DbFactory.GetDialectProvider(dbInfo));
50+
51+
if (scope.PageResult.Args.TryGetValue(DbConnection, out var oDbConn) && oDbConn is Dictionary<string, object> useDb)
52+
return fn(DbFactory.GetDialectProvider(
53+
providerName:useDb.GetValueOrDefault("providerName")?.ToString(),
54+
namedConnection:useDb.GetValueOrDefault("namedConnection")?.ToString()));
55+
}
56+
return fn(OrmLiteConfig.DialectProvider);
57+
}
58+
4559
public IgnoreResult useDb(ScriptScopeContext scope, Dictionary<string, object> dbConnOptions)
4660
{
4761
if (dbConnOptions == null)
@@ -82,11 +96,9 @@ async Task<object> exec<T>(Func<IDbConnection, Task<T>> fn, ScriptScopeContext s
8296
{
8397
try
8498
{
85-
using (var db = await OpenDbConnectionAsync(scope, options as Dictionary<string, object>))
86-
{
87-
var result = await fn(db);
88-
return result;
89-
}
99+
using var db = await OpenDbConnectionAsync(scope, options as Dictionary<string, object>);
100+
var result = await fn(db);
101+
return result;
90102
}
91103
catch (Exception ex)
92104
{
@@ -172,28 +184,34 @@ public Task<object> dbTableNamesWithRowCounts(ScriptScopeContext scope, Dictiona
172184

173185
public Task<object> dbColumnNames(ScriptScopeContext scope, string tableName) => dbColumnNames(scope, tableName, null);
174186
public Task<object> dbColumnNames(ScriptScopeContext scope, string tableName, object options) =>
175-
exec(async db => (await db.GetTableColumnsAsync($"SELECT * FROM {sqlQuote(tableName)}")).Select(x => x.ColumnName).ToArray(), scope, options);
187+
exec(async db => (await db.GetTableColumnsAsync($"SELECT * FROM {sqlQuote(scope,tableName)}")).Select(x => x.ColumnName).ToArray(), scope, options);
176188

177189
public Task<object> dbColumns(ScriptScopeContext scope, string tableName) => dbColumns(scope, tableName, null);
178190
public Task<object> dbColumns(ScriptScopeContext scope, string tableName, object options) =>
179-
exec(db => db.GetTableColumnsAsync($"SELECT * FROM {sqlQuote(tableName)}"), scope, options);
191+
exec(db => db.GetTableColumnsAsync($"SELECT * FROM {sqlQuote(scope,tableName)}"), scope, options);
180192

181193
public Task<object> dbDesc(ScriptScopeContext scope, string sql) => dbDesc(scope, sql, null);
182194
public Task<object> dbDesc(ScriptScopeContext scope, string sql, object options) => exec(db => db.GetTableColumnsAsync(sql), scope, options);
183195

184-
public string sqlQuote(string name) => OrmLiteConfig.DialectProvider.GetQuotedName(name);
185-
public string sqlConcat(IEnumerable<object> values) => OrmLiteConfig.DialectProvider.SqlConcat(values);
186-
public string sqlCurrency(string fieldOrValue) => OrmLiteConfig.DialectProvider.SqlCurrency(fieldOrValue);
187-
public string sqlCurrency(string fieldOrValue, string symbol) => OrmLiteConfig.DialectProvider.SqlCurrency(fieldOrValue, symbol);
188-
189-
public string sqlBool(bool value) => OrmLiteConfig.DialectProvider.SqlBool(value);
190-
public string sqlTrue() => OrmLiteConfig.DialectProvider.SqlBool(true);
191-
public string sqlFalse() => OrmLiteConfig.DialectProvider.SqlBool(false);
192-
public string sqlLimit(int? offset, int? limit) => padCondition(OrmLiteConfig.DialectProvider.SqlLimit(offset, limit));
193-
public string sqlLimit(int? limit) => padCondition(OrmLiteConfig.DialectProvider.SqlLimit(null, limit));
194-
public string sqlSkip(int? offset) => padCondition(OrmLiteConfig.DialectProvider.SqlLimit(offset, null));
195-
public string sqlTake(int? limit) => padCondition(OrmLiteConfig.DialectProvider.SqlLimit(null, limit));
196-
public string ormliteVar(string name) => OrmLiteConfig.DialectProvider.Variables.TryGetValue(name, out var value) ? value : null;
196+
public string sqlQuote(ScriptScopeContext scope, string name) => dialect(scope, d => d.GetQuotedName(name));
197+
public string sqlConcat(ScriptScopeContext scope, IEnumerable<object> values) => dialect(scope, d => d.SqlConcat(values));
198+
public string sqlCurrency(ScriptScopeContext scope, string fieldOrValue) => dialect(scope, d => d.SqlCurrency(fieldOrValue));
199+
public string sqlCurrency(ScriptScopeContext scope, string fieldOrValue, string symbol) =>
200+
dialect(scope, d => d.SqlCurrency(fieldOrValue, symbol));
201+
202+
public string sqlBool(ScriptScopeContext scope, bool value) => dialect(scope, d => d.SqlBool(value));
203+
public string sqlTrue(ScriptScopeContext scope) => dialect(scope, d => d.SqlBool(true));
204+
public string sqlFalse(ScriptScopeContext scope) => dialect(scope, d => d.SqlBool(false));
205+
public string sqlLimit(ScriptScopeContext scope, int? offset, int? limit) =>
206+
dialect(scope, d => padCondition(d.SqlLimit(offset, limit)));
207+
public string sqlLimit(ScriptScopeContext scope, int? limit) =>
208+
dialect(scope, d => padCondition(d.SqlLimit(null, limit)));
209+
public string sqlSkip(ScriptScopeContext scope, int? offset) =>
210+
dialect(scope, d => padCondition(d.SqlLimit(offset, null)));
211+
public string sqlTake(ScriptScopeContext scope, int? limit) =>
212+
dialect(scope, d => padCondition(d.SqlLimit(null, limit)));
213+
public string ormliteVar(ScriptScopeContext scope, string name) =>
214+
dialect(scope, d => d.Variables.TryGetValue(name, out var value) ? value : null);
197215

198216
public string sqlVerifyFragment(string sql) => sql.SqlVerifyFragment();
199217
public bool isUnsafeSql(string sql) => OrmLiteUtils.isUnsafeSql(sql, OrmLiteUtils.VerifySqlRegEx);

src/ServiceStack.OrmLite/OrmLiteConnectionFactory.cs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,15 @@ public virtual IDbConnection OpenDbConnection(string namedConnection)
196196
}
197197

198198
private static Dictionary<string, IOrmLiteDialectProvider> dialectProviders;
199-
public static Dictionary<string, IOrmLiteDialectProvider> DialectProviders => dialectProviders ?? (dialectProviders = new Dictionary<string, IOrmLiteDialectProvider>());
199+
public static Dictionary<string, IOrmLiteDialectProvider> DialectProviders => dialectProviders ??= new Dictionary<string, IOrmLiteDialectProvider>();
200200

201201
public virtual void RegisterDialectProvider(string providerName, IOrmLiteDialectProvider dialectProvider)
202202
{
203203
DialectProviders[providerName] = dialectProvider;
204204
}
205205

206206
private static Dictionary<string, OrmLiteConnectionFactory> namedConnections;
207-
public static Dictionary<string, OrmLiteConnectionFactory> NamedConnections => namedConnections ?? (namedConnections = new Dictionary<string, OrmLiteConnectionFactory>());
207+
public static Dictionary<string, OrmLiteConnectionFactory> NamedConnections => namedConnections ??= new Dictionary<string, OrmLiteConnectionFactory>();
208208

209209
public virtual void RegisterConnection(string namedConnection, string connectionString, IOrmLiteDialectProvider dialectProvider)
210210
{
@@ -291,6 +291,32 @@ public static Task<IDbConnection> OpenDbConnectionStringAsync(this IDbConnection
291291
return ((OrmLiteConnectionFactory)connectionFactory).OpenDbConnectionStringAsync(connectionString, providerName, token);
292292
}
293293

294+
295+
public static IOrmLiteDialectProvider GetDialectProvider(this IDbConnectionFactory connectionFactory, ConnectionInfo dbInfo)
296+
{
297+
return dbInfo != null
298+
? GetDialectProvider(connectionFactory, providerName:dbInfo.ProviderName, namedConnection:dbInfo.NamedConnection)
299+
: ((OrmLiteConnectionFactory) connectionFactory).DialectProvider;
300+
}
301+
302+
public static IOrmLiteDialectProvider GetDialectProvider(this IDbConnectionFactory connectionFactory,
303+
string providerName = null, string namedConnection = null)
304+
{
305+
var dbFactory = (OrmLiteConnectionFactory) connectionFactory;
306+
307+
if (!string.IsNullOrEmpty(providerName))
308+
return OrmLiteConnectionFactory.DialectProviders.TryGetValue(providerName, out var provider)
309+
? provider
310+
: throw new NotSupportedException($"Dialect provider is not registered '{provider}'");
311+
312+
if (!string.IsNullOrEmpty(namedConnection))
313+
return OrmLiteConnectionFactory.NamedConnections.TryGetValue(namedConnection, out var namedFactory)
314+
? namedFactory.DialectProvider
315+
: throw new NotSupportedException($"Named connection is not registered '{namedConnection}'");
316+
317+
return dbFactory.DialectProvider;
318+
}
319+
294320
public static IDbConnection ToDbConnection(this IDbConnection db)
295321
{
296322
return db is IHasDbConnection hasDb

0 commit comments

Comments
 (0)