Skip to content

Commit a5f219d

Browse files
committed
Add translation for string methods with char arguments
Enable SQL translation for string.IndexOf, Replace, StartsWith, EndsWith, and Contains when called with char arguments. Update translators and type mapping to support char overloads, and implement corresponding tests to verify correct SQL generation.
1 parent f4c1183 commit a5f219d

File tree

3 files changed

+143
-25
lines changed

3 files changed

+143
-25
lines changed

src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStringMethodTranslator.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ public class NpgsqlStringMethodTranslator : IMethodCallTranslator
3939
private static readonly MethodInfo Replace = typeof(string).GetRuntimeMethod(
4040
nameof(string.Replace), [typeof(string), typeof(string)])!;
4141

42+
private static readonly MethodInfo Replace_Char = typeof(string).GetRuntimeMethod(
43+
nameof(string.Replace), [typeof(char), typeof(char)])!;
44+
4245
private static readonly MethodInfo Substring = typeof(string).GetTypeInfo().GetDeclaredMethods(nameof(string.Substring))
4346
.Single(m => m.GetParameters().Length == 1);
4447

@@ -204,7 +207,7 @@ public NpgsqlStringMethodTranslator(NpgsqlTypeMappingSource typeMappingSource, I
204207
{
205208
var argument = arguments[0];
206209
var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance!, argument);
207-
210+
argument = _sqlExpressionFactory.ApplyTypeMapping(argument, argument.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping);
208211
return _sqlExpressionFactory.Subtract(
209212
_sqlExpressionFactory.Function(
210213
"strpos",
@@ -218,12 +221,15 @@ public NpgsqlStringMethodTranslator(NpgsqlTypeMappingSource typeMappingSource, I
218221
_sqlExpressionFactory.Constant(1));
219222
}
220223

221-
if (method == Replace)
224+
if (method == Replace || method == Replace_Char)
222225
{
223226
var oldValue = arguments[0];
224227
var newValue = arguments[1];
225228
var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance!, oldValue, newValue);
226229

230+
oldValue = _sqlExpressionFactory.ApplyTypeMapping(oldValue, oldValue.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping);
231+
newValue = _sqlExpressionFactory.ApplyTypeMapping(newValue, newValue.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping);
232+
227233
return _sqlExpressionFactory.Function(
228234
"replace",
229235
[

src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,21 @@ public class NpgsqlSqlTranslatingExpressionVisitor : RelationalSqlTranslatingExp
4545
private static readonly MethodInfo StringStartsWithMethod
4646
= typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(string)])!;
4747

48+
private static readonly MethodInfo StringStartsWithMethodChar
49+
= typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(char)])!;
50+
4851
private static readonly MethodInfo StringEndsWithMethod
4952
= typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(string)])!;
5053

54+
private static readonly MethodInfo StringEndsWithMethodChar
55+
= typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(char)])!;
56+
5157
private static readonly MethodInfo StringContainsMethod
5258
= typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(string)])!;
5359

60+
private static readonly MethodInfo StringContainsMethodChar
61+
= typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(char)])!;
62+
5463
private static readonly MethodInfo EscapeLikePatternParameterMethod =
5564
typeof(NpgsqlSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ConstructLikePatternParameter))!;
5665

@@ -405,21 +414,21 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
405414
return TranslateCubeToSubset(sqlCubeInstance, sqlIndexes) ?? QueryCompilationContext.NotTranslatedExpression;
406415
}
407416

408-
if (method == StringStartsWithMethod
417+
if ((method == StringStartsWithMethod || method == StringStartsWithMethodChar)
409418
&& TryTranslateStartsEndsWithContains(
410419
methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.StartsWith, out var translation1))
411420
{
412421
return translation1;
413422
}
414423

415-
if (method == StringEndsWithMethod
424+
if ((method == StringEndsWithMethod || method == StringEndsWithMethodChar)
416425
&& TryTranslateStartsEndsWithContains(
417426
methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.EndsWith, out var translation2))
418427
{
419428
return translation2;
420429
}
421430

422-
if (method == StringContainsMethod
431+
if ((method == StringContainsMethod || method == StringContainsMethodChar)
423432
&& TryTranslateStartsEndsWithContains(
424433
methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.Contains, out var translation3))
425434
{
@@ -719,6 +728,32 @@ private bool TryTranslateStartsEndsWithContains(
719728
_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
720729
})),
721730

731+
char s when !IsLikeWildChar(s)
732+
=> _sqlExpressionFactory.Like(
733+
translatedInstance,
734+
_sqlExpressionFactory.Constant(
735+
methodType switch
736+
{
737+
StartsEndsWithContains.StartsWith => s + "%",
738+
StartsEndsWithContains.EndsWith => "%" + s,
739+
StartsEndsWithContains.Contains => $"%{s}%",
740+
741+
_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
742+
})),
743+
744+
char s => _sqlExpressionFactory.Like(
745+
translatedInstance,
746+
_sqlExpressionFactory.Constant(
747+
methodType switch
748+
{
749+
StartsEndsWithContains.StartsWith => LikeEscapeChar + s + "%",
750+
StartsEndsWithContains.EndsWith => "%" + LikeEscapeChar + s,
751+
StartsEndsWithContains.Contains => $"%{LikeEscapeChar}{s}%",
752+
753+
_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
754+
}),
755+
_sqlExpressionFactory.Constant(LikeEscapeChar)),
756+
722757
_ => throw new UnreachableException()
723758
};
724759

@@ -834,6 +869,22 @@ private bool TryTranslateStartsEndsWithContains(
834869
_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
835870
},
836871

872+
char s when !IsLikeWildChar(s) => methodType switch
873+
{
874+
StartsEndsWithContains.StartsWith => s + "%",
875+
StartsEndsWithContains.EndsWith => "%" + s,
876+
StartsEndsWithContains.Contains => $"%{s}%",
877+
_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
878+
},
879+
880+
char s => methodType switch
881+
{
882+
StartsEndsWithContains.StartsWith => LikeEscapeChar + s + "%",
883+
StartsEndsWithContains.EndsWith => "%" + LikeEscapeChar + s,
884+
StartsEndsWithContains.Contains => $"%{LikeEscapeChar}{s}%",
885+
_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
886+
},
887+
837888
_ => throw new UnreachableException()
838889
};
839890

