@@ -391,7 +391,7 @@ internal void Generate(in GenerateState ctx)
391391
392392 if ( flags . HasAny ( OperationFlags . GetRowParser ) )
393393 {
394- WriteGetRowParser ( sb , resultType , readers ) ;
394+ WriteGetRowParser ( sb , resultType , readers , grp . Key . AdditionalCommandState ? . StrictBind ?? default ) ;
395395 }
396396 else if ( ! TryWriteMultiExecImplementation ( sb , flags , commandTypeMode , parameterType , grp . Key . ParameterMap , grp . Key . UniqueLocation is not null , methodParameters , factories , fixedSql , additionalCommandState ) )
397397 {
@@ -449,9 +449,9 @@ internal void Generate(in GenerateState ctx)
449449 }
450450 sb . NewLine ( ) ;
451451
452- foreach ( var pair in readers )
452+ foreach ( var tuple in readers )
453453 {
454- WriteRowFactory ( ctx , sb , pair . Type , pair . Index ) ;
454+ WriteRowFactory ( ctx , sb , tuple . Type , tuple . Index , tuple . StrictBind , null /* TODO */ ) ;
455455 }
456456
457457 foreach ( var tuple in factories )
@@ -468,9 +468,9 @@ internal void Generate(in GenerateState ctx)
468468 ctx . ReportDiagnostic ( Diagnostic . Create ( Diagnostics . InterceptorsGenerated , null , callSiteCount , ctx . Nodes . Length , methodIndex , factories . Count ( ) , readers . Count ( ) ) ) ;
469469 }
470470
471- private static void WriteGetRowParser ( CodeWriter sb , ITypeSymbol ? resultType , in RowReaderState readers )
471+ private static void WriteGetRowParser ( CodeWriter sb , ITypeSymbol ? resultType , in RowReaderState readers , ImmutableArray < string > strictBind )
472472 {
473- sb . Append ( "return " ) . AppendReader ( resultType , readers )
473+ sb . Append ( "return " ) . AppendReader ( resultType , readers , strictBind )
474474 . Append ( ".GetRowParser(reader, startIndex, length, returnNullIfFirstMissing);" ) . NewLine ( ) ;
475475 }
476476
@@ -732,7 +732,7 @@ static bool IsReserved(string name)
732732 }
733733 }
734734
735- private static void WriteRowFactory ( in GenerateState context , CodeWriter sb , ITypeSymbol type , int index )
735+ private static void WriteRowFactory ( in GenerateState context , CodeWriter sb , ITypeSymbol type , int index , ImmutableArray < string > strictBind , Location ? location )
736736 {
737737 var map = MemberMap . CreateForResults ( type ) ;
738738 if ( map is null ) return ;
@@ -761,7 +761,7 @@ private static void WriteRowFactory(in GenerateState context, CodeWriter sb, ITy
761761 WriteRowFactoryHeader ( ) ;
762762
763763 WriteTokenizeMethod ( ) ;
764- WriteReadMethod ( ) ;
764+ WriteReadMethod ( context ) ;
765765
766766 WriteRowFactoryFooter ( ) ;
767767
@@ -780,31 +780,57 @@ void WriteRowFactoryFooter()
780780 void WriteTokenizeMethod ( )
781781 {
782782 sb . Append ( "public override object? Tokenize(global::System.Data.Common.DbDataReader reader, global::System.Span<int> tokens, int columnOffset)" ) . Indent ( ) . NewLine ( ) ;
783- sb . Append ( "for (int i = 0; i < tokens.Length; i++)" ) . Indent ( ) . NewLine ( )
784- . Append ( "int token = -1;" ) . NewLine ( )
785- . Append ( "var name = reader.GetName(columnOffset);" ) . NewLine ( )
786- . Append ( "var type = reader.GetFieldType(columnOffset);" ) . NewLine ( )
787- . Append ( "switch (NormalizedHash(name))" ) . Indent ( ) . NewLine ( ) ;
783+ if ( strictBind . IsDefault ) // don't emit any tokens for strict binding
784+ {
785+ sb . Append ( "for (int i = 0; i < tokens.Length; i++)" ) . Indent ( ) . NewLine ( )
786+ . Append ( "int token = -1;" ) . NewLine ( )
787+ . Append ( "var name = reader.GetName(columnOffset);" ) . NewLine ( )
788+ . Append ( "var type = reader.GetFieldType(columnOffset);" ) . NewLine ( )
789+ . Append ( "switch (NormalizedHash(name))" ) . Indent ( ) . NewLine ( ) ;
788790
789- int token = 0 ;
790- foreach ( var member in members )
791- {
792- var dbName = member . DbName ;
793- sb . Append ( "case " ) . Append ( StringHashing . NormalizedHash ( dbName ) )
794- . Append ( " when NormalizedEquals(name, " )
795- . AppendVerbatimLiteral ( StringHashing . Normalize ( dbName ) ) . Append ( "):" ) . Indent ( false ) . NewLine ( )
796- . Append ( "token = type == typeof(" ) . Append ( Inspection . MakeNonNullable ( member . CodeType ) ) . Append ( ") ? " ) . Append ( token )
797- . Append ( " : " ) . Append ( token + map . Members . Length ) . Append ( ";" )
798- . Append ( token == 0 ? " // two tokens for right-typed and type-flexible" : "" ) . NewLine ( )
799- . Append ( "break;" ) . Outdent ( false ) . NewLine ( ) ;
800- token ++ ;
791+ int token = 0 ;
792+ foreach ( var member in members )
793+ {
794+ var dbName = member . DbName ;
795+ sb . Append ( "case " ) . Append ( StringHashing . NormalizedHash ( dbName ) )
796+ . Append ( " when NormalizedEquals(name, " )
797+ . AppendVerbatimLiteral ( StringHashing . Normalize ( dbName ) ) . Append ( "):" ) . Indent ( false ) . NewLine ( )
798+ . Append ( "token = type == typeof(" ) . Append ( Inspection . MakeNonNullable ( member . CodeType ) ) . Append ( ") ? " ) . Append ( token )
799+ . Append ( " : " ) . Append ( token + map . Members . Length ) . Append ( ";" )
800+ . Append ( token == 0 ? " // two tokens for right-typed and type-flexible" : "" ) . NewLine ( )
801+ . Append ( "break;" ) . Outdent ( false ) . NewLine ( ) ;
802+ token ++ ;
803+ }
804+ sb . Outdent ( ) . NewLine ( )
805+ . Append ( "tokens[i] = token;" ) . NewLine ( )
806+ . Append ( "columnOffset++;" ) . NewLine ( )
807+ . Outdent ( ) . NewLine ( ) ;
801808 }
802- sb . Outdent ( ) . NewLine ( )
803- . Append ( "tokens[i] = token;" ) . NewLine ( )
804- . Append ( "columnOffset++;" ) . NewLine ( ) ;
805- sb . Outdent ( ) . NewLine ( ) . Append ( "return null;" ) . Outdent ( ) . NewLine ( ) ;
809+ else
810+ {
811+ sb . Append ( "// strict-bind: " ) ;
812+ for ( int i = 0 ; i < strictBind . Length ; i ++ )
813+ {
814+ if ( i != 0 ) sb . Append ( ", " ) ;
815+ var name = strictBind [ i ] ;
816+ if ( string . IsNullOrWhiteSpace ( name ) )
817+ {
818+ sb . Append ( "(n/a)" ) ;
819+ }
820+ else if ( CompiledRegex . SimpleName . IsMatch ( name ) )
821+ {
822+ sb . Append ( name ) ;
823+ }
824+ else
825+ {
826+ sb . Append ( "'" ) . Append ( name ) . Append ( "'" ) ;
827+ }
828+ }
829+ sb . NewLine ( ) . Append ( "global::System.Diagnostics.Debug.Assert(tokens.Length == " ) . Append ( strictBind . Length ) . Append ( """, "Strict-bind column count mismatch");""" ) . NewLine ( ) ;
830+ }
831+ sb . Append ( "return null;" ) . Outdent ( ) . NewLine ( ) ;
806832 }
807- void WriteReadMethod ( )
833+ void WriteReadMethod ( in GenerateState context )
808834 {
809835 const string DeferredConstructionVariableName = "value" ;
810836
@@ -852,47 +878,63 @@ void WriteReadMethod()
852878 ? type . WithNullableAnnotation ( NullableAnnotation . None ) : type ) . Append ( " result = new();" ) . NewLine ( ) ;
853879 }
854880
855- sb . Append ( "foreach (var token in tokens)" ) . Indent ( ) . NewLine ( )
856- . Append ( "switch (token)" ) . Indent ( ) . NewLine ( ) ;
881+ ImmutableArray < ElementMember > readMembers ;
882+ if ( strictBind . IsDefault )
883+ {
884+ readMembers = members ; // try to parse everything
885+ sb . Append ( "foreach (var token in tokens)" ) ;
886+ }
887+ else
888+ {
889+ readMembers = MapStrictBind ( context , members , strictBind , location ) ;
890+ sb . Append ( "for (int token = 0; token < tokens.Length; token++) // strict-bind" ) ;
891+ }
892+ sb . Indent ( ) . NewLine ( ) . Append ( "switch (token)" ) . Indent ( ) . NewLine ( ) ;
857893
858894 token = 0 ;
859- foreach ( var member in members )
895+ foreach ( var member in readMembers )
860896 {
861- var memberType = member . CodeType ;
897+ if ( member . Member is not null ) // exclude non-mapped bindings
898+ {
899+ var memberType = member . CodeType ;
862900
863- member . GetDbType ( out var readerMethod ) ;
864- var nullCheck = Inspection . CouldBeNullable ( memberType ) ? $ "reader.IsDBNull(columnOffset) ? ({ CodeWriter . GetTypeName ( memberType . WithNullableAnnotation ( NullableAnnotation . Annotated ) ) } )null : " : "" ;
865- sb . Append ( "case " ) . Append ( token ) . Append ( ":" ) . NewLine ( ) . Indent ( false ) ;
901+ member . GetDbType ( out var readerMethod ) ;
902+ var nullCheck = Inspection . CouldBeNullable ( memberType ) ? $ "reader.IsDBNull(columnOffset) ? ({ CodeWriter . GetTypeName ( memberType . WithNullableAnnotation ( NullableAnnotation . Annotated ) ) } )null : " : "" ;
903+ sb . Append ( "case " ) . Append ( token ) . Append ( ":" ) . NewLine ( ) . Indent ( false ) ;
866904
867- // write `result.X = ` or `member0 = `
868- if ( useDeferredConstruction ) sb . Append ( DeferredConstructionVariableName ) . Append ( token ) ;
869- else sb . Append ( "result." ) . Append ( member . CodeName ) ;
870- sb . Append ( " = " ) ;
905+ // write `result.X = ` or `member0 = `
906+ if ( useDeferredConstruction ) sb . Append ( DeferredConstructionVariableName ) . Append ( token ) ;
907+ else sb . Append ( "result." ) . Append ( member . CodeName ) ;
908+ sb . Append ( " = " ) ;
871909
872- sb . Append ( nullCheck ) ;
873- if ( readerMethod is null )
874- {
875- sb . Append ( "reader.GetFieldValue<" ) . Append ( memberType ) . Append ( ">(columnOffset);" ) ;
876- }
877- else
878- {
879- sb . Append ( "reader." ) . Append ( readerMethod ) . Append ( "(columnOffset);" ) ;
880- }
910+ sb . Append ( nullCheck ) ;
911+ if ( readerMethod is null )
912+ {
913+ sb . Append ( "reader.GetFieldValue<" ) . Append ( memberType ) . Append ( ">(columnOffset);" ) ;
914+ }
915+ else
916+ {
917+ sb . Append ( "reader." ) . Append ( readerMethod ) . Append ( "(columnOffset);" ) ;
918+ }
881919
882920
883- sb . NewLine ( ) . Append ( "break;" ) . NewLine ( ) . Outdent ( false )
884- . Append ( "case " ) . Append ( token + map . Members . Length ) . Append ( ":" ) . NewLine ( ) . Indent ( false ) ;
921+ sb . NewLine ( ) . Append ( "break;" ) . NewLine ( ) . Outdent ( false ) ;
885922
886- // write `result.X = ` or `member0 = `
887- if ( useDeferredConstruction ) sb . Append ( DeferredConstructionVariableName ) . Append ( token ) ;
888- else sb . Append ( "result. " ) . Append ( member . CodeName ) ;
923+ if ( strictBind . IsDefault ) // type-forgiving version; only emitted when not using strict-bind
924+ {
925+ sb . Append ( "case " ) . Append ( token + map . Members . Length ) . Append ( ":" ) . NewLine ( ) . Indent ( false ) ;
889926
890- sb . Append ( " = " )
891- . Append ( nullCheck )
892- . Append ( "GetValue<" )
893- . Append ( Inspection . MakeNonNullable ( memberType ) ) . Append ( ">(reader, columnOffset);" ) . NewLine ( )
894- . Append ( "break;" ) . NewLine ( ) . Outdent ( false ) ;
927+ // write `result.X = ` or `member0 = `
928+ if ( useDeferredConstruction ) sb . Append ( DeferredConstructionVariableName ) . Append ( token ) ;
929+ else sb . Append ( "result." ) . Append ( member . CodeName ) ;
895930
931+ sb . Append ( " = " )
932+ . Append ( nullCheck )
933+ . Append ( "GetValue<" )
934+ . Append ( Inspection . MakeNonNullable ( memberType ) ) . Append ( ">(reader, columnOffset);" ) . NewLine ( )
935+ . Append ( "break;" ) . NewLine ( ) . Outdent ( false ) ;
936+ }
937+ }
896938 token ++ ;
897939 }
898940
@@ -978,6 +1020,46 @@ void WriteDeferredMethodArgs()
9781020 }
9791021 }
9801022
1023+ private static ImmutableArray < ElementMember > MapStrictBind ( in GenerateState state , ImmutableArray < ElementMember > members , ImmutableArray < string > strictBind , Location ? location )
1024+ {
1025+ if ( strictBind . IsDefault ) return members ; // not bound
1026+
1027+ var result = ImmutableArray . CreateBuilder < ElementMember > ( strictBind . Length ) ;
1028+ foreach ( var seek in strictBind )
1029+ {
1030+ ElementMember found = default ;
1031+ if ( ! string . IsNullOrWhiteSpace ( seek ) )
1032+ {
1033+ foreach ( var member in members )
1034+ {
1035+ if ( member . CodeName == seek )
1036+ {
1037+ found = member ;
1038+ break ;
1039+ }
1040+ }
1041+ if ( found . Member is null )
1042+ {
1043+ var normalizedSeek = StringHashing . Normalize ( seek ) ;
1044+ foreach ( var member in members )
1045+ {
1046+ if ( StringHashing . NormalizedEquals ( member . CodeName , normalizedSeek ) )
1047+ {
1048+ found = member ;
1049+ break ;
1050+ }
1051+ }
1052+ }
1053+ if ( found . Member is null )
1054+ {
1055+ state . ReportDiagnostic ( Diagnostic . Create ( DapperAnalyzer . Diagnostics . BoundMemberNotFound , location , seek ) ) ;
1056+ }
1057+ }
1058+ result . Add ( found ) ;
1059+ }
1060+ return result . ToImmutable ( ) ;
1061+ }
1062+
9811063 [ Flags ]
9821064 enum WriteArgsFlags
9831065 {
0 commit comments