Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions src/OpenPolicyAgent.Ucast.Linq/QueryableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ public static Type GetHigherPrecedenceNumericType(Type a, Type b)
throw new InvalidOperationException($"Cannot determine precedence between {a} and {b}");
}

public static bool IsGuidType(Type type)
{
return type == typeof(Guid);
}

/// <summary>
/// Builds a LINQ Lambda Expression from the UCAST tree, and then invokes
/// it under a LINQ Where expression on some queryable data source.<br />
Expand Down Expand Up @@ -105,14 +110,18 @@ private static Expression BuildFieldExpression<T>(UCASTNode node, ParameterExpre
/// </summary>
/// <typeparam name="T">The LINQ</typeparam>
/// <param name="node">Current UCAST node in the conditions tree.</param>
/// <param name="property">Derefered property lookup expression on the LINQ data source.</param>
/// <param name="property">Deferred property lookup expression on the LINQ data source.</param>
/// <param name="mapper">Dictionary mapping UCAST property names to lambdas that generate LINQ Expressions.</param>
/// <returns>Result, a LINQ Expression (Usually a BinaryExpression).</returns>
/// <returns>Result, a LINQ BinaryExpression.</returns>
/// <exception cref="ArgumentException">Thrown when arguments are of incompatible types.</exception>
private static Expression BuildFieldExpressionFromProperty<T>(UCASTNode node, Expression property, MappingConfiguration<T> mapper)
private static BinaryExpression BuildFieldExpressionFromProperty<T>(UCASTNode node, Expression property, MappingConfiguration<T> mapper)
{
Expression value = Expression.Constant(node.Value);

// If there is a type mismatch in an expression, it is usually from
// differing numeric types, or from things like a GUID vs String
// comparison. We try to make smart conversions here to ensure types
// are matched for the BinaryExpression result.
Type lhsType = property.Type;
Type rhsType = value.Type;
if (lhsType != rhsType)
Expand All @@ -130,6 +139,21 @@ private static Expression BuildFieldExpressionFromProperty<T>(UCASTNode node, Ex
value = Expression.Convert(value, exprType);
}
}
// Convert GUID strings automatically when the property is a Guid.
// Rego doesn't have native GUID typess, so it'll always be a
// GUID-formatted string that the policy is trying to match
// against the property.
else if (IsGuidType(lhsType) && rhsType == typeof(string))
{
if (Guid.TryParse(node.Value?.ToString(), out Guid guid))
{
value = Expression.Constant(guid);
}
else
{
throw new ArgumentException($"Expected a GUID-formatted string, but got '{node.Value}'");
}
}
}

// Switch expression:
Expand Down Expand Up @@ -164,7 +188,6 @@ private static Expression BuildFieldInExpression<T>(UCASTNode node, Expression p
var eq = new UCASTNode("field", "eq", node.Field);
var childValues = (List<object>)node.Value;


// Iterate over all children, and determine highest-precedent type among
// them. Convert LHS value if needed. RHS type conversions will happen
// automatically during expression building later.
Expand Down
15 changes: 10 additions & 5 deletions test/OpenPolicyAgent.Ucast.Linq.Tests/UnitTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ public static IEnumerable<object[]> EqTestData()
yield return new object[] { new UCASTNode { Type = "field", Op = "eq", Field = "data.id", Value = (long)2 }, testdata.Where(d => d.Id == 2).ToList() };
yield return new object[] { new UCASTNode { Type = "field", Op = "eq", Field = "data.flood_stage", Value = true }, testdata.Where(d => d.FloodStage).ToList() };
yield return new object[] { new UCASTNode { Type = "field", Op = "eq", Field = "data.water_level_meters", Value = 5.8 }, testdata.Where(d => d.WaterLevelMeters == 5.8).ToList() };
yield return new object[] { new UCASTNode { Type = "field", Op = "eq", Field = "data.uuid", Value = "123e4567-e89b-12d3-a456-426614174000" }, testdata.Where(d => d.Uuid == new Guid("123e4567-e89b-12d3-a456-426614174000")).ToList() };
}

public static IEnumerable<object[]> NeTestData()
Expand All @@ -89,6 +90,7 @@ public static IEnumerable<object[]> NeTestData()
yield return new object[] { new UCASTNode { Type = "field", Op = "ne", Field = "data.id", Value = (long)2 }, testdata.Where(d => d.Id != 2).ToList() };
yield return new object[] { new UCASTNode { Type = "field", Op = "ne", Field = "data.flood_stage", Value = true }, testdata.Where(d => !d.FloodStage).ToList() };
yield return new object[] { new UCASTNode { Type = "field", Op = "ne", Field = "data.water_level_meters", Value = 5.8 }, testdata.Where(d => d.WaterLevelMeters != 5.8).ToList() };
yield return new object[] { new UCASTNode { Type = "field", Op = "ne", Field = "data.uuid", Value = "123e4567-e89b-12d3-a456-426614174000" }, testdata.Where(d => d.Uuid != new Guid("123e4567-e89b-12d3-a456-426614174000")).ToList() };
}

