Skip to content
This repository was archived by the owner on Dec 24, 2022. It is now read-only.

Commit 0883d19

Browse files
committed
Support basic join operations with SqlExpression
1 parent b7a062a commit 0883d19

File tree

7 files changed

+196
-8
lines changed

7 files changed

+196
-8
lines changed

src/ServiceStack.OrmLite/Expressions/ReadConnectionExtensions.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ public static SqlExpression<T> From<T>(this IDbConnection dbConn)
4343
return OrmLiteConfig.ExecFilter.SqlExpression<T>(dbConn);
4444
}
4545

46+
public static SqlExpression<T> From<T, JoinWith>(this IDbConnection dbConn, Expression<Func<T, JoinWith, bool>> joinExpr=null)
47+
{
48+
var sql = OrmLiteConfig.ExecFilter.SqlExpression<T>(dbConn);
49+
sql.Join(joinExpr);
50+
return sql;
51+
}
52+
4653
/// <summary>
4754
/// Creates a new SqlExpression builder for the specified type using a user-defined FROM sql expression.
4855
/// </summary>

src/ServiceStack.OrmLite/Expressions/ReadExtensions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ public static SqlExpression<T> SqlExpression<T>()
1616
internal static List<T> Select<T>(this IDbCommand dbCmd, Func<SqlExpression<T>, SqlExpression<T>> expression)
1717
{
1818
var expr = OrmLiteConfig.DialectProvider.SqlExpression<T>();
19-
string sql = expression(expr).ToSelectStatement();
19+
string sql = expression(expr).SelectInto<T>();
2020

2121
return dbCmd.ExprConvertToList<T>(sql);
2222
}
2323

