Skip to content

Commit 26ddf8e

Browse files
franklinicclaude
andcommitted
fix: generator emits params in RegisterTrainableParameter order
The generator previously emitted GetTrainableParameters in [attributes-first, register-discovered] order, but _registeredTensors in LayerBase stores them in constructor registration order. base.SetTrainableParameters assigns by position, so mismatched ordering caused view tensors to be assigned to wrong fields — swapping shapes between weights and biases. Now when RegisterTrainableParameter calls are found, they define the canonical parameter order, replacing the attribute-discovered order. This ensures generated order matches _registeredTensors order exactly. Fixes MemoryNetwork rank-1 crash, SPLADE LayerNorm count mismatch, and other parameter buffer view replacement bugs. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9dba161 commit 26ddf8e

File tree

1 file changed

+31
-20
lines changed

1 file changed

+31
-20
lines changed

src/AiDotNet.Generators/TrainableParameterGenerator.cs

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -119,30 +119,41 @@ private static void Execute(Compilation compilation, ImmutableArray<ClassDeclara
119119
}
120120
}
121121

122-
// Discover trainable parameters from RegisterTrainableParameter() calls
123-
// in the class body. This handles layers that use the registration API but haven't
124-
// been fully annotated with [TrainableParameter]. When the attribute count is less
125-
// than the registration count, the registration-based discovery takes precedence
126-
// to ensure all trainable fields are included in the tape autodiff system.
122+
// Discover trainable parameters from RegisterTrainableParameter() calls.
123+
// Registration order is the canonical order — it matches _registeredTensors
124+
// in LayerBase, so base.SetTrainableParameters positional assignment works.
125+
// If RegisterTrainableParameter calls exist, they REPLACE attribute-discovered
126+
// params to ensure correct ordering (attributes may be in declaration order
127+
// which differs from registration order).
127128
{
128-
// Always run discovery to compare counts
129129
var registeredFields = DiscoverFromRegisterCalls(classDecl, model, "RegisterTrainableParameter");
130-
var existingFieldNames = new HashSet<string>(paramFields.Select(p => p.Name));
131-
foreach (var (fieldName, role) in registeredFields)
130+
if (registeredFields.Count > 0)
132131
{
133-
// Skip fields already discovered via [TrainableParameter] attributes
134-
if (existingFieldNames.Contains(fieldName)) continue;
135-
136-
// Verify the field exists, is a Tensor<T>, and is non-nullable
137-
var matchingField = classSymbol.GetMembers()
138-
.OfType<IFieldSymbol>()
139-
.FirstOrDefault(f => f.Name == fieldName && IsTensorType(f.Type)
140-
&& f.NullableAnnotation != NullableAnnotation.Annotated
141-
&& f.Type.NullableAnnotation != NullableAnnotation.Annotated);
142-
if (matchingField is not null)
132+
// Build attribute-discovered roles map for enrichment
133+
var attrRoles = new Dictionary<string, string>();
134+
foreach (var pf in paramFields)
135+
attrRoles[pf.Name] = pf.Role;
136+
137+
// Replace paramFields with registration-ordered list
138+
paramFields.Clear();
139+
var seen = new HashSet<string>();
140+
foreach (var (fieldName, role) in registeredFields)
143141
{
144-
paramFields.Add(new ParameterFieldInfo(matchingField.Name, role, paramFields.Count));
145-
existingFieldNames.Add(fieldName);
142+
if (!seen.Add(fieldName)) continue;
143+
144+
// Verify the field exists, is a Tensor<T>, and is non-nullable
145+
var matchingField = classSymbol.GetMembers()
146+
.OfType<IFieldSymbol>()
147+
.FirstOrDefault(f => f.Name == fieldName && IsTensorType(f.Type)
148+
&& f.NullableAnnotation != NullableAnnotation.Annotated
149+
&& f.Type.NullableAnnotation != NullableAnnotation.Annotated);
150+
if (matchingField is not null)
151+
{
152+
// Prefer attribute role if available (more specific)
153+
var finalRole = attrRoles.TryGetValue(fieldName, out var attrRole)
154+
? attrRole : role;
155+
paramFields.Add(new ParameterFieldInfo(matchingField.Name, finalRole, paramFields.Count));
156+
}
146157
}
147158
}
148159
}

0 commit comments

Comments
 (0)