public static IEnumerable<object[]> GtTestData()
Expand Down Expand Up @@ -128,6 +130,7 @@ public static IEnumerable<object[]> InTestData()
yield return new object[] { new UCASTNode { Type = "field", Op = "in", Field = "data.id", Value = new List<object>() { (long)2, (long)5 } }, testdata.Where(d => new List<object>() { (long)2, (long)5 }.Contains((long)d.Id)).ToList() };
yield return new object[] { new UCASTNode { Type = "field", Op = "in", Field = "data.flood_stage", Value = new List<object>() { true } }, testdata.Where(d => new List<object>() { true }.Contains(d.FloodStage)).ToList() };
yield return new object[] { new UCASTNode { Type = "field", Op = "in", Field = "data.water_level_meters", Value = new List<object>() { 2.5, 5.8 } }, testdata.Where(d => new List<object>() { 2.5, 5.8 }.Contains(d.WaterLevelMeters)).ToList() };
yield return new object[] { new UCASTNode { Type = "field", Op = "in", Field = "data.uuid", Value = new List<object>() { "123e4567-e89b-12d3-a456-426614174000", "123e4567-e89b-12d3-a456-426614174001" } }, testdata.Where(d => new List<object>() { new Guid("123e4567-e89b-12d3-a456-426614174000"), new Guid("123e4567-e89b-12d3-a456-426614174001") }.Contains(d.Uuid)).ToList() };
}
}

Expand All @@ -140,6 +143,7 @@ public static IEnumerable<object[]> NinTestData()
yield return new object[] { new UCASTNode { Type = "field", Op = "nin", Field = "data.id", Value = new List<object>() { (long)2, (long)5 } }, testdata.Where(d => !new List<object>() { (long)2, (long)5 }.Contains((long)d.Id)).ToList() };
yield return new object[] { new UCASTNode { Type = "field", Op = "nin", Field = "data.flood_stage", Value = new List<object>() { true } }, testdata.Where(d => !new List<object>() { true }.Contains(d.FloodStage)).ToList() };
yield return new object[] { new UCASTNode { Type = "field", Op = "nin", Field = "data.water_level_meters", Value = new List<object>() { 2.5, 5.8 } }, testdata.Where(d => !new List<object>() { 2.5, 5.8 }.Contains(d.WaterLevelMeters)).ToList() };
yield return new object[] { new UCASTNode { Type = "field", Op = "nin", Field = "data.uuid", Value = new List<object>() { "123e4567-e89b-12d3-a456-426614174000", "123e4567-e89b-12d3-a456-426614174001" } }, testdata.Where(d => !new List<object>() { new Guid("123e4567-e89b-12d3-a456-426614174000"), new Guid("123e4567-e89b-12d3-a456-426614174001") }.Contains(d.Uuid)).ToList() };
}
}
}
Expand Down Expand Up @@ -503,6 +507,7 @@ public class UnitTestDataSource
public class HydrologyData
{
public int Id { get; set; }
public Guid Uuid { get; set; }
public string? Name { get; set; }
public DateTime LastUpdated { get; set; }
public bool FloodStage { get; set; }
Expand All @@ -513,11 +518,11 @@ public class HydrologyData
public static List<HydrologyData> GetTestHydrologyData()
{
return [
new HydrologyData { Id = 1, Name = "River Alpha", LastUpdated = new DateTime(2024, 12, 10, 8, 30, 0), FloodStage = false, WaterLevelMeters = 2.5, FlowRateMinute = 100.5 },
new HydrologyData { Id = 2, Name = "Lake Beta", LastUpdated = new DateTime(2024, 12, 9, 15, 45, 0), FloodStage = true, WaterLevelMeters = 5.8, FlowRateMinute = null },
new HydrologyData { Id = 3, Name = "Stream Gamma", LastUpdated = new DateTime(2024, 12, 8, 12, 0, 0), FloodStage = false, WaterLevelMeters = 0.75, FlowRateMinute = 25.3 },
new HydrologyData { Id = 4, Name = "Reservoir Delta", LastUpdated = new DateTime(2024, 12, 7, 9, 15, 0), FloodStage = false, WaterLevelMeters = 15.2, FlowRateMinute = 500.0 },
new HydrologyData { Id = 5, Name = null, LastUpdated = new DateTime(2024, 12, 6, 18, 30, 0), FloodStage = true, WaterLevelMeters = 3.1, FlowRateMinute = 75.8 }
new HydrologyData { Id = 1, Uuid = new Guid("123e4567-e89b-12d3-a456-426614174000"), Name = "River Alpha", LastUpdated = new DateTime(2024, 12, 10, 8, 30, 0), FloodStage = false, WaterLevelMeters = 2.5, FlowRateMinute = 100.5 },
new HydrologyData { Id = 2, Uuid = new Guid("123e4567-e89b-12d3-a456-426614174001"), Name = "Lake Beta", LastUpdated = new DateTime(2024, 12, 9, 15, 45, 0), FloodStage = true, WaterLevelMeters = 5.8, FlowRateMinute = null },
new HydrologyData { Id = 3, Uuid = new Guid("123e4567-e89b-12d3-a456-426614174002"), Name = "Stream Gamma", LastUpdated = new DateTime(2024, 12, 8, 12, 0, 0), FloodStage = false, WaterLevelMeters = 0.75, FlowRateMinute = 25.3 },
new HydrologyData { Id = 4, Uuid = new Guid("123e4567-e89b-12d3-a456-426614174003"), Name = "Reservoir Delta", LastUpdated = new DateTime(2024, 12, 7, 9, 15, 0), FloodStage = false, WaterLevelMeters = 15.2, FlowRateMinute = 500.0 },
new HydrologyData { Id = 5, Uuid = new Guid("123e4567-e89b-12d3-a456-426614174004"), Name = null, LastUpdated = new DateTime(2024, 12, 6, 18, 30, 0), FloodStage = true, WaterLevelMeters = 3.1, FlowRateMinute = 75.8 }
];
}

Expand Down
Loading