2424
internal static List<T> Select<T>(this IDbCommand dbCmd, SqlExpression<T> expression)
2525
{
26-
string sql = expression.ToSelectStatement();
26+
string sql = expression.SelectInto<T>();
2727

2828
return dbCmd.ExprConvertToList<T>(sql);
2929
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq.Expressions;
4+
using System.Text;
5+
6+
namespace ServiceStack.OrmLite
7+
{
8+
public abstract partial class SqlExpression<T> : ISqlExpression, ISelectableSqlExpression
9+
{
10+
List<ModelDefinition> tableDefs = new List<ModelDefinition>();
11+
12+
public SqlExpression<T> Join<Source, Target>(Expression<Func<Source, Target, bool>> joinExpr = null)
13+
{
14+
PrefixFieldWithTableName = true;
15+
16+
var sourceDef = typeof(Source).GetModelDefinition();
17+
var targetDef = typeof(Target).GetModelDefinition();
18+
19+
if (tableDefs.Count == 0)
20+
tableDefs.Add(modelDef);
21+
if (!tableDefs.Contains(sourceDef))
22+
tableDefs.Add(sourceDef);
23+
if (!tableDefs.Contains(targetDef))
24+
tableDefs.Add(targetDef);
25+
26+
var fromExpr = FromExpression;
27+
var sbJoin = new StringBuilder();
28+
29+
string sqlExpr;
30+
31+
if (joinExpr != null)
32+
{
33+
sqlExpr = Visit(joinExpr).ToString();
34+
}
35+
else
36+
{
37+
var refField = OrmLiteReadExtensions.GetRefFieldDef(sourceDef, targetDef, typeof(Source));
38+
if (refField == null)
39+
throw new ArgumentException("Could not infer relationship between {0} and {1}"
40+
.Fmt(sourceDef.ModelName, targetDef.ModelName));
41+
42+
sqlExpr = "\n({0}.{1} = {2}.{3})".Fmt(
43+
sourceDef.ModelName.SqlTable(),
44+
sourceDef.PrimaryKey.FieldName.SqlColumn(),
45+
targetDef.ModelName.SqlTable(),
46+
refField.FieldName.SqlColumn());
47+
}
48+
49+
sbJoin.Append(" INNER JOIN {0} ".Fmt(targetDef.ModelName.SqlTable()));
50+
sbJoin.Append(" ON ");
51+
sbJoin.Append(sqlExpr);
52+
53+
FromExpression = fromExpr + sbJoin;
54+
55+
return this;
56+
}
57+
58+
public string SelectInto<TModel>()
59+
{
60+
if (typeof(TModel) == typeof(T) && !PrefixFieldWithTableName)
61+
{
62+
return ToSelectStatement();
63+
}
64+
65+
if (this.tableDefs.Count == 0)
66+
this.tableDefs.Add(modelDef);
67+
68+
var sbSelect = new StringBuilder();
69+
foreach (var fieldDef in modelDef.FieldDefinitions)
70+
{
71+
var found = false;
72+
73+
foreach (var tableDef in tableDefs)
74+
{
75+
foreach (var tableFieldDef in tableDef.FieldDefinitions)
76+
{
77+
if (tableFieldDef.Name == fieldDef.Name)
78+
{
79+
found = true;
80+
if (sbSelect.Length > 0)
81+
sbSelect.Append(", ");
82+
83+
sbSelect.AppendFormat("{0}.{1}",
84+
tableDef.ModelName.SqlTable(),
85+
tableFieldDef.IsRowVersion
86+
? OrmLiteConfig.DialectProvider.GetRowVersionColumnName(tableFieldDef)
87+
: tableFieldDef.FieldName.SqlColumn());
88+
break;
89+
}
90+
}
91+
92+
if (found)
93+
break;
94+
}
95+
96+
if (!found)
97+
{
98+
if (sbSelect.Length > 0)
99+
sbSelect.Append(", ");
100+
101+
sbSelect.AppendFormat("{0}.{1}",
102+
modelDef.ModelName.SqlTable(),
103+
fieldDef.IsRowVersion
104+
? OrmLiteConfig.DialectProvider.GetRowVersionColumnName(fieldDef)
105+
: fieldDef.FieldName.SqlColumn());
106+
}
107+
}
108+
109+
SelectExpression = "SELECT " + sbSelect;
110+
111+
return ToSelectStatement();
112+
}
113+
}
114+
}

src/ServiceStack.OrmLite/Expressions/SqlExpression.cs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
namespace ServiceStack.OrmLite
1111
{
12-
public abstract class SqlExpression<T> : ISqlExpression
12+
public abstract partial class SqlExpression<T> : ISqlExpression
1313
{
1414
private Expression<Func<T, bool>> underlyingExpression;
1515
private List<string> orderByProperties = new List<string>();
@@ -846,9 +846,16 @@ protected virtual object VisitMemberAccess(MemberExpression m)
846846
var propertyInfo = m.Member as PropertyInfo;
847847

848848
if (propertyInfo.PropertyType.IsEnum)
849-
return new EnumMemberAccess((PrefixFieldWithTableName ? OrmLiteConfig.DialectProvider.GetQuotedTableName(modelDef.ModelName) + "." : "") + GetQuotedColumnName(m.Member.Name), propertyInfo.PropertyType);
850-
851-
return new PartialSqlString((PrefixFieldWithTableName ? OrmLiteConfig.DialectProvider.GetQuotedTableName(modelDef.ModelName) + "." : "") + GetQuotedColumnName(m.Member.Name));
849+
return new EnumMemberAccess(
850+
(PrefixFieldWithTableName
851+
? OrmLiteConfig.DialectProvider.GetQuotedTableName(propertyInfo.DeclaringType.GetModelDefinition().ModelName) + "."
852+
: "")
853+
+ GetQuotedColumnName(m.Member.Name), propertyInfo.PropertyType);
854+
855+
return new PartialSqlString((PrefixFieldWithTableName
856+
? OrmLiteConfig.DialectProvider.GetQuotedTableName(propertyInfo.DeclaringType.GetModelDefinition().ModelName) + "."
857+
: "")
858+
+ GetQuotedColumnName(m.Member.Name));
852859
}
853860

854861
var member = Expression.Convert(m, typeof(object));
@@ -1348,6 +1355,11 @@ public interface ISqlExpression
13481355
string ToSelectStatement();
13491356
}
13501357

1358+
public interface ISelectableSqlExpression
1359+
{
1360+
string SelectInto<TModel>();
1361+
}
1362+
13511363
public class PartialSqlString
13521364
{
13531365
public PartialSqlString(string text)

src/ServiceStack.OrmLite/OrmLiteReadExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,7 @@ public static void LoadReferences<T>(this IDbCommand dbCmd, T instance)
842842
}
843843
}
844844

