Skip to content

Commit 6fe1ed1

Browse files
author
Viktor Tochonov
committed
Implemented tests for different discriminator cases
1 parent 0fa9d2a commit 6fe1ed1

File tree

3 files changed

+108
-133
lines changed

3 files changed

+108
-133
lines changed

src/FSharp.Data.GraphQL.Server.Middleware/ObjectListFilter.fs

Lines changed: 34 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -82,36 +82,6 @@ module ObjectListFilterExtensions =
8282
// | t when Type.(=)(t, typeof<Building>) -> ResidentialPropertiesConstants.Discriminators.Building
8383
// )
8484

85-
type DiscriminatorExpression<'T, 'D> =
86-
| GetDiscriminatorValue of ('T -> 'D)
87-
| CompareDiscriminator of Expression<Func<'T, 'D, bool>>
88-
89-
[<Struct>]
90-
type ObjectListFilterLinqOptions<'T, 'D> (
91-
discriminatorExpression : DiscriminatorExpression<'T, 'D> | null,
92-
[<Optional>] getDiscriminatorValue: (Type -> 'D) | null,
93-
[<Optional>] serializeMemberName: (MemberInfo -> string) | null) =
94-
95-
member _.DiscriminatorExpression = discriminatorExpression |> ValueOption.ofObj
96-
member _.GetDiscriminatorValue = getDiscriminatorValue |> ValueOption.ofObj
97-
member _.SerializeMemberName = serializeMemberName |> ValueOption.ofObj
98-
99-
static member None = ObjectListFilterLinqOptions<'T, 'D> (null, null, null)
100-
101-
new (getDiscriminatorValue : 'T -> 'D) = ObjectListFilterLinqOptions<'T, 'D> (GetDiscriminatorValue getDiscriminatorValue, null, null)
102-
new (compareDiscriminator : Expression<Func<'T, 'D, bool>>) = ObjectListFilterLinqOptions<'T, 'D> (CompareDiscriminator compareDiscriminator, null, null)
103-
new (getDiscriminatorValue : Type -> 'D) = ObjectListFilterLinqOptions<'T, 'D> (null, getDiscriminatorValue, null)
104-
new (serializeMemberName : MemberInfo -> string) = ObjectListFilterLinqOptions<'T, 'D> (null, null, serializeMemberName)
105-
106-
new (getDiscriminatorValue : 'T -> 'D, serializeMemberName : MemberInfo -> string) = ObjectListFilterLinqOptions<'T, 'D> (GetDiscriminatorValue getDiscriminatorValue, null, serializeMemberName)
107-
new (compareDiscriminator : Expression<Func<'T, 'D, bool>>, serializeMemberName : MemberInfo -> string) = ObjectListFilterLinqOptions<'T, 'D> (CompareDiscriminator compareDiscriminator, null, serializeMemberName)
108-
new (getDiscriminatorValue : Type -> 'D, serializeMemberName : MemberInfo -> string) = ObjectListFilterLinqOptions<'T, 'D> (null, getDiscriminatorValue, serializeMemberName)
109-
110-
// Helper to create lambda from body expression
111-
let makeLambda<'T> (param : ParameterExpression) (body : Expression) =
112-
let delegateType = typedefof<Func<_, _>>.MakeGenericType ([| typeof<'T>; body.Type |])
113-
Expression.Lambda (delegateType, body, param)
114-
11585
let private genericWhereMethod =
11686
typeof<Queryable>.GetMethods ()
11787
|> Seq.where (fun m -> m.Name = "Where")
@@ -123,7 +93,7 @@ module ObjectListFilterExtensions =
12393
// Helper to create Where expression
12494
let whereExpr<'T> (query : IQueryable<'T>) (param : ParameterExpression) predicate =
12595
let whereMethod = genericWhereMethod.MakeGenericMethod ([| typeof<'T> |])
126-
Expression.Call (whereMethod, [| query.Expression; makeLambda<'T> param predicate |])
96+
Expression.Call (whereMethod, [| query.Expression; Expression.Lambda<Func<'T, bool>> (predicate, param) |])
12797

12898
let private StringStartsWithMethod = typeof<string>.GetMethod ("StartsWith", [| typeof<string> |])
12999
let private StringEndsWithMethod = typeof<string>.GetMethod ("EndsWith", [| typeof<string> |])
@@ -164,67 +134,61 @@ module ObjectListFilterExtensions =
164134
let paramExpr = Expression.PropertyOrField (param, f.FieldName)
165135
buildFilterExpr (SourceExpression paramExpr) buildTypeDiscriminatorCheck f.Value
166136

167-
type ObjectListFilter with
137+
[<Struct>]
138+
type ObjectListFilterLinqOptions<'T, 'D> (
139+
[<Optional>] compareDiscriminator : Expression<Func<'T, 'D, bool>> | null,
140+
[<Optional>] getDiscriminatorValue : (Type -> 'D) | null,
141+
[<Optional>] serializeMemberName : (MemberInfo -> string) | null) =
168142

169-
member filter.Apply<'T, 'D>
170-
(query : IQueryable<'T>, compareDiscriminator : Expression<Func<'T, 'D, bool>>, getDiscriminatorValue : (Type -> 'D))
171-
=
143+
member _.CompareDiscriminator = compareDiscriminator |> ValueOption.ofObj
144+
member _.GetDiscriminatorValue = getDiscriminatorValue |> ValueOption.ofObj
145+
//member _.SerializeMemberName = serializeMemberName |> ValueOption.ofObj
172146

173-
match filter with
174-
| NoFilter -> query
175-
| _ ->
147+
static member None = ObjectListFilterLinqOptions<'T, 'D> (null, null, null)
176148

177-
// Helper for discriminator comparison
178-
let buildTypeDiscriminatorCheck (param : SourceExpression) (t : Type) =
179-
match compareDiscriminator, getDiscriminatorValue with
180-
| null, discValueFn when obj.Equals (discValueFn, null) ->
181-
// use __typename from filter and do type.ToSting() for values
182-
Unchecked.defaultof<Expression>
183-
| discExpr, discValueFn when obj.Equals (discValueFn, null) ->
184-
// use discriminator and do type.ToSting() for values
185-
Unchecked.defaultof<Expression>
186-
| null, discValueFn ->
187-
// use __typename from filter and execute discValueFn for values
188-
Unchecked.defaultof<Expression>
189-
| discExpr, discValueFn ->
190-
// use discriminator and execute discValueFn for values
149+
static member GetCompareDiscriminator (getDiscriminatorValue : Expression<Func<'T, 'D>>) =
150+
let tParam = Expression.Parameter (typeof<'T>, "x")
151+
let dParam = Expression.Parameter (typeof<'D>, "d")
152+
let body = Expression.Equal(Expression.Invoke(getDiscriminatorValue, tParam), dParam)
153+
Expression.Lambda<Func<'T, 'D, bool>> (body, tParam, dParam)
191154

192-
let discriminatorValue = discValueFn t
193-
Expression.Equal (Expression.PropertyOrField (param, "__discriminator"), Expression.Constant (discriminatorValue))
155+
new (getDiscriminator : Expression<Func<'T, 'D>>) = ObjectListFilterLinqOptions<'T, 'D> (ObjectListFilterLinqOptions.GetCompareDiscriminator getDiscriminator, null, null)
156+
new (compareDiscriminator : Expression<Func<'T, 'D, bool>>) = ObjectListFilterLinqOptions<'T, 'D> (compareDiscriminator, null, null)
157+
new (getDiscriminatorValue : Type -> 'D) = ObjectListFilterLinqOptions<'T, 'D> (null, getDiscriminatorValue, null)
158+
//new (serializeMemberName : MemberInfo -> string) = ObjectListFilterLinqOptions<'T, 'D> (null, null, serializeMemberName)
194159

195-
let queryExpr =
196-
let param = Expression.Parameter (typeof<'T>, "x")
197-
let body = buildFilterExpr (SourceExpression param) buildTypeDiscriminatorCheck filter
198-
whereExpr<'T> query param body
199-
// Create and execute the final expression
200-
query.Provider.CreateQuery<'T> (queryExpr)
160+
new (getDiscriminator : Expression<Func<'T, 'D>>, getDiscriminatorValue : Type -> 'D) = ObjectListFilterLinqOptions<'T, 'D> (ObjectListFilterLinqOptions.GetCompareDiscriminator getDiscriminator, getDiscriminatorValue, null)
161+
162+
//new (getDiscriminator : Expression<Func<'T, 'D>>, serializeMemberName : MemberInfo -> string) = ObjectListFilterLinqOptions<'T, 'D> (ObjectListFilterLinqOptions.GetCompareDiscriminator getDiscriminator, null, serializeMemberName)
163+
//new (compareDiscriminator : Expression<Func<'T, 'D, bool>>, serializeMemberName : MemberInfo -> string) = ObjectListFilterLinqOptions<'T, 'D> (compareDiscriminator, null, serializeMemberName)
164+
//new (getDiscriminatorValue : Type -> 'D, serializeMemberName : MemberInfo -> string) = ObjectListFilterLinqOptions<'T, 'D> (null, getDiscriminatorValue, serializeMemberName)
165+
166+
type ObjectListFilter with
201167

202-
member filter.Apply<'T, 'D>
203-
(query : IQueryable<'T>, [<Optional>] getDiscriminator : Expression<Func<'T, 'D>> | null, [<Optional>] getDiscriminatorValue : Type -> 'D)
204-
=
168+
member filter.Apply<'T, 'D> (query : IQueryable<'T>, [<Optional>] options : ObjectListFilterLinqOptions<'T, 'D>) =
205169

206170
match filter with
207171
| NoFilter -> query
208172
| _ ->
209173
// Helper for discriminator comparison
210174
let buildTypeDiscriminatorCheck (param : SourceExpression) (t : Type) =
211-
match getDiscriminator, getDiscriminatorValue with
212-
| null, discValueFn when obj.Equals(discValueFn, null) ->
175+
match options.CompareDiscriminator, options.GetDiscriminatorValue with
176+
| ValueNone, ValueNone ->
213177
// use __typename from filter and do type.ToSting() for values
214178
let typename = t.FullName
215179
Expression.Equal(Expression.PropertyOrField(param, "__typename"), Expression.Constant(typename)) :> Expression
216-
| discExpr, discValueFn when obj.Equals(discValueFn, null) ->
180+
| ValueSome discExpr, ValueNone ->
217181
// use discriminator and do type.ToSting() for values
218182
let typename = t.FullName
219-
Expression.Equal(Expression.Invoke(discExpr, param), Expression.Constant(typename)) :> Expression
220-
| null, discValueFn ->
183+
Expression.Invoke(discExpr, param, Expression.Constant(typename)) :> Expression
184+
| ValueNone, ValueSome discValueFn ->
221185
// use __typename from filter and execute discValueFn for values
222186
let discriminatorValue = discValueFn t
223187
Expression.Equal(Expression.PropertyOrField(param, "__typename"), Expression.Constant(discriminatorValue)) :> Expression
224-
| discExpr, discValueFn ->
188+
| ValueSome discExpr, ValueSome discValueFn ->
225189
// use discriminator and execute discValueFn for values
226190
let discriminatorValue = discValueFn t
227-
Expression.Equal (Expression.Invoke(discExpr, param), Expression.Constant (discriminatorValue))
191+
Expression.Invoke(discExpr, param, Expression.Constant (discriminatorValue))
228192

229193
let queryExpr =
230194
let param = Expression.Parameter (typeof<'T>, "x")

src/FSharp.Data.GraphQL.Server.Middleware/TypeSystemExtensions.fs

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,21 @@ module TypeSystemExtensions =
2323

2424
open ObjectListFilter.Operators
2525

26+
type ExecutionInfo with
27+
28+
member this.ResolveAbstractionFilter (typeMap : TypeMap) =
29+
match this.Kind with
30+
| ResolveAbstraction typeFields ->
31+
let getType name =
32+
match typeMap.TryFind name with
33+
| ValueSome tdef -> tdef.Type
34+
| ValueNone -> raise (MalformedGQLQueryException ($"Type '{name}' not found in schema."))
35+
match typeFields.Keys |> Seq.map getType |> Seq.toList with
36+
| [] -> ValueNone
37+
| filters -> ValueSome (OfTypes filters)
38+
| _ -> ValueNone
39+
40+
2641
type ResolveFieldContext with
2742

2843
/// <summary>
@@ -32,16 +47,9 @@ module TypeSystemExtensions =
3247
member this.Filter =
3348
match this.Args.TryGetValue "filter" with
3449
| true, (:? ObjectListFilter as f) ->
35-
match this.ExecutionInfo.Kind with
36-
| ResolveAbstraction typeFields ->
37-
let getType name =
38-
match this.Context.Schema.TypeMap.TryFind name with
39-
| ValueSome tdef -> tdef.Type
40-
| ValueNone -> raise (MalformedGQLQueryException ($"Type '{name}' not found in schema."))
41-
match typeFields.Keys |> Seq.map getType |> Seq.toList with
42-
| [] -> ValueNone
43-
| filters -> ValueSome (f &&& (OfTypes filters))
44-
| _ -> ValueSome f
50+
match this.ExecutionInfo.ResolveAbstractionFilter (this.Context.Schema.TypeMap) with
51+
| ValueSome ofTypes -> ValueSome (ofTypes &&& f)
52+
| ValueNone -> ValueSome f
4553
| false, _ -> ValueNone
4654
| true, _ -> raise (InvalidOperationException "Invalid filter argument type.")
4755

tests/FSharp.Data.GraphQL.Tests/LinqTests.fs

Lines changed: 56 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -453,11 +453,12 @@ type Property =
453453
| Building of Building
454454
| Community of Community
455455

456+
456457
[<Fact>]
457458
let ``ObjectListFilter works with getDiscriminator for Complex``() =
458459
let propertyData: Property list =
459460
[
460-
Complex { ID = 1; Name = "Complex A"; Discriminator = typeof<Complex>.FullName}
461+
Complex { ID = 1; Name = "Complex A"; Discriminator = typeof<Complex>.FullName }
461462
Building { ID = 2; Name = "Building B"; Discriminator = typeof<Building>.FullName }
462463
Community { ID = 3; Name = "Community C"; Discriminator = typeof<Community>.FullName; Complexes = [1]; Buildings = [2] }
463464
Complex { ID = 4; Name = "Complex AA"; Discriminator = typeof<Complex>.FullName }
@@ -466,17 +467,13 @@ let ``ObjectListFilter works with getDiscriminator for Complex``() =
466467
]
467468
let queryable = propertyData.AsQueryable()
468469
let filter = OfTypes [typeof<Complex>]
469-
let filteredData =
470-
filter.Apply(
471-
queryable,
472-
getDiscriminator =
473-
fun p ->
474-
match p with
475-
| Complex c -> c.Discriminator
476-
| Building b -> b.Discriminator
477-
| Community c -> c.Discriminator
478-
)
479-
|> Seq.toList
470+
let options =
471+
ObjectListFilterLinqOptions(
472+
(function
473+
| Complex c -> c.Discriminator
474+
| Building b -> b.Discriminator
475+
| Community c -> c.Discriminator))
476+
let filteredData = filter.Apply(queryable,options) |> Seq.toList
480477
List.length filteredData |> equals 2
481478
let result1 = List.head filteredData
482479
match result1 with
@@ -496,30 +493,28 @@ let ``ObjectListFilter works with getDiscriminator for Complex``() =
496493
let ``ObjectListFilter works with getDiscriminator and getDiscriminatorValue for Complex``() =
497494
let propertyData: Property list =
498495
[
499-
Complex { ID = 1; Name = "Complex A"; Discriminator = "Complex" }
500-
Building { ID = 2; Name = "Building B"; Discriminator = "Building" }
501-
Community { ID = 3; Name = "Community C"; Discriminator = "Community"; Complexes = [1]; Buildings = [2] }
502-
Complex { ID = 4; Name = "Complex AA"; Discriminator = "Complex" }
503-
Building { ID = 5; Name = "Building BB"; Discriminator = "Building" }
504-
Community { ID = 6; Name = "Community CC"; Discriminator = "Community"; Complexes = [4]; Buildings = [5] }
496+
Complex { ID = 1; Name = "Complex A"; Discriminator = typeof<Complex>.Name}
497+
Building { ID = 2; Name = "Building B"; Discriminator = typeof<Building>.Name }
498+
Community { ID = 3; Name = "Community C"; Discriminator = typeof<Community>.Name; Complexes = [1]; Buildings = [2] }
499+
Complex { ID = 4; Name = "Complex AA"; Discriminator = typeof<Complex>.Name }
500+
Building { ID = 5; Name = "Building BB"; Discriminator = typeof<Building>.Name }
501+
Community { ID = 6; Name = "Community CC"; Discriminator = typeof<Community>.Name; Complexes = [4]; Buildings = [5] }
505502
]
506503
let queryable = propertyData.AsQueryable()
507504
let filter = OfTypes [typeof<Complex>]
508-
let filteredData =
509-
filter.Apply(
510-
queryable,
511-
(fun p ->
512-
match p with
513-
| Complex c -> c.Discriminator
514-
| Building b -> b.Discriminator
515-
| Community c -> c.Discriminator),
505+
let options =
506+
ObjectListFilterLinqOptions(
507+
(function
508+
| Complex c -> c.Discriminator
509+
| Building b -> b.Discriminator
510+
| Community c -> c.Discriminator),
516511
(function
517512
| t when t = typeof<Complex> -> "Complex"
518513
| t when t = typeof<Building> -> "Building"
519514
| t when t = typeof<Community> -> "Community"
520515
| _ -> raise (NotSupportedException "Type not supported"))
521516
)
522-
|> Seq.toList
517+
let filteredData = filter.Apply(queryable,options) |> Seq.toList
523518
List.length filteredData |> equals 2
524519
let result1 = List.head filteredData
525520
match result1 with
@@ -534,37 +529,45 @@ let ``ObjectListFilter works with getDiscriminator and getDiscriminatorValue for
534529
c.Name |> equals "Complex AA"
535530
| _ -> failwith "Expected Complex"
536531

532+
type Cow =
533+
{ ID : int
534+
Name : string
535+
__typename : string }
537536

537+
type Horse =
538+
{ ID : int
539+
Name : string
540+
__typename : string }
541+
542+
let animalData =
543+
[
544+
{ ID = 1; Name = "Cow A"; __typename = typeof<Cow>.Name }
545+
{ ID = 2; Name = "Horse B"; __typename = typeof<Horse>.Name }
546+
{ ID = 3; Name = "Cow C"; __typename = typeof<Cow>.Name }
547+
{ ID = 4; Name = "Horse D"; __typename = typeof<Horse>.Name }
548+
]
538549

539550
[<Fact>]
540-
let ``ObjectListFilter works with getDiscriminatorValue for Complex``() =
541-
let propertyData: Property list =
542-
[
543-
Complex { ID = 1; Name = "Complex A"; Discriminator = typeof<Complex>.FullName}
544-
Building { ID = 2; Name = "Building B"; Discriminator = typeof<Building>.FullName }
545-
Community { ID = 3; Name = "Community C"; Discriminator = typeof<Community>.FullName; Complexes = [1]; Buildings = [2] }
546-
Complex { ID = 4; Name = "Complex AA"; Discriminator = typeof<Complex>.FullName }
547-
Building { ID = 5; Name = "Building BB"; Discriminator = typeof<Building>.FullName }
548-
Community { ID = 6; Name = "Community CC"; Discriminator = typeof<Community>.FullName; Complexes = [4]; Buildings = [5] }
549-
]
550-
let queryable = propertyData.AsQueryable()
551-
let filter = OfTypes [typeof<Complex>]
552-
let filteredData =
553-
filter.Apply(
554-
queryable,
555-
getDiscriminatorValue = (fun t -> t.FullName)
556-
)
557-
|> Seq.toList
551+
let ``ObjectListFilter works with getDiscriminatorValue for Horse``() =
552+
let queryable = animalData.AsQueryable()
553+
let filter = OfTypes [typeof<Horse>]
554+
let options =
555+
ObjectListFilterLinqOptions(
556+
getDiscriminatorValue = (function
557+
| t when t = typeof<Cow> -> t.Name
558+
| t when t = typeof<Horse> -> t.Name
559+
| _ -> raise (NotSupportedException "Type not supported"))
560+
)
561+
let filteredData = filter.Apply(queryable, options) |> Seq.toList
558562
List.length filteredData |> equals 2
559563
let result1 = List.head filteredData
560564
match result1 with
561-
| Complex c ->
562-
c.ID |> equals 1
563-
c.Name |> equals "Complex A"
564-
| _ -> failwith "Expected Complex"
565+
| h ->
566+
h.ID |> equals 2
567+
h.Name |> equals "Horse B"
568+
| _ -> failwith "Expected Horse"
565569
let result2 = List.last filteredData
566570
match result2 with
567-
| Complex c ->
568-
c.ID |> equals 4
569-
c.Name |> equals "Complex AA"
570-
| _ -> failwith "Expected Complex"
571+
| h ->
572+
h.ID |> equals 4
573+
h.Name |> equals "Horse D"

0 commit comments

Comments
 (0)