Skip to content

Commit 00af3f8

Browse files
ooplesclaude
andcommitted
fix: implement SaveState/LoadState methods for regression bases
Resolves review comments on NonLinearRegressionBase.cs:1126 and DecisionTreeRegressionBase.cs:709 - Implemented SaveState by delegating to Serialize() with proper stream validation - Implemented LoadState by delegating to Deserialize() with proper stream validation - Added null checks and stream capability validation as recommended - Added empty stream validation for LoadState 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 9f303d8 commit 00af3f8

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

src/Regression/DecisionTreeRegressionBase.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,14 +1115,24 @@ public virtual void ApplyGradients(Vector<T> gradients, T learningRate)
11151115
/// </summary>
11161116
public virtual void SaveState(Stream stream)
11171117
{
1118-
throw new NotImplementedException("SaveState is not yet implemented for DecisionTreeRegressionBase. Consider serializing the tree structure explicitly.");
1118+
if (stream == null) throw new ArgumentNullException(nameof(stream));
1119+
if (!stream.CanWrite) throw new ArgumentException("Stream must be writable.", nameof(stream));
1120+
var data = Serialize();
1121+
stream.Write(data, 0, data.Length);
1122+
stream.Flush();
11191123
}
11201124

11211125
/// <summary>
11221126
/// Loads the model's state from a stream.
11231127
/// </summary>
11241128
public virtual void LoadState(Stream stream)
11251129
{
1126-
throw new NotImplementedException("LoadState is not yet implemented for DecisionTreeRegressionBase. Consider deserializing the tree structure explicitly.");
1130+
if (stream == null) throw new ArgumentNullException(nameof(stream));
1131+
if (!stream.CanRead) throw new ArgumentException("Stream must be readable.", nameof(stream));
1132+
using var ms = new MemoryStream();
1133+
stream.CopyTo(ms);
1134+
var data = ms.ToArray();
1135+
if (data.Length == 0) throw new InvalidOperationException("Stream contains no data.");
1136+
Deserialize(data);
11271137
}
11281138
}

src/Regression/NonLinearRegressionBase.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,14 +1114,24 @@ public virtual void ApplyGradients(Vector<T> gradients, T learningRate)
11141114
/// </summary>
11151115
public virtual void SaveState(Stream stream)
11161116
{
1117-
throw new NotImplementedException("SaveState is not yet implemented for NonLinearRegressionBase. Consider using Serialize() method to get serialized data.");
1117+
if (stream == null) throw new ArgumentNullException(nameof(stream));
1118+
if (!stream.CanWrite) throw new ArgumentException("Stream must be writable.", nameof(stream));
1119+
var data = Serialize();
1120+
stream.Write(data, 0, data.Length);
1121+
stream.Flush();
11181122
}
11191123

11201124
/// <summary>
11211125
/// Loads the model's state from a stream.
11221126
/// </summary>
11231127
public virtual void LoadState(Stream stream)
11241128
{
1125-
throw new NotImplementedException("LoadState is not yet implemented for NonLinearRegressionBase. Consider using Deserialize() method with byte array.");
1129+
if (stream == null) throw new ArgumentNullException(nameof(stream));
1130+
if (!stream.CanRead) throw new ArgumentException("Stream must be readable.", nameof(stream));
1131+
using var ms = new MemoryStream();
1132+
stream.CopyTo(ms);
1133+
var data = ms.ToArray();
1134+
if (data.Length == 0) throw new InvalidOperationException("Stream contains no data.");
1135+
Deserialize(data);
11261136
}
11271137
}

0 commit comments

Comments
 (0)