845-
private static FieldDefinition GetRefFieldDef(ModelDefinition modelDef, ModelDefinition refModelDef, Type refType)
845+
public static FieldDefinition GetRefFieldDef(ModelDefinition modelDef, ModelDefinition refModelDef, Type refType)
846846
{
847847
var refNameConvention = modelDef.ModelName + "Id";
848848
var refField = refModelDef.FieldDefinitions.FirstOrDefault(x => x.FieldName == refNameConvention)

src/ServiceStack.OrmLite/ServiceStack.OrmLite.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
</ItemGroup>
106106
<ItemGroup>
107107
<Compile Include="AliasNamingStrategy.cs" />
108+
<Compile Include="Expressions\SqlExpression.Join.cs" />
108109
<Compile Include="Expressions\SqlExpressionVisitor.cs" />
109110
<Compile Include="Expressions\ParameterRebinder.cs" />
110111
<Compile Include="Expressions\PredicateBuilder.cs" />

tests/ServiceStack.OrmLite.Tests/Expression/SqlExpressionTests.cs

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@ public class LetterWeighting
2121
public int Weighting { get; set; }
2222
}
2323

24+
public class LetterStat
25+
{
26+
[AutoIncrement]
27+
public int Id { get; set; }
28+
public long LetterFrequencyId { get; set; }
29+
public string Letter { get; set; }
30+
public int Weighting { get; set; }
31+
}
32+
2433
public class SqlExpressionTests : ExpressionsTestBase
2534
{
2635
private static void InitLetters(IDbConnection db)
@@ -130,7 +139,7 @@ public void Can_select_limit_with_JoinSqlBuilder()
130139
db.Insert(new LetterWeighting { LetterFrequencyId = id, Weighting = ++i * 10 });
131140
});
132141

133-
var joinFn = new Func<JoinSqlBuilder<LetterFrequency, LetterWeighting>>(() =>
142+
var joinFn = new Func<JoinSqlBuilder<LetterFrequency, LetterWeighting>>(() =>
134143
new JoinSqlBuilder<LetterFrequency, LetterWeighting>()
135144
.Join<LetterFrequency, LetterWeighting>(x => x.Id, x => x.LetterFrequencyId)
136145
);
@@ -157,5 +166,50 @@ public void Can_select_limit_with_JoinSqlBuilder()
157166
Assert.That(results.ConvertAll(x => x.Letter), Is.EquivalentTo(new[] { "D", "C" }));
158167
}
159168
}
169+
170+
[Test]
171+
public void Can_join_with_SqlExpression()
172+
{
173+
using (var db = OpenDbConnection())
174+
{
175+
db.DropAndCreateTable<LetterFrequency>();
176+
db.DropAndCreateTable<LetterStat>();
177+
178+
var letters = "A,B,C,D,E".Split(',');
179+
var i = 0;
180+
letters.Each(letter =>
181+
{
182+
var id = db.Insert(new LetterFrequency { Letter = letter }, selectIdentity: true);
183+
db.Insert(new LetterStat
184+
{
185+
LetterFrequencyId = id,
186+
Letter = letter,
187+
Weighting = ++i * 10
188+
});
189+
});
190+
191+
db.Insert(new LetterFrequency { Letter = "F" });
192+
193+
Assert.That(db.Count<LetterFrequency>(), Is.EqualTo(6));
194+
195+
var results = db.Select(db.From<LetterFrequency, LetterStat>());
196+
db.GetLastSql().Print();
197+
Assert.That(results.Count, Is.EqualTo(5));
198+
199+
results = db.Select(db.From<LetterFrequency, LetterStat>((x, y) => x.Id == y.LetterFrequencyId));
200+
db.GetLastSql().Print();
201+
Assert.That(results.Count, Is.EqualTo(5));
202+
203+
results = db.Select(db.From<LetterFrequency>()
204+
.Join<LetterFrequency, LetterStat>((x, y) => x.Id == y.LetterFrequencyId));
205+
db.GetLastSql().Print();
206+
Assert.That(results.Count, Is.EqualTo(5));
207+
208+
results = db.Select<LetterFrequency>(q =>
209+
q.Join<LetterFrequency, LetterStat>((x, y) => x.Id == y.LetterFrequencyId));
210+
db.GetLastSql().Print();
211+
Assert.That(results.Count, Is.EqualTo(5));
212+
}
213+
}
160214
}
161215
}

0 commit comments

Comments
 (0)