test/EFCore.PG.FunctionalTests/Query/Translations/StringTranslationsNpgsqlTest.cs

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,17 @@ WHERE strpos(b."String", 'eattl') - 1 <> -1
131131
""");
132132
}
133133

134-
// TODO: #3547
135-
public override Task IndexOf_Char()
136-
=> Assert.ThrowsAsync<InvalidCastException>(() => base.IndexOf_Char());
134+
public override async Task IndexOf_Char()
135+
{
136+
await base.IndexOf_Char();
137+
138+
AssertSql(
139+
"""
140+
SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan"
141+
FROM "BasicTypesEntities" AS b
142+
WHERE strpos(b."String", 'e') - 1 <> -1
143+
""");
144+
}
137145

138146
public override async Task IndexOf_with_empty_string()
139147
{
@@ -231,9 +239,17 @@ WHERE replace(b."String", 'Sea', 'Rea') = 'Reattle'
231239
""");
232240
}
233241

234-
// TODO: #3547
235-
public override Task Replace_Char()
236-
=> AssertTranslationFailed(() => base.Replace_Char());
242+
public override async Task Replace_Char()
243+
{
244+
await base.Replace_Char();
245+
246+
AssertSql(
247+
"""
248+
SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan"
249+
FROM "BasicTypesEntities" AS b
250+
WHERE replace(b."String", 'S', 'R') = 'Reattle'
251+
""");
252+
}
237253

238254
public override async Task Replace_with_empty_string()
239255
{
@@ -429,9 +445,17 @@ WHERE b."String" LIKE 'Se%'
429445
""");
430446
}
431447

432-
// TODO: #3547
433-
public override Task StartsWith_Literal_Char()
434-
=> AssertTranslationFailed(() => base.StartsWith_Literal_Char());
448+
public override async Task StartsWith_Literal_Char()
449+
{
450+
await base.StartsWith_Literal_Char();
451+
452+
AssertSql(
453+
"""
454+
SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan"
455+
FROM "BasicTypesEntities" AS b
456+
WHERE b."String" LIKE 'S%'
457+
""");
458+
}
435459

436460
public override async Task StartsWith_Parameter()
437461
{
@@ -447,8 +471,19 @@ WHERE b."String" LIKE @pattern_startswith
447471
""");
448472
}
449473

450-
public override Task StartsWith_Parameter_Char()
451-
=> AssertTranslationFailed(() => base.StartsWith_Parameter_Char());
474+
public override async Task StartsWith_Parameter_Char()
475+
{
476+
await base.StartsWith_Parameter_Char();
477+
478+
AssertSql(
479+
"""
480+
@pattern_startswith='S%'
481+
482+
SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan"
483+
FROM "BasicTypesEntities" AS b
484+
WHERE b."String" LIKE @pattern_startswith
485+
""");
486+
}
452487

453488
public override async Task StartsWith_Column()
454489
{
@@ -499,9 +534,17 @@ WHERE b."String" LIKE '%le'
499534
""");
500535
}
501536

502-
// TODO: #3547
503-
public override Task EndsWith_Literal_Char()
504-
=> AssertTranslationFailed(() => base.EndsWith_Literal_Char());
537+
public override async Task EndsWith_Literal_Char()
538+
{
539+
await base.EndsWith_Literal_Char();
540+
541+
AssertSql(
542+
"""
543+
SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan"
544+
FROM "BasicTypesEntities" AS b
545+
WHERE b."String" LIKE '%e'
546+
""");
547+
}
505548

506549
public override async Task EndsWith_Parameter()
507550
{
@@ -517,9 +560,19 @@ WHERE b."String" LIKE @pattern_endswith
517560
""");
518561
}
519562

520-
// TODO: #3547
521-
public override Task EndsWith_Parameter_Char()
522-
=> AssertTranslationFailed(() => base.EndsWith_Parameter_Char());
563+
public override async Task EndsWith_Parameter_Char()
564+
{
565+
await base.EndsWith_Parameter_Char();
566+
567+
AssertSql(
568+
"""
569+
@pattern_endswith='%e'
570+
571+
SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan"
572+
FROM "BasicTypesEntities" AS b
573+
WHERE b."String" LIKE @pattern_endswith
574+
""");
575+
}
523576

524577
public override async Task EndsWith_Column()
525578
{
@@ -575,9 +628,17 @@ WHERE b."String" LIKE '%eattl%'
575628
""");
576629
}
577630

578-
// TODO: #3547
579-
public override Task Contains_Literal_Char()
580-
=> AssertTranslationFailed(() => base.Contains_Literal_Char());
631+
public override async Task Contains_Literal_Char()
632+
{
633+
await base.Contains_Literal_Char();
634+
635+
AssertSql(
636+
"""
637+
SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan"
638+
FROM "BasicTypesEntities" AS b
639+
WHERE b."String" LIKE '%e%'
640+
""");
641+
}
581642

582643
public override async Task Contains_Column()
583644
{

0 commit comments

Comments
 (0)