Skip to content

Commit 9dba161

Browse files
franklinicclaude
andcommitted
fix: generator deduplicates and merges attribute + registration discovery
TrainableParameterGenerator now ALWAYS runs RegisterTrainableParameter discovery (not just when attributes are missing) and merges results, deduplicating by field name. This fixes layers like LayerNormalization where only 1 of 2 fields had [TrainableParameter] — the generator now finds both _gamma and _beta. Also skips nullable fields from registration discovery to avoid CS8601 errors (e.g., DeformableConvolutionalLayer's optional mask). Fixed AudioVisualCorrespondence InitializeLayers build errors. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent bd3c54b commit 9dba161

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

src/AiDotNet.Generators/TrainableParameterGenerator.cs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,24 +119,32 @@ private static void Execute(Compilation compilation, ImmutableArray<ClassDeclara
119119
}
120120
}
121121

122-
// Fallback: discover trainable parameters from RegisterTrainableParameter() calls
122+
// Discover trainable parameters from RegisterTrainableParameter() calls
123123
// in the class body. This handles layers that use the registration API but haven't
124-
// been annotated with [TrainableParameter] yet. Same pattern as sublayer discovery.
125-
if (paramFields.Count == 0)
124+
// been fully annotated with [TrainableParameter]. When the attribute count is less
125+
// than the registration count, the registration-based discovery takes precedence
126+
// to ensure all trainable fields are included in the tape autodiff system.
126127
{
128+
// Always run discovery to compare counts
127129
var registeredFields = DiscoverFromRegisterCalls(classDecl, model, "RegisterTrainableParameter");
130+
var existingFieldNames = new HashSet<string>(paramFields.Select(p => p.Name));
128131
foreach (var (fieldName, role) in registeredFields)
129132
{
130-
// Verify the field exists and is a Tensor<T>
133+
// Skip fields already discovered via [TrainableParameter] attributes
134+
if (existingFieldNames.Contains(fieldName)) continue;
135+
136+
// Verify the field exists, is a Tensor<T>, and is non-nullable
131137
var matchingField = classSymbol.GetMembers()
132138
.OfType<IFieldSymbol>()
133-
.FirstOrDefault(f => f.Name == fieldName && IsTensorType(f.Type));
139+
.FirstOrDefault(f => f.Name == fieldName && IsTensorType(f.Type)
140+
&& f.NullableAnnotation != NullableAnnotation.Annotated
141+
&& f.Type.NullableAnnotation != NullableAnnotation.Annotated);
134142
if (matchingField is not null)
135143
{
136144
paramFields.Add(new ParameterFieldInfo(matchingField.Name, role, paramFields.Count));
145+
existingFieldNames.Add(fieldName);
137146
}
138147
}
139-
140148
}
141149

142150
if (paramFields.Count == 0 && subLayerFields.Count == 0) continue;

0 commit comments

Comments
 (0)