Skip to content

Commit 6f8b5a3

Browse files
committed
fix: update TeacherModelWrapper calls and implement missing IFullModel members
1. Updated all TeacherModelWrapper instantiations to include output dimension: - Lines 50-51: Ensemble test case - Lines 114-115: MultiModal test case - Lines 214-215: Distributed test case - All now use: new TeacherModelWrapper<double>(model, outputDimension: 5) 2. Implemented missing IFullModel interface members in MockFullModel: - IFeatureAware: GetActiveFeatureIndices, SetActiveFeatureIndices, IsFeatureUsed - IFeatureImportance: GetFeatureImportance - ICloneable: DeepCopy, Clone - IGradientComputable: ComputeGradients, ApplyGradients - IParameterizable: GetParameters, SetParameters, ParameterCount, WithParameters - IModelSerializer: Serialize, Deserialize, SaveModel, LoadModel - DefaultLossFunction property All implementations are placeholder/mock implementations suitable for testing.
1 parent 33295e2 commit 6f8b5a3

File tree

1 file changed

+87
-6
lines changed

1 file changed

+87
-6
lines changed

tests/AiDotNet.Tests/UnitTests/KnowledgeDistillation/TeacherModelFactoryTests.cs

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ public void CreateTeacher_Ensemble_WithModels_ReturnsEnsembleTeacher()
4747
var model2 = new MockFullModel(inputDim: 10, outputDim: 5);
4848
var teachers = new[]
4949
{
50-
new TeacherModelWrapper<double>(model1),
51-
new TeacherModelWrapper<double>(model2)
50+
new TeacherModelWrapper<double>(model1, outputDimension: 5),
51+
new TeacherModelWrapper<double>(model2, outputDimension: 5)
5252
};
5353

5454
// Act
@@ -111,8 +111,8 @@ public void CreateTeacher_MultiModal_WithModels_ReturnsMultiModalTeacher()
111111
var model2 = new MockFullModel(inputDim: 10, outputDim: 5);
112112
var teachers = new[]
113113
{
114-
new TeacherModelWrapper<double>(model1),
115-
new TeacherModelWrapper<double>(model2)
114+
new TeacherModelWrapper<double>(model1, outputDimension: 5),
115+
new TeacherModelWrapper<double>(model2, outputDimension: 5)
116116
};
117117

118118
// Act
@@ -211,8 +211,8 @@ public void CreateTeacher_Distributed_WithModels_ReturnsDistributedTeacher()
211211
var model2 = new MockFullModel(inputDim: 10, outputDim: 5);
212212
var teachers = new[]
213213
{
214-
new TeacherModelWrapper<double>(model1),
215-
new TeacherModelWrapper<double>(model2)
214+
new TeacherModelWrapper<double>(model1, outputDimension: 5),
215+
new TeacherModelWrapper<double>(model2, outputDimension: 5)
216216
};
217217

218218
// Act
@@ -231,11 +231,16 @@ private class MockFullModel : IFullModel<double, Vector<double>, Vector<double>>
231231
private readonly int _inputDim;
232232
private readonly int _outputDim;
233233
private readonly Random _random = new Random(42);
234+
private readonly HashSet<int> _activeFeatures = new();
234235

235236
public MockFullModel(int inputDim, int outputDim)
236237
{
237238
_inputDim = inputDim;
238239
_outputDim = outputDim;
240+
241+
// Initialize with all features active
242+
for (int i = 0; i < inputDim; i++)
243+
_activeFeatures.Add(i);
239244
}
240245

241246
public Vector<double> Predict(Vector<double> input)
@@ -265,5 +270,81 @@ public void Train(Vector<double> input, Vector<double> target) { }
265270

266271
public void SaveState(Stream stream) { }
267272
public void LoadState(Stream stream) { }
273+
274+
public byte[] Serialize() => Array.Empty<byte>();
275+
public void Deserialize(byte[] data) { }
276+
public void SaveModel(string filePath) { }
277+
public void LoadModel(string filePath) { }
278+
279+
// IFeatureAware implementation
280+
public IEnumerable<int> GetActiveFeatureIndices() => _activeFeatures;
281+
282+
public void SetActiveFeatureIndices(IEnumerable<int> featureIndices)
283+
{
284+
_activeFeatures.Clear();
285+
foreach (var idx in featureIndices)
286+
_activeFeatures.Add(idx);
287+
}
288+
289+
public bool IsFeatureUsed(int featureIndex) => _activeFeatures.Contains(featureIndex);
290+
291+
// IFeatureImportance implementation
292+
public Dictionary<string, double> GetFeatureImportance()
293+
{
294+
var importance = new Dictionary<string, double>();
295+
for (int i = 0; i < _inputDim; i++)
296+
importance[$"feature_{i}"] = 1.0 / _inputDim;
297+
return importance;
298+
}
299+
300+
// ICloneable implementation
301+
public IFullModel<double, Vector<double>, Vector<double>> DeepCopy()
302+
{
303+
var copy = new MockFullModel(_inputDim, _outputDim);
304+
copy.SetActiveFeatureIndices(_activeFeatures);
305+
return copy;
306+
}
307+
308+
public IFullModel<double, Vector<double>, Vector<double>> Clone()
309+
{
310+
return new MockFullModel(_inputDim, _outputDim);
311+
}
312+
313+
// IGradientComputable implementation
314+
public Vector<double> ComputeGradients(Vector<double> input, Vector<double> target, ILossFunction<double>? lossFunction = null)
315+
{
316+
// Return dummy gradients for testing
317+
return new Vector<double>(_inputDim * _outputDim);
318+
}
319+
320+
public void ApplyGradients(Vector<double> gradients, double learningRate)
321+
{
322+
// Placeholder - do nothing for mock
323+
}
324+
325+
// DefaultLossFunction property
326+
public ILossFunction<double> DefaultLossFunction =>
327+
throw new InvalidOperationException("Mock model does not have a default loss function");
328+
329+
// IParameterizable implementation
330+
public Vector<double> GetParameters()
331+
{
332+
// Return dummy parameters for testing
333+
return new Vector<double>(_inputDim * _outputDim);
334+
}
335+
336+
public void SetParameters(Vector<double> parameters)
337+
{
338+
// Placeholder - do nothing for mock
339+
}
340+
341+
public int ParameterCount => _inputDim * _outputDim;
342+
343+
public IFullModel<double, Vector<double>, Vector<double>> WithParameters(Vector<double> parameters)
344+
{
345+
var copy = new MockFullModel(_inputDim, _outputDim);
346+
copy.SetParameters(parameters);
347+
return copy;
348+
}
268349
}
269350
}

0 commit comments

Comments
 (0)